From e7db291669d07ea42d695205037445694d90e51e Mon Sep 17 00:00:00 2001 From: Lee *!* Clagett Date: Sat, 16 Mar 2024 21:40:48 -0400 Subject: [PATCH] ZMQ Hardening (#96) --- src/db/data.cpp | 6 +- src/db/data.h | 4 +- src/db/storage.cpp | 35 +++-- src/rest_server.cpp | 2 +- src/rpc/admin.cpp | 10 +- src/rpc/client.cpp | 24 ++- src/rpc/client.h | 4 +- src/rpc/daemon_pub.cpp | 9 +- src/rpc/daemon_zmq.cpp | 89 ++++++++--- src/rpc/light_wallet.cpp | 6 +- src/server_main.cpp | 2 +- src/wire/adapted/array.h | 2 +- src/wire/error.cpp | 4 + src/wire/error.h | 2 + src/wire/json/read.cpp | 57 +++---- src/wire/json/read.h | 3 +- src/wire/msgpack/error.cpp | 4 + src/wire/msgpack/error.h | 4 +- src/wire/msgpack/read.cpp | 152 ++++++++++--------- src/wire/msgpack/read.h | 36 ++--- src/wire/read.cpp | 7 + src/wire/read.h | 114 +++++++++++--- src/wire/traits.h | 49 ++++++ src/wire/wrappers_impl.h | 20 ++- tests/unit/db/data.test.cpp | 16 +- tests/unit/db/subaddress.test.cpp | 160 ++++++++++---------- tests/unit/scanner.test.cpp | 16 +- tests/unit/wire/json/read.write.test.cpp | 5 +- tests/unit/wire/msgpack/read.write.test.cpp | 6 +- tests/unit/wire/read.write.test.cpp | 11 +- 30 files changed, 564 insertions(+), 295 deletions(-) diff --git a/src/db/data.cpp b/src/db/data.cpp index d78f8c1..21cc81f 100644 --- a/src/db/data.cpp +++ b/src/db/data.cpp @@ -41,7 +41,9 @@ #include "wire/msgpack.h" #include "wire/uuid.h" #include "wire/vector.h" +#include "wire/wrapper/array.h" #include "wire/wrapper/defaulted.h" +#include "wire/wrappers_impl.h" namespace lws { @@ -82,7 +84,7 @@ namespace db { wire::object(format, wire::field<0>("key", std::ref(self.first)), - wire::field<1>("value", std::ref(self.second)) + wire::optional_field<1>("value", std::ref(self.second)) ); } } @@ -91,7 +93,7 @@ namespace db { bool is_first = true; minor_index last = minor_index::primary; - for (const auto& elem : self.second) + for (const auto& elem : self.second.get_container()) { if (elem[1] < elem[0]) { diff --git a/src/db/data.h b/src/db/data.h index 369e2bd..10aca85 100644 --- a/src/db/data.h +++ b/src/db/data.h @@ -45,6 +45,7 @@ #include "wire/json/fwd.h" #include "wire/msgpack/fwd.h" #include "wire/traits.h" +#include "wire/wrapper/array.h" namespace lws { @@ -138,7 +139,8 @@ namespace db using index_range = std::array; //! Ranges within a major index - using index_ranges = std::vector; + using min_index_ranges = wire::min_element_size<2>; + using index_ranges = wire::array_, min_index_ranges>; //! Compatible with msgpack_table using subaddress_dict = std::pair; diff --git a/src/db/storage.cpp b/src/db/storage.cpp index 4a68e05..b641acf 100644 --- a/src/db/storage.cpp +++ b/src/db/storage.cpp @@ -64,12 +64,21 @@ #include "wire/wrapper/array.h" #include "wire/wrappers_impl.h" +namespace wire +{ + template + static bool operator<(const array_& lhs, const array_& rhs) + { + return lhs.get_container() < rhs.get_container(); + } +} + namespace lws { namespace db { namespace v0 - { + { //! Orignal DB value, with no txn fee struct output { @@ -1156,7 +1165,7 @@ namespace db ); } - static void write_bytes(wire::json_writer& dest, const std::pair>>>>& self) + static void write_bytes(wire::json_writer& dest, const std::pair>& self) { wire::object(dest, wire::field("id", std::cref(self.first)), @@ -2756,9 +2765,9 @@ namespace db const auto add_out = [&out] (major_index major, index_range minor) { if (out.empty() || out.back().first != major) - out.emplace_back(major, index_ranges{minor}); + out.emplace_back(major, index_ranges{std::vector{minor}}); else - out.back().second.push_back(minor); + out.back().second.get_container().push_back(minor); }; const auto check_max_range = [&subaddr_count, max_subaddr] (const index_range& range) -> bool @@ -2771,7 +2780,7 @@ namespace db }; const auto check_max_ranges = [&check_max_range] (const index_ranges& ranges) -> bool { - for (const auto& range : ranges) + for (const auto& range : ranges.get_container()) { if (!check_max_range(range)) return false; @@ -2802,7 +2811,7 @@ namespace db for (auto& major_entry : subaddrs) { - new_dict.clear(); + new_dict.get_container().clear(); if (!check_subaddress_dict(major_entry)) { MERROR("Invalid subaddress_dict given to storage::upsert_subaddrs"); @@ -2826,8 +2835,8 @@ namespace db if (!old_dict) return old_dict.error(); - auto& old_range = old_dict->second; - const auto& new_range = major_entry.second; + auto& old_range = old_dict->second.get_container(); + const auto& new_range = major_entry.second.get_container(); auto old_loc = old_range.begin(); auto new_loc = new_range.begin(); @@ -2838,13 +2847,13 @@ namespace db if (!check_max_range(*new_loc)) return {error::max_subaddresses}; - new_dict.push_back(*new_loc); + new_dict.get_container().push_back(*new_loc); add_out(major_entry.first, *new_loc); ++new_loc; } else if (std::uint64_t(old_loc->at(1)) + 1 < std::uint32_t(new_loc->at(0))) { // existing has no overlap with new - new_dict.push_back(*old_loc); + new_dict.get_container().push_back(*old_loc); ++old_loc; } else if (old_loc->at(0) <= new_loc->at(0) && new_loc->at(1) <= old_loc->at(1)) @@ -2873,17 +2882,17 @@ namespace db } } - std::copy(old_loc, old_range.end(), std::back_inserter(new_dict)); + std::copy(old_loc, old_range.end(), std::back_inserter(new_dict.get_container())); for ( ; new_loc != new_range.end(); ++new_loc) { if (!check_max_range(*new_loc)) return {error::max_subaddresses}; - new_dict.push_back(*new_loc); + new_dict.get_container().push_back(*new_loc); add_out(major_entry.first, *new_loc); } } - for (const auto& new_indexes : new_dict) + for (const auto& new_indexes : new_dict.get_container()) { for (std::uint64_t minor : boost::counting_range(std::uint64_t(new_indexes[0]), std::uint64_t(new_indexes[1]) + 1)) { diff --git a/src/rest_server.cpp b/src/rest_server.cpp index 57ef4eb..fd3bc88 100644 --- a/src/rest_server.cpp +++ b/src/rest_server.cpp @@ -708,7 +708,7 @@ namespace lws for (std::uint64_t elem : boost::counting_range(std::uint64_t(major_i), std::uint64_t(major_i) + n_major)) { ranges.emplace_back( - db::major_index(elem), db::index_ranges{db::index_range{db::minor_index(minor_i), db::minor_index(minor_i + n_minor - 1)}} + db::major_index(elem), db::index_ranges{{db::index_range{db::minor_index(minor_i), db::minor_index(minor_i + n_minor - 1)}}} ); } auto upserted = disk.upsert_subaddresses(id, req.creds.address, req.creds.key, ranges, options.max_subaddresses); diff --git a/src/rpc/admin.cpp b/src/rpc/admin.cpp index 98a986c..cc05c95 100644 --- a/src/rpc/admin.cpp +++ b/src/rpc/admin.cpp @@ -49,7 +49,7 @@ namespace wire { static void write_bytes(wire::writer& dest, const std::pair>& self) { - wire::object(dest, wire::field<0>("key", self.first), wire::field<1>("value", self.second)); + wire::object(dest, wire::field<0>("key", std::cref(self.first)), wire::field<1>("value", std::cref(self.second))); } } @@ -119,8 +119,14 @@ namespace template void read_addresses(wire::reader& source, T& self, U... field) { + using min_address_size = + wire::min_element_sizeof; + std::vector addresses; - wire::object(source, wire::field("addresses", std::ref(addresses)), std::move(field)...); + wire::object(source, + wire::field("addresses", wire::array(std::ref(addresses))), + std::move(field)... + ); self.addresses.reserve(addresses.size()); for (const auto& elem : addresses) diff --git a/src/rpc/client.cpp b/src/rpc/client.cpp index 761317e..f68f73d 100644 --- a/src/rpc/client.cpp +++ b/src/rpc/client.cpp @@ -58,6 +58,8 @@ namespace rpc constexpr const char minimal_chain_topic[] = "json-minimal-chain_main"; constexpr const char full_txpool_topic[] = "json-full-txpool_add"; constexpr const int daemon_zmq_linger = 0; + constexpr const std::int64_t max_msg_sub = 10 * 1024 * 1024; // 50 MiB + constexpr const std::int64_t max_msg_req = 350 * 1024 * 1024; // 350 MiB constexpr const std::chrono::seconds chain_poll_timeout{20}; constexpr const std::chrono::minutes chain_sub_timeout{4}; @@ -166,13 +168,20 @@ namespace rpc MONERO_ZMQ_CHECK(zmq_setsockopt(signal_sub, ZMQ_SUBSCRIBE, signal, sizeof(signal) - 1)); return success(); } + + template + expect do_set_option(void* sock, const int option, const T value) noexcept + { + MONERO_ZMQ_CHECK(zmq_setsockopt(sock, option, std::addressof(value), sizeof(value))); + return success(); + } } // anonymous namespace detail { struct context { - explicit context(zcontext comm, socket signal_pub, socket external_pub, rcontext rmq, std::string daemon_addr, std::string sub_addr, std::chrono::minutes interval) + explicit context(zcontext comm, socket signal_pub, socket external_pub, rcontext rmq, std::string daemon_addr, std::string sub_addr, std::chrono::minutes interval, bool untrusted_daemon) : comm(std::move(comm)) , signal_pub(std::move(signal_pub)) , external_pub(std::move(external_pub)) @@ -185,6 +194,7 @@ namespace rpc , cached{} , sync_pub() , sync_rates() + , untrusted_daemon(untrusted_daemon) { if (std::chrono::minutes{0} < cache_interval) rates_conn.set_server(crypto_compare.host, boost::none, epee::net_utils::ssl_support_t::e_ssl_support_enabled); @@ -202,6 +212,7 @@ namespace rpc rates cached; boost::mutex sync_pub; boost::mutex sync_rates; + const bool untrusted_daemon; }; } // detail @@ -254,14 +265,15 @@ namespace rpc { MONERO_PRECOND(ctx != nullptr); - int option = daemon_zmq_linger; client out{std::move(ctx)}; out.daemon.reset(zmq_socket(out.ctx->comm.get(), ZMQ_REQ)); if (out.daemon.get() == nullptr) return net::zmq::get_error_code(); + MONERO_CHECK(do_set_option(out.daemon.get(), ZMQ_LINGER, daemon_zmq_linger)); + if (out.ctx->untrusted_daemon) + MONERO_CHECK(do_set_option(out.daemon.get(), ZMQ_MAXMSGSIZE, max_msg_req)); MONERO_ZMQ_CHECK(zmq_connect(out.daemon.get(), out.ctx->daemon_addr.c_str())); - MONERO_ZMQ_CHECK(zmq_setsockopt(out.daemon.get(), ZMQ_LINGER, &option, sizeof(option))); if (!out.ctx->sub_addr.empty()) { @@ -269,6 +281,8 @@ namespace rpc if (out.daemon_sub.get() == nullptr) return net::zmq::get_error_code(); + if (out.ctx->untrusted_daemon) + MONERO_CHECK(do_set_option(out.daemon_sub.get(), ZMQ_MAXMSGSIZE, max_msg_sub)); MONERO_ZMQ_CHECK(zmq_connect(out.daemon_sub.get(), out.ctx->sub_addr.c_str())); MONERO_CHECK(do_subscribe(out.daemon_sub.get(), minimal_chain_topic)); MONERO_CHECK(do_subscribe(out.daemon_sub.get(), full_txpool_topic)); @@ -424,7 +438,7 @@ namespace rpc return ctx->cached; } - context context::make(std::string daemon_addr, std::string sub_addr, std::string pub_addr, rmq_details rmq_info, std::chrono::minutes rates_interval) + context context::make(std::string daemon_addr, std::string sub_addr, std::string pub_addr, rmq_details rmq_info, std::chrono::minutes rates_interval, const bool untrusted_daemon) { zcontext comm{zmq_init(1)}; if (comm == nullptr) @@ -502,7 +516,7 @@ namespace rpc return context{ std::make_shared( - std::move(comm), std::move(pub), std::move(external_pub), std::move(rmq), std::move(daemon_addr), std::move(sub_addr), rates_interval + std::move(comm), std::move(pub), std::move(external_pub), std::move(rmq), std::move(daemon_addr), std::move(sub_addr), rates_interval, untrusted_daemon ) }; } diff --git a/src/rpc/client.h b/src/rpc/client.h index 8735c62..502c885 100644 --- a/src/rpc/client.h +++ b/src/rpc/client.h @@ -204,8 +204,10 @@ namespace rpc \param rmq_info Required information for RMQ publishing (if enabled) \param rates_interval Frequency to retrieve exchange rates. Set value to `<= 0` to disable exchange rate retrieval. + \param True if additional size constraints should be placed on + daemon messages */ - static context make(std::string daemon_addr, std::string sub_addr, std::string pub_addr, rmq_details rmq_info, std::chrono::minutes rates_interval); + static context make(std::string daemon_addr, std::string sub_addr, std::string pub_addr, rmq_details rmq_info, std::chrono::minutes rates_interval, const bool untrusted_daemon); context(context&&) = default; context(context const&) = delete; diff --git a/src/rpc/daemon_pub.cpp b/src/rpc/daemon_pub.cpp index 944d3c7..a809127 100644 --- a/src/rpc/daemon_pub.cpp +++ b/src/rpc/daemon_pub.cpp @@ -34,19 +34,24 @@ #include "wire/field.h" #include "wire/traits.h" #include "wire/json/read.h" +#include "wire/wrapper/array.h" +#include "wire/wrappers_impl.h" namespace { + using max_txes_pub = wire::max_element_count<775>; + struct dummy_chain_array { using value_type = crypto::hash; - std::uint64_t count; + std::size_t count = 0; std::reference_wrapper id; void clear() noexcept {} void reserve(std::size_t) noexcept {} + std::size_t size() const noexcept { return count; } crypto::hash& back() noexcept { return id; } void emplace_back() { ++count; } }; @@ -88,7 +93,7 @@ namespace rpc static void read_bytes(wire::json_reader& source, full_txpool_pub& self) { - wire_read::array(source, self.txes); + wire_read::bytes(source, wire::array(std::ref(self.txes))); } expect full_txpool_pub::from_json(std::string&& source) diff --git a/src/rpc/daemon_zmq.cpp b/src/rpc/daemon_zmq.cpp index 681d66f..1c05ca3 100644 --- a/src/rpc/daemon_zmq.cpp +++ b/src/rpc/daemon_zmq.cpp @@ -28,11 +28,14 @@ #include "daemon_zmq.h" #include +#include "cryptonote_config.h" // monero/src #include "crypto/crypto.h" // monero/src #include "rpc/message_data_structs.h" // monero/src #include "wire/crypto.h" #include "wire/json.h" +#include "wire/wrapper/array.h" #include "wire/wrapper/variant.h" +#include "wire/wrappers_impl.h" #include "wire/vector.h" namespace @@ -43,6 +46,17 @@ namespace constexpr const std::size_t default_outputs = 4; constexpr const std::size_t default_txextra_size = 2048; constexpr const std::size_t default_txpool_size = 32; + + using max_blocks_per_fetch = + wire::max_element_count; + + //! Not the default in cryptonote, but roughly a 31.8 MiB block + using max_txes_per_block = wire::max_element_count<21845>; + + using max_inputs_per_tx = wire::max_element_count<3000>; + using max_outputs_per_tx = wire::max_element_count<2000>; + using max_ring_size = wire::max_element_count<4600>; + using max_txpool_size = wire::max_element_count<775>; } namespace rct @@ -65,7 +79,11 @@ namespace rct static void read_bytes(wire::json_reader& source, mgSig& self) { - wire::object(source, WIRE_FIELD(ss), WIRE_FIELD(cc)); + using max_256 = wire::max_element_count<256>; + wire::object(source, + wire::field("ss", wire::array(std::ref(self.ss))), + WIRE_FIELD(cc) + ); } static void read_bytes(wire::json_reader& source, BulletproofPlus& self) @@ -142,13 +160,20 @@ namespace rct void read_bytes(wire::json_reader& source, prunable_helper& self) { + using rf_min_size = wire::min_element_sizeof; + using bf_max = wire::max_element_count; + using bf_plus_max = wire::max_element_count; + using mlsags_max = wire::max_element_count<256>; + using clsags_max = wire::max_element_count<256>; + using pseudo_outs_max = wire::max_element_count<256>; + wire::object(source, - wire::field("range_proofs", std::ref(self.prunable.rangeSigs)), - wire::field("bulletproofs", std::ref(self.prunable.bulletproofs)), - wire::field("bulletproofs_plus", std::ref(self.prunable.bulletproofs_plus)), - wire::field("mlsags", std::ref(self.prunable.MGs)), - wire::field("clsags", std::ref(self.prunable.CLSAGs)), - wire::field("pseudo_outs", std::ref(self.pseudo_outs)) + wire::field("range_proofs", wire::array(std::ref(self.prunable.rangeSigs))), + wire::field("bulletproofs", wire::array(std::ref(self.prunable.bulletproofs))), + wire::field("bulletproofs_plus", wire::array(std::ref(self.prunable.bulletproofs_plus))), + wire::field("mlsags", wire::array(std::ref(self.prunable.MGs))), + wire::field("clsags", wire::array(std::ref(self.prunable.CLSAGs))), + wire::field("pseudo_outs", wire::array(std::ref(self.pseudo_outs))) ); const bool pruned = @@ -166,15 +191,16 @@ namespace rct static void read_bytes(wire::json_reader& source, rctSig& self) { - boost::optional> ecdhInfo; - boost::optional outPk; + using min_ecdh = wire::min_element_sizeof; + using min_ctkey = wire::min_element_sizeof; + boost::optional txnFee; boost::optional prunable; self.outPk.reserve(default_inputs); wire::object(source, WIRE_FIELD(type), - wire::optional_field("encrypted", std::ref(ecdhInfo)), - wire::optional_field("commitments", std::ref(outPk)), + wire::optional_field("encrypted", wire::array(std::ref(self.ecdhInfo))), + wire::optional_field("commitments", wire::array(std::ref(self.outPk))), wire::optional_field("fee", std::ref(txnFee)), wire::optional_field("prunable", std::ref(prunable)) ); @@ -182,13 +208,11 @@ namespace rct self.txnFee = 0; if (self.type != RCTTypeNull) { - if (!ecdhInfo || !outPk || !txnFee) + if (self.ecdhInfo.empty() || self.outPk.empty() || !txnFee) WIRE_DLOG_THROW(wire::error::schema::missing_key, "Expected fields `encrypted`, `commitments`, and `fee`"); - self.ecdhInfo = std::move(*ecdhInfo); - self.outPk = std::move(*outPk); self.txnFee = std::move(*txnFee); } - else if (ecdhInfo || outPk || txnFee) + else if (!self.ecdhInfo.empty() || !self.outPk.empty() || txnFee) WIRE_DLOG_THROW(wire::error::schema::invalid_key, "Did not expected `encrypted`, `commitments`, or `fee`"); if (prunable) @@ -243,7 +267,11 @@ namespace cryptonote } static void read_bytes(wire::json_reader& source, txin_to_key& self) { - wire::object(source, WIRE_FIELD(amount), WIRE_FIELD(key_offsets), wire::field("key_image", std::ref(self.k_image))); + wire::object(source, + WIRE_FIELD(amount), + WIRE_FIELD_ARRAY(key_offsets, max_ring_size), + wire::field("key_image", std::ref(self.k_image)) + ); } static void read_bytes(wire::json_reader& source, txin_v& self) { @@ -264,35 +292,47 @@ namespace cryptonote wire::object(source, WIRE_FIELD(version), WIRE_FIELD(unlock_time), - wire::field("inputs", std::ref(self.vin)), - wire::field("outputs", std::ref(self.vout)), + wire::field("inputs", wire::array(std::ref(self.vin))), + wire::field("outputs", wire::array(std::ref(self.vout))), WIRE_FIELD(extra), + WIRE_FIELD_ARRAY(signatures, max_inputs_per_tx), wire::field("ringct", std::ref(self.rct_signatures)) ); } static void read_bytes(wire::json_reader& source, block& self) { + using min_hash_size = wire::min_element_sizeof; self.tx_hashes.reserve(default_transaction_count); wire::object(source, WIRE_FIELD(major_version), WIRE_FIELD(minor_version), WIRE_FIELD(timestamp), WIRE_FIELD(miner_tx), - WIRE_FIELD(tx_hashes), + WIRE_FIELD_ARRAY(tx_hashes, min_hash_size), WIRE_FIELD(prev_id), WIRE_FIELD(nonce) ); } - namespace rpc + static void read_bytes(wire::json_reader& source, std::vector& self) { + wire_read::array_unchecked(source, self, 0, max_txes_per_block{}); + } + + namespace rpc + { static void read_bytes(wire::json_reader& source, block_with_transactions& self) { self.transactions.reserve(default_transaction_count); wire::object(source, WIRE_FIELD(block), WIRE_FIELD(transactions)); } + static void read_bytes(wire::json_reader& source, std::vector& self) + { + wire_read::array_unchecked(source, self, 0, max_blocks_per_fetch{}); + } + static void read_bytes(wire::json_reader& source, tx_in_pool& self) { wire::object(source, WIRE_FIELD(tx), WIRE_FIELD(tx_hash)); @@ -310,11 +350,16 @@ void lws::rpc::read_bytes(wire::json_reader& source, get_blocks_fast_response& s { self.blocks.reserve(default_blocks_fetched); self.output_indices.reserve(default_blocks_fetched); - wire::object(source, WIRE_FIELD(blocks), WIRE_FIELD(output_indices), WIRE_FIELD(start_height), WIRE_FIELD(current_height)); + wire::object(source, + WIRE_FIELD(blocks), + wire::field("output_indices", wire::array(wire::array(wire::array(std::ref(self.output_indices))))), + WIRE_FIELD(start_height), + WIRE_FIELD(current_height) + ); } void lws::rpc::read_bytes(wire::json_reader& source, get_transaction_pool_response& self) { self.transactions.reserve(default_txpool_size); - wire::object(source, WIRE_FIELD(transactions)); + wire::object(source, WIRE_FIELD_ARRAY(transactions, max_txpool_size)); } diff --git a/src/rpc/light_wallet.cpp b/src/rpc/light_wallet.cpp index 3590784..5994fb1 100644 --- a/src/rpc/light_wallet.cpp +++ b/src/rpc/light_wallet.cpp @@ -50,6 +50,8 @@ namespace { + using max_subaddrs = wire::max_element_count<16384>; + enum class iso_timestamp : std::uint64_t {}; struct rct_bytes @@ -178,7 +180,7 @@ namespace lws } void rpc::read_bytes(wire::json_reader& source, safe_uint64_array& self) { - for (std::size_t count = source.start_array(); !source.is_array_end(count); --count) + for (std::size_t count = source.start_array(0); !source.is_array_end(count); --count) self.values.emplace_back(wire::integer::cast_unsigned(source.safe_unsigned_integer())); source.end_array(); } @@ -374,7 +376,7 @@ namespace lws wire::object(source, wire::field("address", std::ref(address)), wire::field("view_key", std::ref(unwrap(unwrap(self.creds.key)))), - WIRE_FIELD(subaddrs), + WIRE_FIELD_ARRAY(subaddrs, max_subaddrs), WIRE_OPTIONAL_FIELD(get_all) ); convert_address(address, self.creds.address); diff --git a/src/server_main.cpp b/src/server_main.cpp index c3f83f4..1b99539 100644 --- a/src/server_main.cpp +++ b/src/server_main.cpp @@ -281,7 +281,7 @@ namespace boost::filesystem::create_directories(prog.db_path); auto disk = lws::db::storage::open(prog.db_path.c_str(), prog.create_queue_max); - auto ctx = lws::rpc::context::make(std::move(prog.daemon_rpc), std::move(prog.daemon_sub), std::move(prog.zmq_pub), std::move(prog.rmq), prog.rates_interval); + auto ctx = lws::rpc::context::make(std::move(prog.daemon_rpc), std::move(prog.daemon_sub), std::move(prog.zmq_pub), std::move(prog.rmq), prog.rates_interval, prog.untrusted_daemon); MINFO("Using monerod ZMQ RPC at " << ctx.daemon_address()); auto client = lws::scanner::sync(disk.clone(), ctx.connect().value(), prog.untrusted_daemon).value(); diff --git a/src/wire/adapted/array.h b/src/wire/adapted/array.h index 5c98c4f..93535f6 100644 --- a/src/wire/adapted/array.h +++ b/src/wire/adapted/array.h @@ -76,7 +76,7 @@ namespace wire template inline void read_bytes(R& source, std::array& dest) { - std::size_t count = source.start_array(); + std::size_t count = source.start_array(0); const bool json = (count == 0); if (!json && count != dest.size()) WIRE_DLOG_THROW(wire::error::schema::array, "Expected array of size " << dest.size()); diff --git a/src/wire/error.cpp b/src/wire/error.cpp index 790d220..e3d52b6 100644 --- a/src/wire/error.cpp +++ b/src/wire/error.cpp @@ -42,6 +42,10 @@ namespace wire return "No schema errors"; case schema::array: return "Schema expected array"; + case schema::array_max_element: + return "Schema expected array size to be smaller"; + case schema::array_min_size: + return "Schema expected minimum wire size per array element to be larger"; case schema::binary: return "Schema expected binary value of variable size"; case schema::boolean: diff --git a/src/wire/error.h b/src/wire/error.h index c67a978..b16654c 100644 --- a/src/wire/error.h +++ b/src/wire/error.h @@ -54,6 +54,8 @@ namespace wire { none = 0, //!< Must be zero for `expect<..>` array, //!< Expected an array value + array_max_element,//!< Exceeded max array count + array_min_size, //!< Below min element wire size binary, //!< Expected a binary value of variable length boolean, //!< Expected a boolean value enumeration, //!< Expected a value from a specific set diff --git a/src/wire/json/read.cpp b/src/wire/json/read.cpp index a964c32..0043e09 100644 --- a/src/wire/json/read.cpp +++ b/src/wire/json/read.cpp @@ -48,13 +48,13 @@ namespace }; //! \throw std::system_error by converting `code` into a std::error_code - [[noreturn]] void throw_json_error(const epee::span source, const rapidjson::Reader& reader, const wire::error::schema expected) + [[noreturn]] void throw_json_error(const epee::span source, const rapidjson::Reader& reader, const wire::error::schema expected) { const std::size_t offset = std::min(source.size(), reader.GetErrorOffset()); const std::size_t start = offset;//std::max(snippet_size / 2, offset) - (snippet_size / 2); const std::size_t end = start + std::min(snippet_size, source.size() - start); - const boost::string_ref text{source.data() + start, end - start}; + const boost::string_ref text{reinterpret_cast(source.data()) + start, end - start}; const rapidjson::ParseErrorCode parse_error = reader.GetParseErrorCode(); switch (parse_error) { @@ -178,17 +178,19 @@ namespace wire void json_reader::read_next_value(rapidjson_sax& handler) { - rapidjson::InsituStringStream stream{current_.data()}; - if (!reader_.Parse(stream, handler)) - throw_json_error(current_, reader_, handler.expected_); - current_.remove_prefix(stream.Tell()); + rapidjson::MemoryStream stream{reinterpret_cast(remaining_.data()), remaining_.size()}; + rapidjson::EncodedInputStream, rapidjson::MemoryStream> istream{stream}; + if (!reader_.Parse(istream, handler)) + throw_json_error(remaining_, reader_, handler.expected_); + remaining_.remove_prefix(istream.Tell()); } char json_reader::get_next_token() { - rapidjson::InsituStringStream stream{current_.data()}; - rapidjson::SkipWhitespace(stream); - current_.remove_prefix(stream.Tell()); + rapidjson::MemoryStream stream{reinterpret_cast(remaining_.data()), remaining_.size()}; + rapidjson::EncodedInputStream, rapidjson::MemoryStream> istream{stream}; + rapidjson::SkipWhitespace(istream); + remaining_.remove_prefix(istream.Tell()); return stream.Peek(); } @@ -196,15 +198,15 @@ namespace wire { if (get_next_token() != '"') WIRE_DLOG_THROW_(error::schema::string); - current_.remove_prefix(1); + remaining_.remove_prefix(1); - void const* const end = std::memchr(current_.data(), '"', current_.size()); + void const* const end = std::memchr(remaining_.data(), '"', remaining_.size()); if (!end) WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorStringMissQuotationMark)); - char const* const begin = current_.data(); - const std::size_t length = current_.remove_prefix(static_cast(end) - current_.data() + 1); - return {begin, length - 1}; + std::uint8_t const* const begin = remaining_.data(); + const std::size_t length = remaining_.remove_prefix(static_cast(end) - remaining_.data() + 1); + return {reinterpret_cast(begin), length - 1}; } void json_reader::skip_value() @@ -214,11 +216,12 @@ namespace wire } json_reader::json_reader(std::string&& source) - : reader(), + : reader(nullptr), source_(std::move(source)), - current_(std::addressof(source_[0]), source_.size()), reader_() - {} + { + remaining_ = {reinterpret_cast(source_.data()), source_.size()}; + } void json_reader::check_complete() const { @@ -271,13 +274,13 @@ namespace wire { if (get_next_token() != '"') WIRE_DLOG_THROW_(error::schema::string); - current_.remove_prefix(1); + remaining_.remove_prefix(1); const std::uintmax_t out = unsigned_integer(); if (get_next_token() != '"') WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorStringMissQuotationMark)); - current_.remove_prefix(1); + remaining_.remove_prefix(1); return out; } @@ -316,11 +319,11 @@ namespace wire WIRE_DLOG_THROW(error::schema::fixed_binary, "of size" << dest.size() * 2 << " but got " << value.size()); } - std::size_t json_reader::start_array() + std::size_t json_reader::start_array(std::size_t) { if (get_next_token() != '[') WIRE_DLOG_THROW_(error::schema::array); - current_.remove_prefix(1); + remaining_.remove_prefix(1); increment_depth(); return 0; } @@ -332,7 +335,7 @@ namespace wire WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorArrayMissCommaOrSquareBracket)); if (next == ']') { - current_.remove_prefix(1); + remaining_.remove_prefix(1); return true; } @@ -340,7 +343,7 @@ namespace wire { if (next != ',') WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorArrayMissCommaOrSquareBracket)); - current_.remove_prefix(1); + remaining_.remove_prefix(1); } return false; } @@ -349,7 +352,7 @@ namespace wire { if (get_next_token() != '{') WIRE_DLOG_THROW_(error::schema::object); - current_.remove_prefix(1); + remaining_.remove_prefix(1); increment_depth(); return 0; } @@ -377,7 +380,7 @@ namespace wire WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorObjectMissCommaOrCurlyBracket)); if (next == '}') { - current_.remove_prefix(1); + remaining_.remove_prefix(1); return false; } @@ -386,7 +389,7 @@ namespace wire { if (next != ',') WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorObjectMissCommaOrCurlyBracket)); - current_.remove_prefix(1); + remaining_.remove_prefix(1); } ++state; @@ -395,7 +398,7 @@ namespace wire index = process_key(json_key.value.string); if (get_next_token() != ':') WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorObjectMissColon)); - current_.remove_prefix(1); + remaining_.remove_prefix(1); // parse value if (index != map.size()) diff --git a/src/wire/json/read.h b/src/wire/json/read.h index 7937276..6bbd23e 100644 --- a/src/wire/json/read.h +++ b/src/wire/json/read.h @@ -48,7 +48,6 @@ namespace wire struct rapidjson_sax; std::string source_; - epee::span current_; rapidjson::Reader reader_; void read_next_value(rapidjson_sax& handler); @@ -90,7 +89,7 @@ namespace wire //! \throw wire::exception if next token not `[`. - std::size_t start_array() override final; + std::size_t start_array(std::size_t) override final; //! Skips whitespace to next token. \return True if next token is eof or ']'. bool is_array_end(std::size_t count) override final; diff --git a/src/wire/msgpack/error.cpp b/src/wire/msgpack/error.cpp index a1278dd..8ddcc93 100644 --- a/src/wire/msgpack/error.cpp +++ b/src/wire/msgpack/error.cpp @@ -43,8 +43,12 @@ namespace error return "Unable to encode integer in msgpack"; case msgpack::invalid: return "Invalid msgpack encoding"; + case msgpack::max_tree_size: + return "Exceeded tag tracking amount"; case msgpack::not_enough_bytes: return "Expected more bytes in the msgpack stream"; + case msgpack::underflow_tree: + return "Expected more tags"; } return "Unknown msgpack error"; diff --git a/src/wire/msgpack/error.h b/src/wire/msgpack/error.h index 345d047..38e2786 100644 --- a/src/wire/msgpack/error.h +++ b/src/wire/msgpack/error.h @@ -40,7 +40,9 @@ namespace error incomplete, integer_encoding, invalid, - not_enough_bytes + max_tree_size, + not_enough_bytes, + underflow_tree }; //! \return Static string describing error `value`. diff --git a/src/wire/msgpack/read.cpp b/src/wire/msgpack/read.cpp index d7e762b..cf532a6 100644 --- a/src/wire/msgpack/read.cpp +++ b/src/wire/msgpack/read.cpp @@ -77,7 +77,7 @@ namespace //! \return Integer `T` encoded as big endian in `source`. template - T read_endian(epee::byte_slice& source) + T read_endian(epee::span& source) { static_assert(std::is_integral::value, "must be integral type"); static constexpr const std::size_t bits = 8 * sizeof(T); @@ -95,12 +95,12 @@ namespace //! \return Integer `T` encoded as big endian in `source`. template - T read_endian(epee::byte_slice& source, const wire::msgpack::type) + T read_endian(epee::span& source, const wire::msgpack::type) { return read_endian(source); } //! \return Integer `T` whose encoding is specified by tag `next` template - T read_integer(epee::byte_slice& source, const wire::msgpack::tag next) + T read_integer(epee::span& source, const wire::msgpack::tag next) { try { @@ -135,20 +135,21 @@ namespace WIRE_DLOG_THROW_(wire::error::schema::integer); } - epee::byte_slice read_raw(epee::byte_slice& source, const std::size_t bytes) + epee::span read_raw(epee::span& source, const std::size_t bytes) { if (source.size() < bytes) WIRE_DLOG_THROW_(wire::error::msgpack::not_enough_bytes); - return source.take_slice(bytes); + const std::size_t actual = source.remove_prefix(bytes); + return {source.data() - actual, actual}; } template - epee::byte_slice read_raw(epee::byte_slice& source) + epee::span read_raw(epee::span& source) { return read_raw(source, wire::integer::cast_unsigned(read_endian(source))); } - epee::byte_slice read_string(epee::byte_slice& source, const wire::msgpack::tag next) + epee::span read_string(epee::span& source, const wire::msgpack::tag next) { switch (next) { @@ -170,7 +171,7 @@ namespace } //! \return Binary blob encoded message - epee::byte_slice read_binary(epee::byte_slice& source, const wire::msgpack::tag next) + epee::span read_binary(epee::span& source, const wire::msgpack::tag next) { switch (next) { @@ -189,21 +190,21 @@ namespace namespace wire { - void msgpack_reader::throw_logic_error() + void msgpack_reader::throw_wire_exception() { - throw std::logic_error{"Bug in msgpack_reader usage"}; + WIRE_DLOG_THROW_(error::msgpack::underflow_tree); } void msgpack_reader::skip_value() { - assert(remaining_); - if (limits::max() == remaining_) - throw std::runtime_error{"msgpack_reader exceeded tree tracking"}; + assert(tags_remaining_); + if (limits::max() == tags_remaining_) + WIRE_DLOG_THROW_(error::msgpack::max_tree_size); - const std::size_t initial = remaining_; + const std::size_t initial = tags_remaining_; do { - const std::size_t size = source_.size(); + const std::size_t size = remaining_.size(); const msgpack::tag next = peek_tag(); switch (next) { @@ -213,59 +214,59 @@ namespace wire case msgpack::tag::unused: case msgpack::tag::False: case msgpack::tag::True: - source_.remove_prefix(1); + remaining_.remove_prefix(1); break; case msgpack::tag::binary8: case msgpack::tag::binary16: case msgpack::tag::binary32: - source_.remove_prefix(1); - read_binary(source_, next); + remaining_.remove_prefix(1); + read_binary(remaining_, next); break; case msgpack::tag::extension8: - source_.remove_prefix(1); - read_raw(source_); - source_.remove_prefix(1); + remaining_.remove_prefix(1); + read_raw(remaining_); + remaining_.remove_prefix(1); break; case msgpack::tag::extension16: - source_.remove_prefix(1); - read_raw(source_); - source_.remove_prefix(1); + remaining_.remove_prefix(1); + read_raw(remaining_); + remaining_.remove_prefix(1); break; case msgpack::tag::extension32: - source_.remove_prefix(1); - read_raw(source_); - source_.remove_prefix(1); + remaining_.remove_prefix(1); + read_raw(remaining_); + remaining_.remove_prefix(1); break; case msgpack::tag::int8: case msgpack::tag::uint8: - source_.remove_prefix(2); + remaining_.remove_prefix(2); break; case msgpack::tag::int16: case msgpack::tag::uint16: case msgpack::tag::fixed_extension1: - source_.remove_prefix(3); + remaining_.remove_prefix(3); break; case msgpack::tag::int32: case msgpack::tag::uint32: case msgpack::tag::float32: - source_.remove_prefix(5); + remaining_.remove_prefix(5); break; case msgpack::tag::int64: case msgpack::tag::uint64: case msgpack::tag::float64: - source_.remove_prefix(9); + remaining_.remove_prefix(9); break; case msgpack::tag::fixed_extension2: - source_.remove_prefix(4); + remaining_.remove_prefix(4); break; case msgpack::tag::fixed_extension4: - source_.remove_prefix(6); + remaining_.remove_prefix(6); break; case msgpack::tag::fixed_extension8: - source_.remove_prefix(10); + remaining_.remove_prefix(10); break; case msgpack::tag::fixed_extension16: - source_.remove_prefix(18); + remaining_.remove_prefix(18); break; #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wswitch" @@ -273,8 +274,8 @@ namespace wire case msgpack::tag::string8: case msgpack::tag::string16: case msgpack::tag::string32: - source_.remove_prefix(1); - read_string(source_, next); + remaining_.remove_prefix(1); + read_string(remaining_, next); break; case msgpack::tag(0x90): case msgpack::tag(0x91): case msgpack::tag(0x92): case msgpack::tag(0x93): case msgpack::tag(0x94): case msgpack::tag(0x95): @@ -284,7 +285,7 @@ namespace wire case msgpack::tag(0x9f): case msgpack::tag::array16: case msgpack::tag::array32: - start_array(); + start_array(0); break; case msgpack::tag(0x80): case msgpack::tag(0x81): case msgpack::tag(0x82): case msgpack::tag(0x83): case msgpack::tag(0x84): case msgpack::tag(0x85): @@ -299,27 +300,27 @@ namespace wire #pragma GCC diagnostic pop }; - if (size == source_.size()) + if (size == remaining_.size()) { if (!msgpack::ftag_unsigned::matches(next) && !msgpack::ftag_signed::matches(next)) WIRE_DLOG_THROW_(error::msgpack::invalid); - source_.remove_prefix(1); + remaining_.remove_prefix(1); } - update_remaining(); - } while (initial <= remaining_); + update_tags_remaining(); + } while (initial <= tags_remaining_); } msgpack::tag msgpack_reader::peek_tag() { - if (source_.empty()) + if (remaining_.empty()) WIRE_DLOG_THROW_(error::msgpack::not_enough_bytes); - return msgpack::tag(*source_.data()); + return msgpack::tag(*remaining_.data()); } msgpack::tag msgpack_reader::get_tag() { const msgpack::tag next = peek_tag(); - source_.remove_prefix(1); + remaining_.remove_prefix(1); return next; } @@ -327,12 +328,12 @@ namespace wire { if (msgpack::ftag_signed::matches(next)) return *reinterpret_cast(std::addressof(next)); // special case - return read_integer(source_, next); + return read_integer(remaining_, next); } std::uintmax_t msgpack_reader::do_unsigned_integer(const msgpack::tag next) { - return read_integer(source_, next); + return read_integer(remaining_, next); } template @@ -347,7 +348,7 @@ namespace wire { if (type.Tag() == next) { - out = integer::cast_unsigned(read_endian(source_, type)); + out = integer::cast_unsigned(read_endian(remaining_, type)); return true; } return false; @@ -361,13 +362,13 @@ namespace wire void msgpack_reader::check_complete() const { - if (remaining_) + if (tags_remaining_) WIRE_DLOG_THROW_(error::msgpack::incomplete); } bool msgpack_reader::boolean() { - update_remaining(); + update_tags_remaining(); switch (get_tag()) { case msgpack::tag::True: @@ -382,14 +383,14 @@ namespace wire double msgpack_reader::real() { - update_remaining(); + update_tags_remaining(); const auto read_float = [this](auto value) { - if (source_.size() < sizeof(value)) + if (remaining_.size() < sizeof(value)) WIRE_DLOG_THROW_(error::msgpack::not_enough_bytes); - std::memcpy(std::addressof(value), source_.data(), sizeof(value)); - source_.remove_prefix(sizeof(value)); + std::memcpy(std::addressof(value), remaining_.data(), sizeof(value)); + remaining_.remove_prefix(sizeof(value)); return value; }; @@ -407,34 +408,38 @@ namespace wire std::string msgpack_reader::string() { - update_remaining(); - const epee::byte_slice bytes = read_string(source_, get_tag()); + update_tags_remaining(); + const epee::span bytes = read_string(remaining_, get_tag()); return std::string{reinterpret_cast(bytes.data()), bytes.size()}; } std::vector msgpack_reader::binary() { - update_remaining(); - const epee::byte_slice bytes = read_binary(source_, get_tag()); + update_tags_remaining(); + const epee::span bytes = read_binary(remaining_, get_tag()); return std::vector{bytes.begin(), bytes.end()}; } void msgpack_reader::binary(epee::span dest) { - update_remaining(); - const epee::byte_slice bytes = read_binary(source_, get_tag()); + update_tags_remaining(); + const epee::span bytes = read_binary(remaining_, get_tag()); if (dest.size() != bytes.size()) WIRE_DLOG_THROW(error::schema::fixed_binary, "of size " << dest.size() << " but got " << bytes.size()); std::memcpy(dest.data(), bytes.data(), dest.size()); } - std::size_t msgpack_reader::start_array() + std::size_t msgpack_reader::start_array(const std::size_t min_element_size) { const std::size_t upcoming = read_count(error::schema::array); - if (limits::max() - remaining_ < upcoming) - throw std::runtime_error{"Exceeded max tree tracking for msgpack_reader"}; - remaining_ += upcoming; + if (limits::max() - tags_remaining_ < upcoming) + WIRE_DLOG_THROW_(error::msgpack::max_tree_size); + if (min_element_size && (remaining_.size() / min_element_size) < upcoming) + WIRE_DLOG_THROW(error::schema::array, upcoming << " array elements of at least " << min_element_size << " bytes each exceeds " << remaining_.size() << " remaining bytes"); + + tags_remaining_ += upcoming; + increment_depth(); return upcoming; } @@ -442,7 +447,7 @@ namespace wire { if (count) return false; - update_remaining(); + update_tags_remaining(); return true; } @@ -451,10 +456,11 @@ namespace wire const std::size_t upcoming = read_count(error::schema::object); if (limits::max() / 2 < upcoming) - throw std::runtime_error{"Exceeded max object tracking for msgpack_reader"}; - if (limits::max() - remaining_ < upcoming * 2) - throw std::runtime_error{"Exceeded msgpack_reader:: tree tracking"}; - remaining_ += upcoming * 2; + WIRE_DLOG_THROW_(error::msgpack::max_tree_size); + if (limits::max() - tags_remaining_ < upcoming * 2) + WIRE_DLOG_THROW_(error::msgpack::max_tree_size); + tags_remaining_ += upcoming * 2; + increment_depth(); return upcoming; } @@ -463,14 +469,14 @@ namespace wire index = map.size(); for ( ;state; --state) { - update_remaining(); // for key + update_tags_remaining(); // for key const msgpack::tag next = get_tag(); const bool single = msgpack::ftag_unsigned::matches(next); if (single || matches(next)) { unsigned key = std::uint8_t(next); if (!single) - key = read_integer(source_, next); + key = read_integer(remaining_, next); for (const key_map& elem : map) { if (elem.id == key) @@ -482,7 +488,7 @@ namespace wire } else if (msgpack::ftag_string::matches(next) || matches(next)) { - const epee::byte_slice key = read_string(source_, next); + const epee::span key = read_string(remaining_, next); for (const key_map& elem : map) { const boost::string_ref elem_{elem.name}; @@ -503,7 +509,7 @@ namespace wire } skip_value(); } // until state == 0 - update_remaining(); // for end of object + update_tags_remaining(); // for end of object return false; } } diff --git a/src/wire/msgpack/read.h b/src/wire/msgpack/read.h index 2190684..6ad03a2 100644 --- a/src/wire/msgpack/read.h +++ b/src/wire/msgpack/read.h @@ -45,30 +45,30 @@ namespace wire class msgpack_reader : public reader { epee::byte_slice source_; - std::size_t remaining_; //!< Expected number of elements remaining + std::size_t tags_remaining_; //!< Expected number of elements remaining - //! \throw std::logic_error - [[noreturn]] void throw_logic_error(); - //! Decrement remaining_ if not zero, \throw std::logic_error when `remaining_ == 0`. - void update_remaining() + //! \throw wire::exception with `error::msgpack::underflow_tree` + [[noreturn]] void throw_wire_exception(); + //! Decrement tags_remaining_ if not zero, \throw std::logic_error when `tags_remaining_ == 0`. + void update_tags_remaining() { - if (remaining_) - --remaining_; + if (tags_remaining_) + --tags_remaining_; else - throw_logic_error(); + throw_wire_exception(); } //! Skips next value. \throw wire::exception if invalid JSON syntax. void skip_value(); - //! \return Next tag but leave `source_` untouched. + //! \return Next tag but leave `remaining_` untouched. msgpack::tag peek_tag(); - //! \return Next tag and remove first byte from `source_`. + //! \return Next tag and remove first byte from `remaining_`. msgpack::tag get_tag(); - //! \return Integer from `soure_` where positive fixed tag has been checked. + //! \return Integer from `remaining_` where positive fixed tag has been checked. std::intmax_t do_integer(msgpack::tag); - //! \return Integer from `source_` where fixed tag has been checked. + //! \return Integer from `remaining_` where fixed tag has been checked. std::uintmax_t do_unsigned_integer(msgpack::tag); //! \return Number of items determined by `T` fixed tag and `U` tuple of tags. @@ -77,8 +77,10 @@ namespace wire public: explicit msgpack_reader(epee::byte_slice&& source) - : reader(), source_(std::move(source)), remaining_(1) - {} + : reader(nullptr), source_(std::move(source)), tags_remaining_(1) + { + remaining_ = {source_.data(), source_.size()}; + } //! \throw wire::exception if JSON parsing is incomplete. void check_complete() const override final; @@ -89,7 +91,7 @@ namespace wire //! \throw wire::expception if next token not an integer. std::intmax_t integer() override final { - update_remaining(); + update_tags_remaining(); const msgpack::tag next = get_tag(); if (std::uint8_t(next) <= msgpack::ftag_unsigned::max()) return std::uint8_t(next); @@ -99,7 +101,7 @@ namespace wire //! \throw wire::exception if next token not an unsigned integer. std::uintmax_t unsigned_integer() override final { - update_remaining(); + update_tags_remaining(); const msgpack::tag next = get_tag(); if (std::uint8_t(next) <= msgpack::ftag_unsigned::max()) return std::uint8_t(next); @@ -120,7 +122,7 @@ namespace wire //! \throw wire::exception if next token not `[`. - std::size_t start_array() override final; + std::size_t start_array(std::size_t min_element_size) override final; //! \return true when `count == 0`. bool is_array_end(const std::size_t count) override final; diff --git a/src/wire/read.cpp b/src/wire/read.cpp index def3d2f..1474c9c 100644 --- a/src/wire/read.cpp +++ b/src/wire/read.cpp @@ -35,6 +35,13 @@ void wire::reader::increment_depth() WIRE_DLOG_THROW_(error::schema::maximum_depth); } +void wire::reader::decrement_depth() +{ + if (!depth_) + throw std::logic_error{"reader::decrement_depth() already at zero"}; + --depth_; +} + [[noreturn]] void wire::integer::throw_exception(std::intmax_t source, std::intmax_t min, std::intmax_t max) { static_assert( diff --git a/src/wire/read.h b/src/wire/read.h index 74d074c..1c44319 100644 --- a/src/wire/read.h +++ b/src/wire/read.h @@ -44,19 +44,27 @@ namespace wire { //! Interface for converting "wire" (byte) formats to C/C++ objects without a DOM. class reader - { + { std::size_t depth_; //!< Tracks number of recursive objects and arrays protected: - //! \throw wire::exception if max depth is reached - void increment_depth(); - void decrement_depth() noexcept { --depth_; } + epee::span remaining_; //!< Derived class tracks unprocessed bytes here + + reader(const epee::span remaining) noexcept + : depth_(0), remaining_(remaining) + {} reader(const reader&) = default; reader(reader&&) = default; reader& operator=(const reader&) = default; reader& operator=(reader&&) = default; + //! \throw wire::exception if max depth is reached + void increment_depth(); + + //! \throw std::logic_error if already `depth() == 0`. + void decrement_depth(); + public: struct key_map { @@ -70,16 +78,15 @@ namespace wire //! \return Assume delimited arrays in generic interface (some optimizations disabled) static constexpr std::true_type delimited_arrays() noexcept { return {}; } - reader() noexcept - : depth_(0) - {} - virtual ~reader() noexcept {} //! \return Number of recursive objects and arrays std::size_t depth() const noexcept { return depth_; } + //! \return Unprocessed bytes + epee::span remaining() const noexcept { return remaining_; } + //! \throw wire::exception if parsing is incomplete. virtual void check_complete() const = 0; @@ -104,14 +111,20 @@ namespace wire //! \throw wire::exception if next value cannot be read as binary into `dest`. virtual void binary(epee::span dest) = 0; - //! \throw wire::exception if next value not array - virtual std::size_t start_array() = 0; + /* \param min_element_size of each array element in any format - if known. + Derived types with explicit element count should verify available + space, and throw a `wire::exception` on issues. + \throw wire::exception if next value not array + \throw wire::exception if not enough bytes for all array elements + (with epee/msgpack which has specified number of elements). + \return Number of values to read before calling `is_array_end()` */ + virtual std::size_t start_array(std::size_t min_element_size) = 0; //! \return True if there is another element to read. virtual bool is_array_end(std::size_t count) = 0; //! \throw wire::exception if array end delimiter not present. - void end_array() noexcept { decrement_depth(); } + void end_array() { decrement_depth(); } //! \throw wire::exception if not object begin. \return State to be given to `key(...)` function. @@ -134,7 +147,7 @@ namespace wire */ virtual bool key(epee::span map, std::size_t& state, std::size_t& index) = 0; - void end_object() noexcept { decrement_depth(); } + void end_object() { decrement_depth(); } }; template @@ -247,28 +260,84 @@ namespace wire_read return {}; } + // Trap objects that do not have standard insertion functions + template + void array_insert(const R&, const T&...) noexcept + { + static_assert(std::is_same::value, "type T does not have a valid insertion function"); + } + + // Insert to sorted containers + template + inline auto array_insert(R& source, T& dest) -> decltype(dest.emplace_hint(dest.end(), std::declval()), bool(true)) + { + V val{}; + wire_read::bytes(source, val); + dest.emplace_hint(dest.end(), std::move(val)); + return true; + } + + // Insert into unsorted containers + template + inline auto array_insert(R& source, T& dest) -> decltype(dest.emplace_back(), dest.back(), bool(true)) + { + // more efficient to process the object in-place in many cases + dest.emplace_back(); + wire_read::bytes(source, dest.back()); + return true; + } + + // no compile-time checks for the array constraints template - inline void array(R& source, T& dest) + inline void array_unchecked(R& source, T& dest, const std::size_t min_element_size, const std::size_t max_element_count) { using value_type = typename T::value_type; - static_assert(!std::is_same::value, "read array of chars as binary"); + static_assert(!std::is_same::value, "read array of chars as string"); + static_assert(!std::is_same::value, "read array of signed chars as binary"); static_assert(!std::is_same::value, "read array of unsigned chars as binary"); - std::size_t count = source.start_array(); + std::size_t count = source.start_array(min_element_size); + + // quick check for epee/msgpack formats + if (max_element_count < count) + throw_exception(wire::error::schema::array_max_element, "", nullptr); + + // also checked by derived formats when count is known + if (min_element_size && (source.remaining().size() / min_element_size) < count) + throw_exception(wire::error::schema::array_min_size, "", nullptr); dest.clear(); wire::reserve(dest, count); bool more = count; + const std::size_t start_bytes = source.remaining().size(); while (more || !source.is_array_end(count)) { - dest.emplace_back(); - read_bytes(source, dest.back()); + // check for json/cbor formats + if (source.delimited_arrays() && max_element_count <= dest.size()) + throw_exception(wire::error::schema::array_max_element, "", nullptr); + + wire_read::array_insert(source, dest); --count; more &= bool(count); + + if (((start_bytes - source.remaining().size()) / dest.size()) < min_element_size) + throw_exception(wire::error::schema::array_min_size, "", nullptr); } - return source.end_array(); + source.end_array(); + } + + template::max()> + inline void array(R& source, T& dest, wire::min_element_size min_element_size, wire::max_element_count max_element_count = {}) + { + using value_type = typename T::value_type; + static_assert( + min_element_size.template check() || max_element_count.template check(), + "array unpacking memory issues" + ); + // each set of template args generates unique ASM, merge them down + array_unchecked(source, dest, min_element_size, max_element_count); } template @@ -413,7 +482,14 @@ namespace wire template inline std::enable_if_t::value> read_bytes(R& source, T& dest) { - wire_read::array(source, dest); + static constexpr const std::size_t wire_size = + default_min_element_size::value; + static_assert( + wire_size != 0, + "no sane default array constraints for the reader / value_type pair" + ); + + wire_read::array(source, dest, min_element_size{}); } template diff --git a/src/wire/traits.h b/src/wire/traits.h index 5f79c7a..0255949 100644 --- a/src/wire/traits.h +++ b/src/wire/traits.h @@ -84,6 +84,52 @@ namespace wire : is_array // all array types in old output engine were optional when empty {}; + //! A constraint for `wire_read::array` where a max of `N` elements can be read. + template + struct max_element_count + : std::integral_constant + { + // The threshold is low - min_element_size is a better constraint metric + static constexpr std::size_t max_bytes() noexcept { return 512 * 1024; } // 512 KiB + + //! \return True if `N` C++ objects of type `T` are below `max_bytes()` threshold. + template + static constexpr bool check() noexcept + { + return N <= (max_bytes() / sizeof(T)); + } + }; + + //! A constraint for `wire_read::array` where each element must use at least `N` bytes on the wire. + template + struct min_element_size + : std::integral_constant + { + static constexpr std::size_t max_ratio() noexcept { return 4; } + + //! \return True if C++ object of type `T` with minimum wire size `N` is below `max_ratio()`. + template + static constexpr bool check() noexcept + { + return N != 0 ? ((sizeof(T) / N) <= max_ratio()) : false; + } + }; + + /*! Trait used in `wire/read.h` for default `min_element_size` behavior based + on an array of `T` objects and `R` reader type. This trait can be used + instead of the `wire::array(...)` (and associated macros) functionality, as + it sets a global value. The last argument is for `enable_if`. */ + template + struct default_min_element_size + : std::integral_constant + {}; + + //! If `T` is a blob, a safe default for all formats is the size of the blob + template + struct default_min_element_size::value>> + : std::integral_constant + {}; + // example usage : `wire::sum(std::size_t(wire::available(fields))...)` inline constexpr int sum() noexcept @@ -96,6 +142,9 @@ namespace wire return head + sum(tail...); } + template + using min_element_sizeof = min_element_size; + //! If container has no `reserve(0)` function, this function is used template inline void reserve(const T&...) noexcept diff --git a/src/wire/wrappers_impl.h b/src/wire/wrappers_impl.h index 7166358..0274eaa 100644 --- a/src/wire/wrappers_impl.h +++ b/src/wire/wrappers_impl.h @@ -44,7 +44,25 @@ namespace wire { // see constraints directly above `array_` definition static_assert(std::is_same::value, "array_ must have a read constraint for memory purposes"); - wire_read::array(source, wrapper.get_read_object()); + } + + template + inline void read_bytes(R& source, array_>& wrapper) + { + using array_type = array_>; + using value_type = typename array_type::value_type; + using constraint = typename array_type::constraint; + static_assert(constraint::template check(), "max reserve bytes exceeded for element"); + wire_read::array(source, wrapper.get_read_object(), min_element_size<0>{}, constraint{}); + } + template + inline void read_bytes(R& source, array_>& wrapper) + { + using array_type = array_>; + using value_type = typename array_type::value_type; + using constraint = typename array_type::constraint; + static_assert(constraint::template check(), "max compression ratio exceeded for element"); + wire_read::array(source, wrapper.get_read_object(), constraint{}); } template diff --git a/tests/unit/db/data.test.cpp b/tests/unit/db/data.test.cpp index e40760f..ac670fc 100644 --- a/tests/unit/db/data.test.cpp +++ b/tests/unit/db/data.test.cpp @@ -35,41 +35,41 @@ LWS_CASE("db::data::check_subaddress_dict") EXPECT(lws::db::check_subaddress_dict( { lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{{lws::db::minor_index(0), lws::db::minor_index(0)}}} + lws::db::index_ranges{{lws::db::index_range{{lws::db::minor_index(0), lws::db::minor_index(0)}}}} } )); EXPECT(lws::db::check_subaddress_dict( { lws::db::major_index(0), - lws::db::index_ranges{ + lws::db::index_ranges{{ lws::db::index_range{{lws::db::minor_index(0), lws::db::minor_index(0)}}, lws::db::index_range{{lws::db::minor_index(2), lws::db::minor_index(10)}} - } + }} } )); EXPECT(!lws::db::check_subaddress_dict( { lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{{lws::db::minor_index(1), lws::db::minor_index(0)}}} + lws::db::index_ranges{{lws::db::index_range{{lws::db::minor_index(1), lws::db::minor_index(0)}}}} } )); EXPECT(!lws::db::check_subaddress_dict( { lws::db::major_index(0), - lws::db::index_ranges{ + lws::db::index_ranges{{ lws::db::index_range{{lws::db::minor_index(0), lws::db::minor_index(4)}}, lws::db::index_range{{lws::db::minor_index(1), lws::db::minor_index(10)}} - } + }} } )); EXPECT(!lws::db::check_subaddress_dict( { lws::db::major_index(0), - lws::db::index_ranges{ + lws::db::index_ranges{{ lws::db::index_range{{lws::db::minor_index(0), lws::db::minor_index(0)}}, lws::db::index_range{{lws::db::minor_index(1), lws::db::minor_index(10)}} - } + }} } )); diff --git a/tests/unit/db/subaddress.test.cpp b/tests/unit/db/subaddress.test.cpp index 9974829..9086537 100644 --- a/tests/unit/db/subaddress.test.cpp +++ b/tests/unit/db/subaddress.test.cpp @@ -56,7 +56,7 @@ namespace lws::db::cursor::subaddress_indexes cur = nullptr; for (const auto& major_entry : source) { - for (const auto& minor_entry : major_entry.second) + for (const auto& minor_entry : major_entry.second.get_container()) { for (std::uint64_t elem : boost::counting_range(std::uint64_t(minor_entry[0]), std::uint64_t(minor_entry[1]) + 1)) { @@ -96,7 +96,7 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}}} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 100); { @@ -105,9 +105,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(100)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(100)); check_address_map(lest_env, reader, user, subs); } @@ -121,9 +121,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(fetched.has_value()); EXPECT(fetched->size() == 1); EXPECT(fetched->at(0).first == lws::db::major_index(0)); - EXPECT(fetched->at(0).second.size() == 1); - EXPECT(fetched->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(fetched->at(0).second[0][1] == lws::db::minor_index(100)); + EXPECT(fetched->at(0).second.get_container().size() == 1); + EXPECT(fetched->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(fetched->at(0).second.get_container()[0][1] == lws::db::minor_index(100)); } SECTION("Upsert Appended") @@ -131,15 +131,15 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}}} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 100); EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(100)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(100)); { auto reader = MONERO_UNWRAP(db.start_read()); @@ -147,14 +147,14 @@ LWS_CASE("db::storage::upsert_subaddresses") } subs.back().second = - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}}; + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}}}; result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 200); EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(101)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(101)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); { auto reader = MONERO_UNWRAP(db.start_read()); @@ -162,7 +162,7 @@ LWS_CASE("db::storage::upsert_subaddresses") } subs.back().second = - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(201), lws::db::minor_index(201)}}; + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(201), lws::db::minor_index(201)}}}; result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 200); EXPECT(result.has_error()); EXPECT(result == lws::error::max_subaddresses); @@ -172,9 +172,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(fetched.has_value()); EXPECT(fetched->size() == 1); EXPECT(fetched->at(0).first == lws::db::major_index(0)); - EXPECT(fetched->at(0).second.size() == 1); - EXPECT(fetched->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(fetched->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(fetched->at(0).second.get_container().size() == 1); + EXPECT(fetched->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(fetched->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); } SECTION("Upsert Prepended") @@ -182,15 +182,15 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}}} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 100); EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(101)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(101)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); { auto reader = MONERO_UNWRAP(db.start_read()); @@ -198,7 +198,7 @@ LWS_CASE("db::storage::upsert_subaddresses") } subs.back().second = - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}}; + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}}}; result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 199); EXPECT(result.has_error()); @@ -208,9 +208,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(100)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(100)); lws::db::storage_reader reader = MONERO_UNWRAP(db.start_read()); check_address_map(lest_env, reader, user, subs); @@ -219,9 +219,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(fetched.has_value()); EXPECT(fetched->size() == 1); EXPECT(fetched->at(0).first == lws::db::major_index(0)); - EXPECT(fetched->at(0).second.size() == 1); - EXPECT(fetched->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(fetched->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(fetched->at(0).second.get_container().size() == 1); + EXPECT(fetched->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(fetched->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); } SECTION("Upsert Wrapped") @@ -229,15 +229,15 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}}} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 100); EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(101)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(101)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); { auto reader = MONERO_UNWRAP(db.start_read()); @@ -245,7 +245,7 @@ LWS_CASE("db::storage::upsert_subaddresses") } subs.back().second = - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(300)}}; + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(300)}}}; result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 299); EXPECT(result.has_error()); @@ -255,11 +255,11 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 2); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(100)); - EXPECT(result->at(0).second[1][0] == lws::db::minor_index(201)); - EXPECT(result->at(0).second[1][1] == lws::db::minor_index(300)); + EXPECT(result->at(0).second.get_container().size() == 2); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(100)); + EXPECT(result->at(0).second.get_container()[1][0] == lws::db::minor_index(201)); + EXPECT(result->at(0).second.get_container()[1][1] == lws::db::minor_index(300)); lws::db::storage_reader reader = MONERO_UNWRAP(db.start_read()); check_address_map(lest_env, reader, user, subs); @@ -267,9 +267,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(fetched.has_value()); EXPECT(fetched->size() == 1); EXPECT(fetched->at(0).first == lws::db::major_index(0)); - EXPECT(fetched->at(0).second.size() == 1); - EXPECT(fetched->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(fetched->at(0).second[0][1] == lws::db::minor_index(300)); + EXPECT(fetched->at(0).second.get_container().size() == 1); + EXPECT(fetched->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(fetched->at(0).second.get_container()[0][1] == lws::db::minor_index(300)); } SECTION("Upsert After") @@ -277,15 +277,15 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}}} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 100); EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(100)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(100)); { auto reader = MONERO_UNWRAP(db.start_read()); @@ -293,7 +293,7 @@ LWS_CASE("db::storage::upsert_subaddresses") } subs.back().second = - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(102), lws::db::minor_index(200)}}; + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(102), lws::db::minor_index(200)}}}; result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 198); EXPECT(result.has_error()); EXPECT(result == lws::error::max_subaddresses); @@ -302,9 +302,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(102)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(102)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); auto reader = MONERO_UNWRAP(db.start_read()); check_address_map(lest_env, reader, user, subs); @@ -312,11 +312,11 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(fetched.has_value()); EXPECT(fetched->size() == 1); EXPECT(fetched->at(0).first == lws::db::major_index(0)); - EXPECT(fetched->at(0).second.size() == 2); - EXPECT(fetched->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(fetched->at(0).second[0][1] == lws::db::minor_index(100)); - EXPECT(fetched->at(0).second[1][0] == lws::db::minor_index(102)); - EXPECT(fetched->at(0).second[1][1] == lws::db::minor_index(200)); + EXPECT(fetched->at(0).second.get_container().size() == 2); + EXPECT(fetched->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(fetched->at(0).second.get_container()[0][1] == lws::db::minor_index(100)); + EXPECT(fetched->at(0).second.get_container()[1][0] == lws::db::minor_index(102)); + EXPECT(fetched->at(0).second.get_container()[1][1] == lws::db::minor_index(200)); } SECTION("Upsert Before") @@ -324,15 +324,15 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)}}} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 100); EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(101)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(101)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); { auto reader = MONERO_UNWRAP(db.start_read()); @@ -340,7 +340,7 @@ LWS_CASE("db::storage::upsert_subaddresses") } subs.back().second = - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(99)}}; + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(99)}}}; result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 198); EXPECT(result.has_error()); EXPECT(result == lws::error::max_subaddresses); @@ -349,9 +349,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(99)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(99)); auto reader = MONERO_UNWRAP(db.start_read()); check_address_map(lest_env, reader, user, subs); @@ -359,11 +359,11 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(fetched.has_value()); EXPECT(fetched->size() == 1); EXPECT(fetched->at(0).first == lws::db::major_index(0)); - EXPECT(fetched->at(0).second.size() == 2); - EXPECT(fetched->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(fetched->at(0).second[0][1] == lws::db::minor_index(99)); - EXPECT(fetched->at(0).second[1][0] == lws::db::minor_index(101)); - EXPECT(fetched->at(0).second[1][1] == lws::db::minor_index(200)); + EXPECT(fetched->at(0).second.get_container().size() == 2); + EXPECT(fetched->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(fetched->at(0).second.get_container()[0][1] == lws::db::minor_index(99)); + EXPECT(fetched->at(0).second.get_container()[1][0] == lws::db::minor_index(101)); + EXPECT(fetched->at(0).second.get_container()[1][1] == lws::db::minor_index(200)); } SECTION("Upsert Encapsulated") @@ -371,15 +371,15 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(200)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(200)}}} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 200); EXPECT(result.has_value()); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index(0)); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(result->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); { auto reader = MONERO_UNWRAP(db.start_read()); @@ -387,7 +387,7 @@ LWS_CASE("db::storage::upsert_subaddresses") } subs.back().second = - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(5), lws::db::minor_index(99)}}; + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(5), lws::db::minor_index(99)}}}; result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 300); EXPECT(result.has_value()); EXPECT(result->size() == 0); @@ -398,9 +398,9 @@ LWS_CASE("db::storage::upsert_subaddresses") EXPECT(fetched.has_value()); EXPECT(fetched->size() == 1); EXPECT(fetched->at(0).first == lws::db::major_index(0)); - EXPECT(fetched->at(0).second.size() == 1); - EXPECT(fetched->at(0).second[0][0] == lws::db::minor_index(1)); - EXPECT(fetched->at(0).second[0][1] == lws::db::minor_index(200)); + EXPECT(fetched->at(0).second.get_container().size() == 1); + EXPECT(fetched->at(0).second.get_container()[0][0] == lws::db::minor_index(1)); + EXPECT(fetched->at(0).second.get_container()[0][1] == lws::db::minor_index(200)); } @@ -409,9 +409,9 @@ LWS_CASE("db::storage::upsert_subaddresses") std::vector subs{}; subs.emplace_back( lws::db::major_index(0), - lws::db::index_ranges{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}} + lws::db::index_ranges{{lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(100)}}} ); - subs.back().second.push_back( + subs.back().second.get_container().push_back( lws::db::index_range{lws::db::minor_index(101), lws::db::minor_index(200)} ); auto result = db.upsert_subaddresses(lws::db::account_id(1), user.account, user.view, subs, 100); diff --git a/tests/unit/scanner.test.cpp b/tests/unit/scanner.test.cpp index 18faa8d..52ac9c0 100644 --- a/tests/unit/scanner.test.cpp +++ b/tests/unit/scanner.test.cpp @@ -321,7 +321,7 @@ LWS_CASE("lws::scanner::sync and lws::scanner::run") { lws::scanner::reset(); auto rpc = - lws::rpc::context::make(rendevous, {}, {}, {}, std::chrono::minutes{0}); + lws::rpc::context::make(rendevous, {}, {}, {}, std::chrono::minutes{0}, false); lws::db::test::cleanup_db on_scope_exit{}; @@ -412,7 +412,7 @@ LWS_CASE("lws::scanner::sync and lws::scanner::run") lws::db::subaddress_dict{ lws::db::major_index::primary, lws::db::index_ranges{ - lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(2)} + {lws::db::index_range{lws::db::minor_index(1), lws::db::minor_index(2)}} } } }; @@ -421,12 +421,12 @@ LWS_CASE("lws::scanner::sync and lws::scanner::run") EXPECT(result); EXPECT(result->size() == 1); EXPECT(result->at(0).first == lws::db::major_index::primary); - EXPECT(result->at(0).second.size() == 1); - EXPECT(result->at(0).second.at(0).size() == 2); - EXPECT(result->at(0).second.at(0).at(0) == lws::db::minor_index(1)); - EXPECT(result->at(0).second.at(0).at(1) == lws::db::minor_index(2)); - } - + EXPECT(result->at(0).second.get_container().size() == 1); + EXPECT(result->at(0).second.get_container().at(0).size() == 2); + EXPECT(result->at(0).second.get_container().at(0).at(0) == lws::db::minor_index(1)); + EXPECT(result->at(0).second.get_container().at(0).at(1) == lws::db::minor_index(2)); + } + std::vector destinations; destinations.emplace_back(); destinations.back().amount = 8000; diff --git a/tests/unit/wire/json/read.write.test.cpp b/tests/unit/wire/json/read.write.test.cpp index f67f035..e0e451e 100644 --- a/tests/unit/wire/json/read.write.test.cpp +++ b/tests/unit/wire/json/read.write.test.cpp @@ -10,6 +10,8 @@ #include "wire/json/read.h" #include "wire/json/write.h" #include "wire/vector.h" +#include "wire/wrapper/array.h" +#include "wire/wrappers_impl.h" #include "wire/base.test.h" @@ -31,7 +33,8 @@ namespace template void basic_object_map(F& format, T& self) { - wire::object(format, WIRE_FIELD(utf8), WIRE_FIELD(vec), WIRE_FIELD(data), WIRE_FIELD(choice)); + using max_vec = wire::max_element_count<100>; + wire::object(format, WIRE_FIELD(utf8), WIRE_FIELD_ARRAY(vec, max_vec), WIRE_FIELD(data), WIRE_FIELD(choice)); } template diff --git a/tests/unit/wire/msgpack/read.write.test.cpp b/tests/unit/wire/msgpack/read.write.test.cpp index e91eebc..8ad2521 100644 --- a/tests/unit/wire/msgpack/read.write.test.cpp +++ b/tests/unit/wire/msgpack/read.write.test.cpp @@ -30,9 +30,12 @@ #include #include #include +#include "wire/field.h" #include "wire/traits.h" #include "wire/msgpack.h" #include "wire/vector.h" +#include "wire/wrapper/array.h" +#include "wire/wrappers_impl.h" #include "wire/base.test.h" @@ -65,9 +68,10 @@ namespace template void basic_object_map(F& format, T& self) { + using vec_max = wire::max_element_count<100>; wire::object(format, WIRE_FIELD_ID(0, utf8), - WIRE_FIELD_ID(1, vec), + wire::field<1>("vec", wire::array(std::ref(self.vec))), WIRE_FIELD_ID(2, data), WIRE_FIELD_ID(254, choice) ); diff --git a/tests/unit/wire/read.write.test.cpp b/tests/unit/wire/read.write.test.cpp index 1607413..94d3c4c 100644 --- a/tests/unit/wire/read.write.test.cpp +++ b/tests/unit/wire/read.write.test.cpp @@ -36,6 +36,8 @@ #include "wire/json.h" #include "wire/msgpack.h" #include "wire/vector.h" +#include "wire/wrapper/array.h" +#include "wire/wrappers_impl.h" #include "wire/base.test.h" @@ -70,12 +72,13 @@ namespace template void complex_map(F& format, T& self) { + using max_vec = wire::max_element_count<100>; wire::object(format, - WIRE_FIELD(objects), - WIRE_FIELD(ints), - WIRE_FIELD(uints), + WIRE_FIELD_ARRAY(objects, max_vec), + WIRE_FIELD_ARRAY(ints, max_vec), + WIRE_FIELD_ARRAY(uints, max_vec), WIRE_FIELD(blobs), - WIRE_FIELD(strings), + WIRE_FIELD_ARRAY(strings, max_vec), WIRE_FIELD(choice) ); }