diff --git a/src/crypto.cpp b/src/crypto.cpp index 9cf064c..def5e30 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -159,14 +159,14 @@ class Cache public: Cache() { - uv_mutex_init_checked(&derivations_lock); + uv_rwlock_init_checked(&derivations_lock); uv_rwlock_init_checked(&public_keys_lock); uv_rwlock_init_checked(&tx_keys_lock); } ~Cache() { - uv_mutex_destroy(&derivations_lock); + uv_rwlock_destroy(&derivations_lock); uv_rwlock_destroy(&public_keys_lock); uv_rwlock_destroy(&tx_keys_lock); } @@ -177,33 +177,45 @@ public: memcpy(index.data(), key1.h, HASH_SIZE); memcpy(index.data() + HASH_SIZE, key2.h, HASH_SIZE); + derivation = {}; { - MutexLock lock(derivations_lock); + ReadLock lock(derivations_lock); auto it = derivations.find(index); if (it != derivations.end()) { - derivation = it->second.m_derivation; - view_tag = it->second.get_view_tag(output_index); - return true; + const DerivationEntry& entry = it->second; + derivation = entry.m_derivation; + if (entry.find_view_tag(output_index, view_tag)) { + return true; + } } } - ge_p3 point; - ge_p2 point2; - ge_p1p1 point3; + if (derivation.empty()) { + ge_p3 point; + ge_p2 point2; + ge_p1p1 point3; - if (ge_frombytes_vartime(&point, key1.h) != 0) { - return false; + if (ge_frombytes_vartime(&point, key1.h) != 0) { + return false; + } + + ge_scalarmult(&point2, key2.h, &point); + ge_mul8(&point3, &point2); + ge_p1p1_to_p2(&point2, &point3); + ge_tobytes(reinterpret_cast(&derivation), &point2); } - ge_scalarmult(&point2, key2.h, &point); - ge_mul8(&point3, &point2); - ge_p1p1_to_p2(&point2, &point3); - ge_tobytes(reinterpret_cast(&derivation), &point2); + derive_view_tag(derivation, output_index, view_tag); { - MutexLock lock(derivations_lock); - auto result = derivations.emplace(index, DerivationEntry{ derivation, {} }); - view_tag = result.first->second.get_view_tag(output_index); + WriteLock lock(derivations_lock); + + DerivationEntry& entry = derivations.emplace(index, DerivationEntry{ derivation, {} }).first->second; + + const uint32_t k = static_cast(output_index << 8) | view_tag; + if (std::find(entry.m_viewTags.begin(), entry.m_viewTags.end(), k) == entry.m_viewTags.end()) { + entry.m_viewTags.emplace_back(k); + } } return true; @@ -285,7 +297,7 @@ public: void clear() { - { MutexLock lock(derivations_lock); derivations.clear(); } + { WriteLock lock(derivations_lock); derivations.clear(); } { WriteLock lock(public_keys_lock); public_keys.clear(); } { WriteLock lock(tx_keys_lock); tx_keys.clear(); } } @@ -296,22 +308,18 @@ private: hash m_derivation; std::vector m_viewTags; - uint8_t get_view_tag(size_t output_index) { + bool find_view_tag(size_t output_index, uint8_t& view_tag) const { for (uint32_t k : m_viewTags) { if ((k >> 8) == output_index) { - return static_cast(k); + view_tag = static_cast(k); + return true; } } - - uint8_t t; - derive_view_tag(m_derivation, output_index, t); - m_viewTags.emplace_back(static_cast(output_index << 8) | t); - - return t; + return false; } }; - uv_mutex_t derivations_lock; + uv_rwlock_t derivations_lock; unordered_map, DerivationEntry> derivations; uv_rwlock_t public_keys_lock;