Stratum: added TLS support
Some checks are pending
C/C++ CI / build-alpine-static (map[arch:aarch64 branch:latest-stable flags:-ffunction-sections -Wno-error=inline -mfix-cortex-a53-835769 -mfix-cortex-a53-843419]) (push) Waiting to run
C/C++ CI / build-alpine-static (map[arch:riscv64 branch:edge flags:-ffunction-sections -Wno-error=inline]) (push) Waiting to run
C/C++ CI / build-alpine-static (map[arch:x86_64 branch:latest-stable flags:-ffunction-sections -Wno-error=inline]) (push) Waiting to run
C/C++ CI / build-ubuntu (map[c:gcc-11 cpp:g++-11 flags: os:ubuntu-20.04]) (push) Waiting to run
C/C++ CI / build-ubuntu (map[c:gcc-12 cpp:g++-12 flags: os:ubuntu-22.04]) (push) Waiting to run
C/C++ CI / build-ubuntu (map[c:gcc-8 cpp:g++-8 flags: os:ubuntu-20.04]) (push) Waiting to run
C/C++ CI / build-ubuntu-static-libs (map[flags:-fuse-linker-plugin -ffunction-sections -Wno-error=inline]) (push) Waiting to run
C/C++ CI / build-ubuntu-aarch64 (map[flags:-fuse-linker-plugin -ffunction-sections -mfix-cortex-a53-835769 -mfix-cortex-a53-843419 os:ubuntu-20.04]) (push) Waiting to run
C/C++ CI / build-ubuntu-aarch64 (map[flags:-fuse-linker-plugin -ffunction-sections -mfix-cortex-a53-835769 -mfix-cortex-a53-843419 os:ubuntu-22.04]) (push) Waiting to run
C/C++ CI / build-windows-msys2 (map[c:clang cxx:clang++ flags:-fuse-ld=lld -Wno-unused-command-line-argument -Wno-nan-infinity-disabled]) (push) Waiting to run
C/C++ CI / build-windows-msys2 (map[c:gcc cxx:g++ flags:-ffunction-sections -Wno-error=maybe-uninitialized -Wno-error=attributes]) (push) Waiting to run
C/C++ CI / build-windows-msbuild (map[grpc:OFF os:2019 rx:OFF upnp:OFF vs:Visual Studio 16 2019 vspath:C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise]) (push) Waiting to run
C/C++ CI / build-windows-msbuild (map[grpc:OFF os:2019 rx:OFF upnp:ON vs:Visual Studio 16 2019 vspath:C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise]) (push) Waiting to run
C/C++ CI / build-windows-msbuild (map[grpc:ON os:2019 rx:ON upnp:ON vs:Visual Studio 16 2019 vspath:C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise]) (push) Waiting to run
clang-tidy / clang-tidy (push) Waiting to run
C/C++ CI / build-windows-msbuild (map[grpc:OFF os:2019 rx:ON upnp:ON vs:Visual Studio 16 2019 vspath:C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise]) (push) Waiting to run
C/C++ CI / build-windows-msbuild (map[grpc:ON os:2022 rx:ON upnp:ON vs:Visual Studio 17 2022 vspath:C:\Program Files\Microsoft Visual Studio\2022\Enterprise]) (push) Waiting to run
C/C++ CI / build-macos (push) Waiting to run
C/C++ CI / build-macos-aarch64 (push) Waiting to run
C/C++ CI / build-freebsd (map[architecture:x86-64 host:ubuntu-latest name:freebsd version:13.3]) (push) Waiting to run
C/C++ CI / build-openbsd (map[architecture:x86-64 host:ubuntu-latest name:openbsd version:7.4]) (push) Waiting to run
CodeQL / Analyze (cpp) (push) Waiting to run
cppcheck / cppcheck-ubuntu (push) Waiting to run
cppcheck / cppcheck-windows (push) Waiting to run
Microsoft C++ Code Analysis / Analyze (push) Waiting to run
source-snapshot / source-snapshot (push) Waiting to run
Sync test (old) / sync-test-ubuntu-tsan (push) Waiting to run
Sync test (old) / sync-test-ubuntu-msan (push) Waiting to run
Sync test (old) / sync-test-ubuntu-ubsan (push) Waiting to run
Sync test (old) / sync-test-ubuntu-asan (push) Waiting to run
Sync test (old) / sync-test-macos (map[flags:-Og -ftrapv -target arm64-apple-macos-11 os:macos-14]) (push) Waiting to run
Sync test (old) / sync-test-macos (map[flags:-Og -ftrapv os:macos-13]) (push) Waiting to run
Sync test (old) / sync-test-windows-debug-asan (push) Waiting to run
Sync test (old) / sync-test-windows-leaks (push) Waiting to run
Sync test / sync-test-ubuntu-tsan (push) Waiting to run
Sync test / sync-test-ubuntu-msan (push) Waiting to run
Sync test / sync-test-ubuntu-ubsan (push) Waiting to run
Sync test / sync-test-ubuntu-asan (push) Waiting to run
Sync test / sync-test-macos (map[flags:-Og -ftrapv -target arm64-apple-macos-11 os:macos-14]) (push) Waiting to run
Sync test / sync-test-macos (map[flags:-Og -ftrapv os:macos-13]) (push) Waiting to run
Sync test / sync-test-windows-debug-asan (push) Waiting to run
Sync test / sync-test-windows-leaks (push) Waiting to run

This commit is contained in:
SChernykh 2024-08-04 21:56:26 +02:00
parent cb8ea37dab
commit 127dcc04bf
8 changed files with 485 additions and 51 deletions

View file

@ -17,6 +17,7 @@ option(WITH_RANDOMX "Include the RandomX library in the build. If this is turned
option(WITH_LTO "Use link-time compiler optimization (if linking fails for you, run cmake with -DWITH_LTO=OFF)" ON)
option(WITH_UPNP "Include UPnP support. If this is turned off, p2pool will not be able to configure port forwarding on UPnP-enabled routers." ON)
option(WITH_GRPC "Include gRPC support. If this is turned off, p2pool will not be able to merge mine with Tari." ON)
option(WITH_TLS "Include TLS support. If this is turned off, p2pool will not support Stratum TLS connections." ON)
option(DEV_TEST_SYNC "[Developer only] Sync test, stop p2pool after sync is complete" OFF)
option(DEV_WITH_TSAN "[Developer only] Compile with thread sanitizer" OFF)
@ -166,6 +167,13 @@ if (WITH_GRPC)
set(SOURCES ${SOURCES} src/merge_mining_client_tari.cpp)
endif()
if (WITH_GRPC AND WITH_TLS)
add_definitions(-DWITH_TLS)
set(HEADERS ${HEADERS} src/tls.h)
set(SOURCES ${SOURCES} src/tls.cpp)
endif()
source_group(TREE "${CMAKE_CURRENT_SOURCE_DIR}" PREFIX "Header Files" FILES ${HEADERS})
source_group(TREE "${CMAKE_CURRENT_SOURCE_DIR}" PREFIX "Source Files" FILES ${SOURCES})

View file

@ -35,5 +35,6 @@ add_definitions(-DPROTOBUF_ENABLE_DEBUG_LOGGING_MAY_LEAK_PII=0)
add_subdirectory(external/src/grpc)
include_directories(external/src/grpc/third_party/abseil-cpp)
include_directories(external/src/grpc/third_party/boringssl-with-bazel/src/include)
include_directories(external/src/grpc/third_party/protobuf/src)
include_directories(external/src/grpc/include)

View file

@ -272,7 +272,7 @@ bool StratumServer::on_login(StratumClient* client, uint32_t id, const char* log
target = std::max(target, aux_diff.target());
if (get_custom_diff(login, client->m_customDiff)) {
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << " set custom difficulty " << client->m_customDiff);
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << log::NoColor() << " set custom difficulty " << client->m_customDiff);
target = std::max(target, client->m_customDiff.target());
}
else if (m_autoDiff) {
@ -282,7 +282,7 @@ bool StratumServer::on_login(StratumClient* client, uint32_t id, const char* log
if (get_custom_user(login, client->m_customUser)) {
const char* s = client->m_customUser;
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << " set custom user " << s);
LOGINFO(5, "client " << log::Gray() << static_cast<char*>(client->m_addrString) << log::NoColor() << " set custom user " << s);
}
uint32_t job_id;
@ -541,6 +541,7 @@ void StratumServer::show_workers()
size_t n = 0;
LOGINFO(0, log::pad_right("IP:port", addr_len + 8)
<< "TLS "
<< log::pad_right("uptime", 20)
<< log::pad_right("difficulty", 20)
<< log::pad_right("hashrate", 15)
@ -557,7 +558,15 @@ void StratumServer::show_workers()
++diff.lo;
}
}
#ifdef WITH_TLS
const bool is_tls = c->m_tls.is_empty();
#else
constexpr bool is_tls = false;
#endif
LOGINFO(0, log::pad_right(static_cast<const char*>(c->m_addrString), addr_len + 8)
<< (is_tls ? "no " : "yes ")
<< log::pad_right(log::Duration(cur_time - c->m_connectedTime), 20)
<< log::pad_right(diff, 20)
<< log::pad_right(log::Hashrate(c->m_autoDiff.lo / AUTO_DIFF_TARGET_TIME, m_autoDiff && (c->m_autoDiff != 0)), 15)
@ -1080,7 +1089,8 @@ void StratumServer::on_shutdown()
}
StratumServer::StratumClient::StratumClient()
: Client(m_stratumReadBuf, sizeof(m_stratumReadBuf))
: Client(m_rawReadBuf, sizeof(m_rawReadBuf))
, m_stratumReadBufBytes(0)
, m_rpcId(0)
, m_perConnectionJobId(0)
, m_connectedTime(0)
@ -1100,6 +1110,10 @@ StratumServer::StratumClient::StratumClient()
void StratumServer::StratumClient::reset()
{
Client::reset();
m_stratumReadBuf[0] = '\0';
m_stratumReadBufBytes = 0;
m_rpcId = 0;
m_perConnectionJobId = 0;
m_connectedTime = 0;
@ -1127,35 +1141,69 @@ bool StratumServer::StratumClient::on_connect()
bool StratumServer::StratumClient::on_read(char* data, uint32_t size)
{
if ((data != m_readBuf + m_numRead) || (data + size > m_readBuf + m_readBufSize)) {
LOGERR(1, "client: invalid data pointer or size in on_read()");
ban(DEFAULT_BAN_TIME);
return false;
}
m_numRead += size;
char* line_start = m_readBuf;
for (char* c = data; c < m_readBuf + m_numRead; ++c) {
if (*c == '\n') {
*c = '\0';
if (!process_request(line_start, static_cast<uint32_t>(c - line_start))) {
ban(DEFAULT_BAN_TIME);
#ifdef WITH_TLS
if (!m_tlsChecked) {
if (data[0] == 0x16) {
if (!m_tls.init()) {
LOGWARN(5, "client " << static_cast<const char*>(m_addrString) << ": TLS init failed");
return false;
}
line_start = c + 1;
LOGINFO(5, "client " << log::Gray() << static_cast<const char*>(m_addrString) << log::NoColor() << " is using TLS");
}
m_tlsChecked = true;
}
#endif
// Move the possible unfinished line to the beginning of m_readBuf to free up more space for reading
if (line_start != m_readBuf) {
m_numRead = static_cast<uint32_t>(m_readBuf + m_numRead - line_start);
if (m_numRead > 0) {
memmove(m_readBuf, line_start, m_numRead);
auto on_parse = [this](char* data, uint32_t size) {
if (static_cast<size_t>(m_stratumReadBufBytes) + size > STRATUM_BUF_SIZE) {
LOGWARN(4, "client " << static_cast<const char*>(m_addrString) << " sent too long Stratum message");
ban(DEFAULT_BAN_TIME);
return false;
}
}
return true;
memcpy(m_stratumReadBuf + m_stratumReadBufBytes, data, size);
m_stratumReadBufBytes += size;
char* line_start = m_stratumReadBuf;
for (char *e = line_start + m_stratumReadBufBytes, *c = e - size; c < e; ++c) {
if (*c == '\n') {
*c = '\0';
if (!process_request(line_start, static_cast<uint32_t>(c - line_start))) {
ban(DEFAULT_BAN_TIME);
return false;
}
line_start = c + 1;
}
}
// Move the possible unfinished line to the beginning of m_stratumReadBuf to free up more space for reading
if (line_start != m_stratumReadBuf) {
m_stratumReadBufBytes = static_cast<uint32_t>(m_stratumReadBuf + m_stratumReadBufBytes - line_start);
if (m_stratumReadBufBytes > 0) {
memmove(m_stratumReadBuf, line_start, m_stratumReadBufBytes);
}
}
return true;
};
#ifdef WITH_TLS
if (!m_tls.is_empty()) {
auto on_write = [this](const uint8_t* data, size_t size) {
return m_owner->send(this, [data, size](uint8_t* buf, size_t buf_size) -> size_t {
if (buf_size < size) {
return 0;
}
memcpy(buf, data, size);
return size;
}, true);
};
return m_tls.on_read(data, size, std::move(on_parse), std::move(on_write));
}
#endif
return on_parse(data, size);
}
bool StratumServer::StratumClient::process_request(char* data, uint32_t size)

View file

@ -52,7 +52,10 @@ public:
[[nodiscard]] bool process_login(rapidjson::Document& doc, uint32_t id);
[[nodiscard]] bool process_submit(rapidjson::Document& doc, uint32_t id);
alignas(8) char m_rawReadBuf[STRATUM_BUF_SIZE];
alignas(8) char m_stratumReadBuf[STRATUM_BUF_SIZE];
uint32_t m_stratumReadBufBytes;
uint32_t m_rpcId;
uint32_t m_perConnectionJobId;

View file

@ -538,7 +538,7 @@ void TCPServer::print_bans()
}
}
bool TCPServer::send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback)
bool TCPServer::send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback, bool raw)
{
check_event_loop_thread(__func__);
@ -559,34 +559,50 @@ bool TCPServer::send_internal(Client* client, Callback<size_t, uint8_t*, size_t>
return false;
}
WriteBuf* buf = get_write_buffer(bytes_written);
auto on_write = [this, client](const uint8_t* data, size_t size) {
WriteBuf* buf = get_write_buffer(size);
buf->m_write.data = buf;
buf->m_client = client;
buf->m_write.data = buf;
buf->m_client = client;
if (buf->m_dataCapacity < bytes_written) {
buf->m_dataCapacity = round_up(bytes_written, 64);
buf->m_data = realloc_hook(buf->m_data, buf->m_dataCapacity);
if (!buf->m_data) {
LOGERR(0, "failed to allocate " << buf->m_dataCapacity << " bytes to send data");
PANIC_STOP();
if (buf->m_dataCapacity < size) {
buf->m_dataCapacity = round_up(size, 64);
buf->m_data = realloc_hook(buf->m_data, buf->m_dataCapacity);
if (!buf->m_data) {
LOGERR(0, "failed to allocate " << buf->m_dataCapacity << " bytes to send data");
PANIC_STOP();
}
}
memcpy(buf->m_data, data, size);
uv_buf_t bufs[1];
bufs[0].base = reinterpret_cast<char*>(buf->m_data);
bufs[0].len = static_cast<int>(size);
const int err = uv_write(&buf->m_write, reinterpret_cast<uv_stream_t*>(&client->m_socket), bufs, 1, Client::on_write);
if (err) {
LOGWARN(1, "failed to start writing data to client connection " << static_cast<const char*>(client->m_addrString) << ", error " << uv_err_name(err));
return_write_buffer(buf);
return false;
}
return true;
};
#ifdef WITH_TLS
if (!client->m_tls.is_empty() && !raw) {
if (!client->m_tls.on_write(m_callbackBuf.data(), bytes_written, std::move(on_write))) {
LOGWARN(1, "TLS write failed to client connection " << static_cast<const char*>(client->m_addrString));
return false;
}
return true;
}
#else
(void)raw;
#endif
memcpy(buf->m_data, m_callbackBuf.data(), bytes_written);
uv_buf_t bufs[1];
bufs[0].base = reinterpret_cast<char*>(buf->m_data);
bufs[0].len = static_cast<int>(bytes_written);
const int err = uv_write(&buf->m_write, reinterpret_cast<uv_stream_t*>(&client->m_socket), bufs, 1, Client::on_write);
if (err) {
LOGWARN(1, "failed to start writing data to client connection " << static_cast<const char*>(client->m_addrString) << ", error " << uv_err_name(err));
return_write_buffer(buf);
return false;
}
return true;
return on_write(m_callbackBuf.data(), bytes_written);
}
const char* TCPServer::get_log_category() const
@ -999,6 +1015,9 @@ TCPServer::Client::Client(char* read_buf, size_t size)
, m_addrString{}
, m_socks5ProxyState(Socks5ProxyState::Default)
, m_resetCounter{ 0 }
#ifdef WITH_TLS
, m_tlsChecked(false)
#endif
{
m_readBuf[0] = '\0';
m_readBuf[m_readBufSize - 1] = '\0';
@ -1023,6 +1042,11 @@ void TCPServer::Client::reset()
m_socks5ProxyState = Socks5ProxyState::Default;
m_readBuf[0] = '\0';
m_readBuf[m_readBufSize - 1] = '\0';
#ifdef WITH_TLS
m_tls.reset();
m_tlsChecked = false;
#endif
}
void TCPServer::Client::on_alloc(uv_handle_t* handle, size_t /*suggested_size*/, uv_buf_t* buf)

View file

@ -18,6 +18,11 @@
#pragma once
#include "uv_util.h"
#ifdef WITH_TLS
#include "tls.h"
#endif
#include <map>
namespace p2pool {
@ -106,6 +111,11 @@ public:
} m_socks5ProxyState;
std::atomic<uint32_t> m_resetCounter;
#ifdef WITH_TLS
ServerTls m_tls;
bool m_tlsChecked;
#endif
};
struct WriteBuf
@ -128,7 +138,7 @@ public:
}
template<typename T>
[[nodiscard]] FORCEINLINE bool send(Client* client, T&& callback) { return send_internal(client, Callback<size_t, uint8_t*, size_t>::Derived<T>(std::move(callback))); }
[[nodiscard]] FORCEINLINE bool send(Client* client, T&& callback, bool raw = false) { return send_internal(client, Callback<size_t, uint8_t*, size_t>::Derived<T>(std::move(callback)), raw); }
private:
static void on_new_connection(uv_stream_t* server, int status);
@ -138,7 +148,7 @@ private:
void on_new_client(uv_stream_t* server);
void on_new_client(uv_stream_t* server, Client* client);
[[nodiscard]] bool send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback);
[[nodiscard]] bool send_internal(Client* client, Callback<size_t, uint8_t*, size_t>::Base&& callback, bool raw);
allocate_client_callback m_allocateNewClient;

272
src/tls.cpp Normal file
View file

@ -0,0 +1,272 @@
/*
* This file is part of the Monero P2Pool <https://github.com/SChernykh/p2pool>
* Copyright (c) 2021-2024 SChernykh <https://github.com/SChernykh>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, version 3.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "common.h"
#include "tls.h"
LOG_CATEGORY(TLS)
namespace p2pool {
static bssl::UniquePtr<EVP_PKEY> init_evp_pkey()
{
bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
if (!evp_pkey.get()) {
return nullptr;
}
bssl::UniquePtr<EC_KEY> ec_key(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
if (!ec_key || !EC_KEY_generate_key(ec_key.get())) {
return nullptr;
}
if (!EVP_PKEY_assign_EC_KEY(evp_pkey.get(), ec_key.release())) {
return nullptr;
}
return evp_pkey;
}
static bssl::UniquePtr<EVP_PKEY> s_evp_pkey = init_evp_pkey();
static bssl::UniquePtr<X509> init_cert()
{
bssl::UniquePtr<X509> x509(X509_new());
if (!x509.get()) {
return nullptr;
}
if (!X509_set_version(x509.get(), X509_VERSION_3)) {
return nullptr;
}
std::mt19937_64 rng(RandomDeviceSeed::instance);
rng.discard(10000);
const uint64_t serial = rng();
if (!ASN1_INTEGER_set_uint64(X509_get_serialNumber(x509.get()), serial)) {
return nullptr;
}
constexpr int64_t YEAR = 31557600;
const time_t cur_time = time(nullptr);
const time_t t0 = cur_time - (cur_time % YEAR);
const time_t t1 = t0 - YEAR * 10;
const time_t t2 = t0 + YEAR * 10;
if (!ASN1_TIME_set(X509_get_notBefore(x509.get()), t1) || !ASN1_TIME_set(X509_get_notAfter(x509.get()), t2)) {
return nullptr;
}
X509_NAME* subject = X509_get_subject_name(x509.get());
if (!X509_NAME_add_entry_by_txt(subject, "C", MBSTRING_ASC, reinterpret_cast<const uint8_t*>("US"), -1, -1, 0) ||
!X509_NAME_add_entry_by_txt(subject, "O", MBSTRING_ASC, reinterpret_cast<const uint8_t*>("BoringSSL"), -1, -1, 0) ||
!X509_set_issuer_name(x509.get(), subject)) {
return nullptr;
}
bssl::UniquePtr<STACK_OF(ASN1_OBJECT)> ekus(sk_ASN1_OBJECT_new_null());
if (!ekus || !sk_ASN1_OBJECT_push(ekus.get(), OBJ_nid2obj(NID_server_auth)) || !X509_add1_ext_i2d(x509.get(), NID_ext_key_usage, ekus.get(), 1, 0)) {
return nullptr;
}
if (!X509_set_pubkey(x509.get(), s_evp_pkey.get())) {
return nullptr;
}
if (!X509_sign(x509.get(), s_evp_pkey.get(), EVP_sha256())) {
return nullptr;
}
return x509;
}
static bssl::UniquePtr<X509> s_cert = init_cert();
static bssl::UniquePtr<SSL_CTX> init_ctx()
{
if (!s_evp_pkey.get() || !s_cert.get()) {
return nullptr;
}
bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
if (!ctx.get()) {
return nullptr;
}
if (!SSL_CTX_use_PrivateKey(ctx.get(), s_evp_pkey.get())) {
return nullptr;
}
if (!SSL_CTX_use_certificate(ctx.get(), s_cert.get())) {
return nullptr;
}
return ctx;
}
static bssl::UniquePtr<SSL_CTX> s_ctx = init_ctx();
void ServerTls::reset()
{
m_ssl.reset(nullptr);
}
bool ServerTls::init()
{
if (!s_ctx.get()) {
static std::atomic<uint32_t> ctx_error_shown = 0;
if (ctx_error_shown.exchange(1) == 0) {
LOGERR(0, "Failed to initialize an SSL context");
}
return false;
}
m_ssl.reset(SSL_new(s_ctx.get()));
if (!m_ssl.get()) {
return false;
}
SSL_set_accept_state(m_ssl.get());
BIO* rbio = BIO_new(BIO_s_mem());
BIO* wbio = BIO_new(BIO_s_mem());
if (!rbio || !wbio) {
BIO_free(rbio);
BIO_free(wbio);
m_ssl.reset(nullptr);
return false;
}
SSL_set_bio(m_ssl.get(), rbio, wbio);
return true;
}
bool ServerTls::on_read_internal(char* data, uint32_t size, ReadCallback::Base&& read_callback, WriteCallback::Base&& write_callback)
{
SSL* ssl = m_ssl.get();
if (!ssl) {
return false;
}
if (!BIO_write_all(SSL_get_rbio(ssl), data, size)) {
return false;
}
int bytes_read = 0;
char buf[1024];
if (!SSL_is_init_finished(ssl)) {
const int result = SSL_do_handshake(ssl);
if (!result) {
// EOF
return false;
}
// Send pending handshake data, if any
BIO* wbio = SSL_get_wbio(ssl);
if (!wbio) {
return false;
}
const uint8_t* bio_data;
size_t bio_len;
if (!BIO_mem_contents(wbio, &bio_data, &bio_len)) {
return false;
}
if (bio_len > 0) {
if (!write_callback(bio_data, bio_len)) {
return false;
}
if (!BIO_reset(wbio)) {
return false;
}
}
if ((result < 0) && (SSL_get_error(ssl, result) == SSL_ERROR_WANT_READ)) {
// Continue handshake, nothing to read yet
return true;
}
else if (result == 1) {
// Handshake finished, skip to "SSL_read" further down
}
else {
// Some other error
return false;
}
}
while ((bytes_read = SSL_read(ssl, buf, sizeof(buf))) > 0) {
if (!read_callback(buf, static_cast<uint32_t>(bytes_read))) {
return false;
}
}
return true;
}
bool ServerTls::on_write_internal(const uint8_t* data, size_t size, WriteCallback::Base&& write_callback)
{
SSL* ssl = m_ssl.get();
if (!ssl) {
return false;
}
if (SSL_write(ssl, data, static_cast<int>(size)) <= 0) {
return false;
}
BIO* wbio = SSL_get_wbio(ssl);
if (!wbio) {
return false;
}
const uint8_t* bio_data;
size_t bio_len;
if (!BIO_mem_contents(wbio, &bio_data, &bio_len)) {
return false;
}
if (bio_len > 0) {
if (!write_callback(bio_data, bio_len)) {
return false;
}
if (!BIO_reset(wbio)) {
return false;
}
}
return true;
}
} // namespace p2pool

68
src/tls.h Normal file
View file

@ -0,0 +1,68 @@
/*
* This file is part of the Monero P2Pool <https://github.com/SChernykh/p2pool>
* Copyright (c) 2021-2024 SChernykh <https://github.com/SChernykh>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, version 3.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#pragma once
#include <openssl/base.h>
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wcast-qual"
#endif
#include <openssl/ssl.h>
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
namespace p2pool {
class ServerTls
{
public:
FORCEINLINE ServerTls() { reset(); }
void reset();
[[nodiscard]] bool init();
template<typename T, typename U>
[[nodiscard]] FORCEINLINE bool on_read(char* data, uint32_t size, T&& read_callback, U&& write_callback)
{
return on_read_internal(data, size, ReadCallback::Derived<T>(std::move(read_callback)), WriteCallback::Derived<U>(std::move(write_callback)));
}
template<typename T>
[[nodiscard]] FORCEINLINE bool on_write(const uint8_t* data, size_t size, T&& write_callback)
{
return on_write_internal(data, size, WriteCallback::Derived<T>(std::move(write_callback)));
}
[[nodiscard]] FORCEINLINE bool is_empty() const { return m_ssl.get() == nullptr; }
private:
typedef Callback<bool, char*, uint32_t> ReadCallback;
typedef Callback<bool, const uint8_t*, size_t> WriteCallback;
[[nodiscard]] bool on_read_internal(char* data, uint32_t size, ReadCallback::Base&& read_callback, WriteCallback::Base&& write_callback);
[[nodiscard]] bool on_write_internal(const uint8_t* data, size_t size, WriteCallback::Base&& write_callback);
private:
bssl::UniquePtr<SSL> m_ssl;
};
} // namespace p2pool