Thrift&SSL编程#08#编写基于Windows系统的测试代码(单向验证:客户端验证服务端)

关于Thrift和OpenSSL的安全通信,上篇我们描述了数字证书的生成方法,本文在此基础上编写单向验证的测试代码。

  • 01#编译OpenSSL库;
  • 02#编译Boost库;
  • 03#编译zlib库;
  • 04#编译libevent库;
  • 05#编译Thrift库;
  • 06#生成客户端和服务端通信所用的数字证书;
  • 07#编写基于Linux系统的测试代码(单向验证:客户端验证服务端);
  • 08#编写基于Windows系统的测试代码(单向验证:客户端验证服务端)
  • 09#编写基于Linux系统的测试代码(双向验证:客户端验证服务端+服务端验证客户端);
  • 10#编写基于Windows系统的测试代码(双向验证:客户端验证服务端+服务端验证客户端);
  • 11#自定义数字证书的验证策略;

数字证书单向验证原理简介

提到数字证书的验证,一般指的是:

  • CA机构使用其根公钥证书(ca.crt)签发服务端公钥证书(server.crt);
  • 服务端程序内置服务端私钥证书(server.key)和服务端公钥证书(server.crt);
  • 客户端程序内置根公钥证书(ca.crt);
  • 客户端程序启动后,向服务端请求证书,然后服务端将server.crt发送给客户端;
  • 客户端获得服务端公钥证书(server.crt)后,使用根公钥证书(ca.crt)验证server.crt的合法性;
  • 验证机制如果发现server.crtca.crt签发的,则继续SSL握手操作,如对称密钥交换等步骤;
  • 验证机制如果发现server.crtca.crt签发的,则终止SSL握手操作,终止本次通信;

上述步骤即是我们常见的单向验证,即客户端验证服务端的合法性,如果服务端是合法的,则进行SSL安全通信,否则就终止通信。

如果验证机制通过了server.crt的合法性校验,那服务端私钥证书server.key扮演了什么角色?

  • 我们知道SSL考虑到加密性能,在进行真实数据通信时,最终还得采用对称加密。
  • 客户端生成对称加密密钥(random_key),再由服务端公钥证书(server.crt)进行非对称加密后得到密文,然后将此密文发给服务端。
  • 服务端获得密文后,由服务端私钥证书(server.key)进行解密,得到对称加密密钥的明文(random_key)。
  • 后续的通信,双方均使用该密钥(random_key)进行加解密。

服务端程序

以下是SSL通信所用的服务端测试程序,从代码看其用到了server.crt和server.key两个证书。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <stdint.h>
#include <inttypes.h>
#include <signal.h>
#include <iostream>
#include <stdexcept>
#include <sstream>

#include <thrift/concurrency/ThreadFactory.h>
#include <thrift/concurrency/ThreadManager.h>
#include <thrift/transport/TSSLServerSocket.h>
#include <thrift/transport/TSSLSocket.h>
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/server/TSimpleServer.h>
#include "FirstService.h"

using namespace std;
using namespace apache::thrift::protocol;
using namespace apache::thrift::transport;
using namespace apache::thrift::server;
using namespace thrift::test;

apache::thrift::concurrency::Monitor gMonitor;
void signal_handler(int signum)
{
if (signum == SIGINT)
{
gMonitor.notifyAll();
}
}

class TestHandler : public FirstServiceIf
{
public:
TestHandler() = default;

void testString(string& out, const string& thing) override
{
printf("testString(\"%s\")\n", thing.c_str());
out = thing;
}
};

int main(int argc, char** argv)
{
std::string sServerCRT = "cert.v3/server.crt";
std::string sServerKEY = "cert.v3/server.key";

// Dispatcher
TBinaryProtocolFactoryT<TBufferBase>* binaryProtocolFactory = new TBinaryProtocolFactoryT<TBufferBase>();
binaryProtocolFactory->setContainerSizeLimit(0);
binaryProtocolFactory->setStringSizeLimit(0);

std::shared_ptr<TProtocolFactory> protocolFactory;
protocolFactory.reset(binaryProtocolFactory);

// Processors
std::shared_ptr<TestHandler> testHandler(new TestHandler());
std::shared_ptr<TProcessor> testProcessor(new FirstServiceProcessor(testHandler));

// Transport
std::shared_ptr<TSSLSocketFactory> sslSocketFactory;
sslSocketFactory = std::shared_ptr<TSSLSocketFactory>(new TSSLSocketFactory());
sslSocketFactory->loadCertificate(sServerCRT.c_str());
sslSocketFactory->loadPrivateKey(sServerKEY.c_str());
sslSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

std::shared_ptr<TServerSocket> serverSocket;
serverSocket = std::shared_ptr<TServerSocket>(new TSSLServerSocket(9090, sslSocketFactory));

// Factory
std::shared_ptr<TTransportFactory> transportFactory;
transportFactory = std::make_shared<TBufferedTransportFactory>();

// Server Info
cout << "Starting simple server (buffered/binary) listen on: " << serverSocket->getPort() << endl;

// Server
std::shared_ptr<apache::thrift::server::TServer> server;
server.reset(new TSimpleServer(testProcessor, serverSocket, transportFactory, protocolFactory));

if (server.get() != nullptr)
{
apache::thrift::concurrency::ThreadFactory factory;
factory.setDetached(false);
std::shared_ptr<apache::thrift::concurrency::Runnable> serverThreadRunner(server);
std::shared_ptr<apache::thrift::concurrency::Thread> thread = factory.newThread(serverThreadRunner);

signal(SIGINT, signal_handler);

thread->start();
gMonitor.waitForever(); // wait for a shutdown signal

signal(SIGINT, SIG_DFL);

server->stop();
thread->join();
server.reset();
}

cout << "done." << endl;

return 0;
}

客户端程序

以下是SSL通信所用的客户端测试程序,从代码看其用到了ca.crt一个证书。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include <iostream>
#include <sstream>
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/transport/TSSLSocket.h>
#include "FirstService.h"

using namespace apache::thrift::protocol;
using namespace apache::thrift::transport;
using namespace thrift::test;

///////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Proto>
class TPedanticProtocol : public Proto
{
public:
TPedanticProtocol(std::shared_ptr<TTransport>& transport) : Proto(transport),
m_last_seqid((std::numeric_limits<int32_t>::max)() - 10)
{

}

virtual uint32_t writeMessageBegin_virt(const std::string& name,
const TMessageType messageType, const int32_t in_seqid) override
{
int32_t seqid = in_seqid;
if (!seqid)
{
// this is typical for normal cpp generated code
seqid = ++m_last_seqid;
}

return Proto::writeMessageBegin_virt(name, messageType, seqid);
}

virtual uint32_t readMessageBegin_virt(std::string& name, TMessageType& messageType, int32_t& seqid) override
{
uint32_t result = Proto::readMessageBegin_virt(name, messageType, seqid);
if (seqid != m_last_seqid)
{
std::stringstream ss;
ss << "ERROR: send request with seqid " << m_last_seqid << " and got reply with seqid " << seqid;
throw std::logic_error(ss.str());
}

return result;
}

private:
int32_t m_last_seqid;
};

int main(int argc, char** argv)
{
std::string sCACertPath = "cert.v3/ca.crt";
std::string sServerHost = "127.0.0.1";

///
std::shared_ptr<TSSLSocketFactory> factory;
factory = std::shared_ptr<TSSLSocketFactory>(new TSSLSocketFactory());
factory->loadTrustedCertificates(sCACertPath.c_str());
factory->authenticate(true);
factory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

std::shared_ptr<TSocket> socket;
socket = factory->createSocket(sServerHost, 9090);

///
std::shared_ptr<TTransport> transport;
transport = std::make_shared<TBufferedTransport>(socket);

typedef TPedanticProtocol<TBinaryProtocol> TPedanticBinaryProtocol;
std::shared_ptr<TProtocol> protocol;
protocol = std::make_shared<TPedanticBinaryProtocol>(transport);

// Connection info
std::cout << "Connecting (buffered/binary" << ") to: " << sServerHost << ":" << socket->getPort() << std::endl;

FirstServiceClient testClient(protocol);

try
{
transport->open();
std::string s;
testClient.testString(s, "Test");
std::cout << "testString(\"Test\")" << " = " << s << std::endl;
if (s != "Test")
{
std::cout << "error" << std::endl;
}
}
catch (TTransportException& ex)
{
std::cout << "Connect failed: " << ex.what() << std::endl;

return 0;
}

transport->close();

return 0;
}

上述客户端代码和服务端代码,可以在如下两个地址下载: