mirror of
https://github.com/SChernykh/p2pool.git
synced 2025-01-08 03:39:24 +00:00
Stratum: added TLS support
This commit is contained in:
parent
cb8ea37dab
commit
b0ac084ab2
8 changed files with 461 additions and 51 deletions
|
@ -17,6 +17,7 @@ option(WITH_RANDOMX "Include the RandomX library in the build. If this is turned
|
|||
option(WITH_LTO "Use link-time compiler optimization (if linking fails for you, run cmake with -DWITH_LTO=OFF)" ON)
|
||||
option(WITH_UPNP "Include UPnP support. If this is turned off, p2pool will not be able to configure port forwarding on UPnP-enabled routers." ON)
|
||||
option(WITH_GRPC "Include gRPC support. If this is turned off, p2pool will not be able to merge mine with Tari." ON)
|
||||
option(WITH_TLS "Include TLS support. If this is turned off, p2pool will not support Stratum TLS connections." ON)
|
||||
|
||||
option(DEV_TEST_SYNC "[Developer only] Sync test, stop p2pool after sync is complete" OFF)
|
||||
option(DEV_WITH_TSAN "[Developer only] Compile with thread sanitizer" OFF)
|
||||
|
@ -166,6 +167,13 @@ if (WITH_GRPC)
|
|||
set(SOURCES ${SOURCES} src/merge_mining_client_tari.cpp)
|
||||
endif()
|
||||
|
||||
if (WITH_GRPC AND WITH_TLS)
|
||||
add_definitions(-DWITH_TLS)
|
||||
|
||||
set(HEADERS ${HEADERS} src/tls.h)
|
||||
set(SOURCES ${SOURCES} src/tls.cpp)
|
||||
endif()
|
||||
|
||||
source_group(TREE "${CMAKE_CURRENT_SOURCE_DIR}" PREFIX "Header Files" FILES ${HEADERS})
|
||||
source_group(TREE "${CMAKE_CURRENT_SOURCE_DIR}" PREFIX "Source Files" FILES ${SOURCES})
|
||||
|
||||
|
|
|
@ -35,5 +35,6 @@ add_definitions(-DPROTOBUF_ENABLE_DEBUG_LOGGING_MAY_LEAK_PII=0)
|
|||
add_subdirectory(external/src/grpc)
|
||||
|
||||
include_directories(external/src/grpc/third_party/abseil-cpp)
|
||||
include_directories(external/src/grpc/third_party/boringssl-with-bazel/src/include)
|
||||
include_directories(external/src/grpc/third_party/protobuf/src)
|
||||
include_directories(external/src/grpc/include)
|
||||
|
|
|
@ -272,7 +272,7 @@ bool StratumServer::on_login(StratumClient* client, uint32_t id, const char* log
|
|||
target = std::max(target, aux_diff.target());
|
||||
|
||||
if (get_custom_diff(login, client->m_customDiff)) {
|
||||
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << " set custom difficulty " << client->m_customDiff);
|
||||
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << log::NoColor() << " set custom difficulty " << client->m_customDiff);
|
||||
target = std::max(target, client->m_customDiff.target());
|
||||
}
|
||||
else if (m_autoDiff) {
|
||||
|
@ -282,7 +282,7 @@ bool StratumServer::on_login(StratumClient* client, uint32_t id, const char* log
|
|||
|
||||
if (get_custom_user(login, client->m_customUser)) {
|
||||
const char* s = client->m_customUser;
|
||||
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << " set custom user " << s);
|
||||
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << log::NoColor() << " set custom user " << s);
|
||||
}
|
||||
|
||||
uint32_t job_id;
|
||||
|
@ -541,6 +541,7 @@ void StratumServer::show_workers()
|
|||
size_t n = 0;
|
||||
|
||||
LOGINFO(0, log::pad_right("IP:port", addr_len + 8)
|
||||
<< "TLS "
|
||||
<< log::pad_right("uptime", 20)
|
||||
<< log::pad_right("difficulty", 20)
|
||||
<< log::pad_right("hashrate", 15)
|
||||
|
@ -558,6 +559,7 @@ void StratumServer::show_workers()
|
|||
}
|
||||
}
|
||||
LOGINFO(0, log::pad_right(static_cast<const char*>(c->m_addrString), addr_len + 8)
|
||||
<< (c->m_tls.is_empty() ? "no " : "yes ")
|
||||
<< log::pad_right(log::Duration(cur_time - c->m_connectedTime), 20)
|
||||
<< log::pad_right(diff, 20)
|
||||
<< log::pad_right(log::Hashrate(c->m_autoDiff.lo / AUTO_DIFF_TARGET_TIME, m_autoDiff && (c->m_autoDiff != 0)), 15)
|
||||
|
@ -1080,7 +1082,8 @@ void StratumServer::on_shutdown()
|
|||
}
|
||||
|
||||
StratumServer::StratumClient::StratumClient()
|
||||
: Client(m_stratumReadBuf, sizeof(m_stratumReadBuf))
|
||||
: Client(m_rawReadBuf, sizeof(m_rawReadBuf))
|
||||
, m_stratumReadBufBytes(0)
|
||||
, m_rpcId(0)
|
||||
, m_perConnectionJobId(0)
|
||||
, m_connectedTime(0)
|
||||
|
@ -1100,6 +1103,10 @@ StratumServer::StratumClient::StratumClient()
|
|||
void StratumServer::StratumClient::reset()
|
||||
{
|
||||
Client::reset();
|
||||
|
||||
m_stratumReadBuf[0] = '\0';
|
||||
m_stratumReadBufBytes = 0;
|
||||
|
||||
m_rpcId = 0;
|
||||
m_perConnectionJobId = 0;
|
||||
m_connectedTime = 0;
|
||||
|
@ -1127,35 +1134,69 @@ bool StratumServer::StratumClient::on_connect()
|
|||
|
||||
bool StratumServer::StratumClient::on_read(char* data, uint32_t size)
|
||||
{
|
||||
if ((data != m_readBuf + m_numRead) || (data + size > m_readBuf + m_readBufSize)) {
|
||||
LOGERR(1, "client: invalid data pointer or size in on_read()");
|
||||
ban(DEFAULT_BAN_TIME);
|
||||
return false;
|
||||
}
|
||||
|
||||
m_numRead += size;
|
||||
|
||||
char* line_start = m_readBuf;
|
||||
for (char* c = data; c < m_readBuf + m_numRead; ++c) {
|
||||
if (*c == '\n') {
|
||||
*c = '\0';
|
||||
if (!process_request(line_start, static_cast<uint32_t>(c - line_start))) {
|
||||
ban(DEFAULT_BAN_TIME);
|
||||
#ifdef WITH_TLS
|
||||
if (!m_tlsChecked) {
|
||||
if (data[0] == 0x16) {
|
||||
if (!m_tls.init()) {
|
||||
LOGWARN(5, "client " << static_cast<const char*>(m_addrString) << ": TLS init failed");
|
||||
return false;
|
||||
}
|
||||
line_start = c + 1;
|
||||
LOGINFO(5, "client " << log::Gray() << static_cast<const char*>(m_addrString) << log::NoColor() << " is using TLS");
|
||||
}
|
||||
m_tlsChecked = true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Move the possible unfinished line to the beginning of m_readBuf to free up more space for reading
|
||||
if (line_start != m_readBuf) {
|
||||
m_numRead = static_cast<uint32_t>(m_readBuf + m_numRead - line_start);
|
||||
if (m_numRead > 0) {
|
||||
memmove(m_readBuf, line_start, m_numRead);
|
||||
auto on_parse = [this](char* data, uint32_t size) {
|
||||
if (static_cast<size_t>(m_stratumReadBufBytes) + size > STRATUM_BUF_SIZE) {
|
||||
LOGWARN(4, "client " << static_cast<const char*>(m_addrString) << " sent too long Stratum message");
|
||||
ban(DEFAULT_BAN_TIME);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
memcpy(m_stratumReadBuf + m_stratumReadBufBytes, data, size);
|
||||
m_stratumReadBufBytes += size;
|
||||
|
||||
char* line_start = m_stratumReadBuf;
|
||||
for (char *e = line_start + m_stratumReadBufBytes, *c = e - size; c < e; ++c) {
|
||||
if (*c == '\n') {
|
||||
*c = '\0';
|
||||
if (!process_request(line_start, static_cast<uint32_t>(c - line_start))) {
|
||||
ban(DEFAULT_BAN_TIME);
|
||||
return false;
|
||||
}
|
||||
line_start = c + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Move the possible unfinished line to the beginning of m_stratumReadBuf to free up more space for reading
|
||||
if (line_start != m_stratumReadBuf) {
|
||||
m_stratumReadBufBytes = static_cast<uint32_t>(m_stratumReadBuf + m_stratumReadBufBytes - line_start);
|
||||
if (m_stratumReadBufBytes > 0) {
|
||||
memmove(m_stratumReadBuf, line_start, m_stratumReadBufBytes);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
#ifdef WITH_TLS
|
||||
if (!m_tls.is_empty()) {
|
||||
auto on_write = [this](const uint8_t* data, size_t size) {
|
||||
return m_owner->send(this, [data, size](uint8_t* buf, size_t buf_size) -> size_t {
|
||||
if (buf_size < size) {
|
||||
return 0;
|
||||
}
|
||||
memcpy(buf, data, size);
|
||||
return size;
|
||||
}, true);
|
||||
};
|
||||
|
||||
return m_tls.on_read(data, size, std::move(on_parse), std::move(on_write));
|
||||
}
|
||||
#endif
|
||||
|
||||
return on_parse(data, size);
|
||||
}
|
||||
|
||||
bool StratumServer::StratumClient::process_request(char* data, uint32_t size)
|
||||
|
|
|
@ -52,7 +52,10 @@ public:
|
|||
[[nodiscard]] bool process_login(rapidjson::Document& doc, uint32_t id);
|
||||
[[nodiscard]] bool process_submit(rapidjson::Document& doc, uint32_t id);
|
||||
|
||||
alignas(8) char m_rawReadBuf[STRATUM_BUF_SIZE];
|
||||
|
||||
alignas(8) char m_stratumReadBuf[STRATUM_BUF_SIZE];
|
||||
uint32_t m_stratumReadBufBytes;
|
||||
|
||||
uint32_t m_rpcId;
|
||||
uint32_t m_perConnectionJobId;
|
||||
|
|
|
@ -538,7 +538,7 @@ void TCPServer::print_bans()
|
|||
}
|
||||
}
|
||||
|
||||
bool TCPServer::send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback)
|
||||
bool TCPServer::send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback, bool raw)
|
||||
{
|
||||
check_event_loop_thread(__func__);
|
||||
|
||||
|
@ -559,34 +559,48 @@ bool TCPServer::send_internal(Client* client, Callback<size_t, uint8_t*, size_t>
|
|||
return false;
|
||||
}
|
||||
|
||||
WriteBuf* buf = get_write_buffer(bytes_written);
|
||||
auto on_write = [this, client](const uint8_t* data, size_t size) {
|
||||
WriteBuf* buf = get_write_buffer(size);
|
||||
|
||||
buf->m_write.data = buf;
|
||||
buf->m_client = client;
|
||||
buf->m_write.data = buf;
|
||||
buf->m_client = client;
|
||||
|
||||
if (buf->m_dataCapacity < bytes_written) {
|
||||
buf->m_dataCapacity = round_up(bytes_written, 64);
|
||||
buf->m_data = realloc_hook(buf->m_data, buf->m_dataCapacity);
|
||||
if (!buf->m_data) {
|
||||
LOGERR(0, "failed to allocate " << buf->m_dataCapacity << " bytes to send data");
|
||||
PANIC_STOP();
|
||||
if (buf->m_dataCapacity < size) {
|
||||
buf->m_dataCapacity = round_up(size, 64);
|
||||
buf->m_data = realloc_hook(buf->m_data, buf->m_dataCapacity);
|
||||
if (!buf->m_data) {
|
||||
LOGERR(0, "failed to allocate " << buf->m_dataCapacity << " bytes to send data");
|
||||
PANIC_STOP();
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(buf->m_data, data, size);
|
||||
|
||||
uv_buf_t bufs[1];
|
||||
bufs[0].base = reinterpret_cast<char*>(buf->m_data);
|
||||
bufs[0].len = static_cast<int>(size);
|
||||
|
||||
const int err = uv_write(&buf->m_write, reinterpret_cast<uv_stream_t*>(&client->m_socket), bufs, 1, Client::on_write);
|
||||
if (err) {
|
||||
LOGWARN(1, "failed to start writing data to client connection " << static_cast<const char*>(client->m_addrString) << ", error " << uv_err_name(err));
|
||||
return_write_buffer(buf);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
#ifdef WITH_TLS
|
||||
if (!client->m_tls.is_empty() && !raw) {
|
||||
if (!client->m_tls.on_write(m_callbackBuf.data(), bytes_written, std::move(on_write))) {
|
||||
LOGWARN(1, "TLS write failed to client connection " << static_cast<const char*>(client->m_addrString));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
memcpy(buf->m_data, m_callbackBuf.data(), bytes_written);
|
||||
|
||||
uv_buf_t bufs[1];
|
||||
bufs[0].base = reinterpret_cast<char*>(buf->m_data);
|
||||
bufs[0].len = static_cast<int>(bytes_written);
|
||||
|
||||
const int err = uv_write(&buf->m_write, reinterpret_cast<uv_stream_t*>(&client->m_socket), bufs, 1, Client::on_write);
|
||||
if (err) {
|
||||
LOGWARN(1, "failed to start writing data to client connection " << static_cast<const char*>(client->m_addrString) << ", error " << uv_err_name(err));
|
||||
return_write_buffer(buf);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return on_write(m_callbackBuf.data(), bytes_written);
|
||||
}
|
||||
|
||||
const char* TCPServer::get_log_category() const
|
||||
|
@ -999,6 +1013,7 @@ TCPServer::Client::Client(char* read_buf, size_t size)
|
|||
, m_addrString{}
|
||||
, m_socks5ProxyState(Socks5ProxyState::Default)
|
||||
, m_resetCounter{ 0 }
|
||||
, m_tlsChecked(false)
|
||||
{
|
||||
m_readBuf[0] = '\0';
|
||||
m_readBuf[m_readBufSize - 1] = '\0';
|
||||
|
@ -1023,6 +1038,11 @@ void TCPServer::Client::reset()
|
|||
m_socks5ProxyState = Socks5ProxyState::Default;
|
||||
m_readBuf[0] = '\0';
|
||||
m_readBuf[m_readBufSize - 1] = '\0';
|
||||
|
||||
#ifdef WITH_TLS
|
||||
m_tls.reset();
|
||||
m_tlsChecked = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void TCPServer::Client::on_alloc(uv_handle_t* handle, size_t /*suggested_size*/, uv_buf_t* buf)
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#pragma once
|
||||
|
||||
#include "uv_util.h"
|
||||
#include "tls.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace p2pool {
|
||||
|
@ -106,6 +108,11 @@ public:
|
|||
} m_socks5ProxyState;
|
||||
|
||||
std::atomic<uint32_t> m_resetCounter;
|
||||
|
||||
#ifdef WITH_TLS
|
||||
ServerTls m_tls;
|
||||
bool m_tlsChecked;
|
||||
#endif
|
||||
};
|
||||
|
||||
struct WriteBuf
|
||||
|
@ -128,7 +135,7 @@ public:
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
[[nodiscard]] FORCEINLINE bool send(Client* client, T&& callback) { return send_internal(client, Callback<size_t, uint8_t*, size_t>::Derived<T>(std::move(callback))); }
|
||||
[[nodiscard]] FORCEINLINE bool send(Client* client, T&& callback, bool raw = false) { return send_internal(client, Callback<size_t, uint8_t*, size_t>::Derived<T>(std::move(callback)), raw); }
|
||||
|
||||
private:
|
||||
static void on_new_connection(uv_stream_t* server, int status);
|
||||
|
@ -138,7 +145,7 @@ private:
|
|||
void on_new_client(uv_stream_t* server);
|
||||
void on_new_client(uv_stream_t* server, Client* client);
|
||||
|
||||
[[nodiscard]] bool send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback);
|
||||
[[nodiscard]] bool send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback, bool raw);
|
||||
|
||||
allocate_client_callback m_allocateNewClient;
|
||||
|
||||
|
|
272
src/tls.cpp
Normal file
272
src/tls.cpp
Normal file
|
@ -0,0 +1,272 @@
|
|||
/*
|
||||
* This file is part of the Monero P2Pool <https://github.com/SChernykh/p2pool>
|
||||
* Copyright (c) 2021-2024 SChernykh <https://github.com/SChernykh>
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, version 3.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful, but
|
||||
* WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#include "common.h"
|
||||
#include "tls.h"
|
||||
|
||||
LOG_CATEGORY(TLS)
|
||||
|
||||
namespace p2pool {
|
||||
|
||||
static bssl::UniquePtr<EVP_PKEY> init_evp_pkey()
|
||||
{
|
||||
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
|
||||
|
||||
if (!evp_pkey.get()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bssl::UniquePtr<EC_KEY> ec_key(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
|
||||
|
||||
if (!ec_key || !EC_KEY_generate_key(ec_key.get())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!EVP_PKEY_assign_EC_KEY(evp_pkey.get(), ec_key.release())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return evp_pkey;
|
||||
}
|
||||
|
||||
static bssl::UniquePtr<EVP_PKEY> s_evp_pkey = init_evp_pkey();
|
||||
|
||||
static bssl::UniquePtr<X509> init_cert()
|
||||
{
|
||||
bssl::UniquePtr<X509> x509(X509_new());
|
||||
|
||||
if (!x509.get()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!X509_set_version(x509.get(), X509_VERSION_3)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::mt19937_64 rng(RandomDeviceSeed::instance);
|
||||
rng.discard(10000);
|
||||
|
||||
const uint64_t serial = rng();
|
||||
|
||||
if (!ASN1_INTEGER_set_uint64(X509_get_serialNumber(x509.get()), serial)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
constexpr int64_t YEAR = 31557600;
|
||||
|
||||
const time_t cur_time = time(nullptr);
|
||||
|
||||
const time_t t0 = cur_time - (cur_time % YEAR);
|
||||
const time_t t1 = t0 - YEAR * 10;
|
||||
const time_t t2 = t0 + YEAR * 10;
|
||||
|
||||
if (!ASN1_TIME_set(X509_get_notBefore(x509.get()), t1) || !ASN1_TIME_set(X509_get_notAfter(x509.get()), t2)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
X509_NAME* subject = X509_get_subject_name(x509.get());
|
||||
|
||||
if (!X509_NAME_add_entry_by_txt(subject, "C", MBSTRING_ASC, reinterpret_cast<const uint8_t*>("US"), -1, -1, 0) ||
|
||||
!X509_NAME_add_entry_by_txt(subject, "O", MBSTRING_ASC, reinterpret_cast<const uint8_t*>("BoringSSL"), -1, -1, 0) ||
|
||||
!X509_set_issuer_name(x509.get(), subject)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bssl::UniquePtr<STACK_OF(ASN1_OBJECT)> ekus(sk_ASN1_OBJECT_new_null());
|
||||
|
||||
if (!ekus || !sk_ASN1_OBJECT_push(ekus.get(), OBJ_nid2obj(NID_server_auth)) || !X509_add1_ext_i2d(x509.get(), NID_ext_key_usage, ekus.get(), 1, 0)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!X509_set_pubkey(x509.get(), s_evp_pkey.get())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!X509_sign(x509.get(), s_evp_pkey.get(), EVP_sha256())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return x509;
|
||||
}
|
||||
|
||||
static bssl::UniquePtr<X509> s_cert = init_cert();
|
||||
|
||||
static bssl::UniquePtr<SSL_CTX> init_ctx()
|
||||
{
|
||||
if (!s_evp_pkey.get() || !s_cert.get()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
|
||||
|
||||
if (!ctx.get()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!SSL_CTX_use_PrivateKey(ctx.get(), s_evp_pkey.get())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!SSL_CTX_use_certificate(ctx.get(), s_cert.get())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
static bssl::UniquePtr<SSL_CTX> s_ctx = init_ctx();
|
||||
|
||||
void ServerTls::reset()
|
||||
{
|
||||
m_ssl.reset(nullptr);
|
||||
}
|
||||
|
||||
bool ServerTls::init()
|
||||
{
|
||||
if (!s_ctx.get()) {
|
||||
static std::atomic<uint32_t> ctx_error_shown = 0;
|
||||
if (ctx_error_shown.exchange(1) == 0) {
|
||||
LOGERR(0, "Failed to initialize an SSL context");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
m_ssl.reset(SSL_new(s_ctx.get()));
|
||||
|
||||
if (!m_ssl.get()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
SSL_set_accept_state(m_ssl.get());
|
||||
|
||||
BIO* rbio = BIO_new(BIO_s_mem());
|
||||
BIO* wbio = BIO_new(BIO_s_mem());
|
||||
|
||||
if (!rbio || !wbio) {
|
||||
BIO_free(rbio);
|
||||
BIO_free(wbio);
|
||||
|
||||
m_ssl.reset(nullptr);
|
||||
return false;
|
||||
}
|
||||
|
||||
SSL_set_bio(m_ssl.get(), rbio, wbio);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ServerTls::on_read_internal(char* data, uint32_t size, ReadCallback::Base&& read_callback, WriteCallback::Base&& write_callback)
|
||||
{
|
||||
SSL* ssl = m_ssl.get();
|
||||
if (!ssl) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!BIO_write_all(SSL_get_rbio(ssl), data, size)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int bytes_read = 0;
|
||||
char buf[1024];
|
||||
|
||||
if (!SSL_is_init_finished(ssl)) {
|
||||
const int result = SSL_do_handshake(ssl);
|
||||
|
||||
if (!result) {
|
||||
// EOF
|
||||
return false;
|
||||
}
|
||||
|
||||
// Send pending handshake data, if any
|
||||
BIO* wbio = SSL_get_wbio(ssl);
|
||||
if (!wbio) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint8_t* bio_data;
|
||||
size_t bio_len;
|
||||
|
||||
if (!BIO_mem_contents(wbio, &bio_data, &bio_len)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (bio_len > 0) {
|
||||
if (!write_callback(bio_data, bio_len)) {
|
||||
return false;
|
||||
}
|
||||
if (!BIO_reset(wbio)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if ((result < 0) && (SSL_get_error(ssl, result) == SSL_ERROR_WANT_READ)) {
|
||||
// Continue handshake, nothing to read yet
|
||||
return true;
|
||||
}
|
||||
else if (result == 1) {
|
||||
// Handshake finished, skip to "SSL_read" further down
|
||||
}
|
||||
else {
|
||||
// Some other error
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
while ((bytes_read = SSL_read(ssl, buf, sizeof(buf))) > 0) {
|
||||
if (!read_callback(buf, static_cast<uint32_t>(bytes_read))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ServerTls::on_write_internal(const uint8_t* data, size_t size, WriteCallback::Base&& write_callback)
|
||||
{
|
||||
SSL* ssl = m_ssl.get();
|
||||
if (!ssl) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (SSL_write(ssl, data, static_cast<int>(size)) <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BIO* wbio = SSL_get_wbio(ssl);
|
||||
if (!wbio) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint8_t* bio_data;
|
||||
size_t bio_len;
|
||||
|
||||
if (!BIO_mem_contents(wbio, &bio_data, &bio_len)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (bio_len > 0) {
|
||||
if (!write_callback(bio_data, bio_len)) {
|
||||
return false;
|
||||
}
|
||||
if (!BIO_reset(wbio)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace p2pool
|
58
src/tls.h
Normal file
58
src/tls.h
Normal file
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* This file is part of the Monero P2Pool <https://github.com/SChernykh/p2pool>
|
||||
* Copyright (c) 2021-2024 SChernykh <https://github.com/SChernykh>
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, version 3.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful, but
|
||||
* WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
* General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openssl/base.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
namespace p2pool {
|
||||
|
||||
class ServerTls
|
||||
{
|
||||
public:
|
||||
FORCEINLINE ServerTls() { reset(); }
|
||||
|
||||
void reset();
|
||||
[[nodiscard]] bool init();
|
||||
|
||||
template<typename T, typename U>
|
||||
[[nodiscard]] FORCEINLINE bool on_read(char* data, uint32_t size, T&& read_callback, U&& write_callback)
|
||||
{
|
||||
return on_read_internal(data, size, ReadCallback::Derived<T>(std::move(read_callback)), WriteCallback::Derived<U>(std::move(write_callback)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
[[nodiscard]] FORCEINLINE bool on_write(const uint8_t* data, size_t size, T&& write_callback)
|
||||
{
|
||||
return on_write_internal(data, size, WriteCallback::Derived<T>(std::move(write_callback)));
|
||||
}
|
||||
|
||||
[[nodiscard]] FORCEINLINE bool is_empty() const { return m_ssl.get() == nullptr; }
|
||||
|
||||
private:
|
||||
typedef Callback<bool, char*, uint32_t> ReadCallback;
|
||||
typedef Callback<bool, const uint8_t*, size_t> WriteCallback;
|
||||
|
||||
[[nodiscard]] bool on_read_internal(char* data, uint32_t size, ReadCallback::Base&& read_callback, WriteCallback::Base&& write_callback);
|
||||
[[nodiscard]] bool on_write_internal(const uint8_t* data, size_t size, WriteCallback::Base&& write_callback);
|
||||
|
||||
private:
|
||||
bssl::UniquePtr<SSL> m_ssl;
|
||||
};
|
||||
|
||||
} // namespace p2pool
|
Loading…
Reference in a new issue