diff --git a/src/crypto/astrobwt/AstroBWT.cpp b/src/crypto/astrobwt/AstroBWT.cpp index 5dee11657..4b3331f4b 100644 --- a/src/crypto/astrobwt/AstroBWT.cpp +++ b/src/crypto/astrobwt/AstroBWT.cpp @@ -86,15 +86,54 @@ static void Salsa20_XORKeyStream_AVX256(const void* key, void* output, size_t si } #endif -void sort_indices(int N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indices) +static inline bool smaller(const uint8_t* v, uint64_t a, uint64_t b) +{ + const uint64_t value_a = a >> 21; + const uint64_t value_b = b >> 21; + + if (value_a < value_b) { + return true; + } + + if (value_a > value_b) { + return false; + } + + a &= (1 << 21) - 1; + b &= (1 << 21) - 1; + + if (a == b) { + return false; + } + + const uint64_t data_a = bswap_64(*reinterpret_cast(v + a + 5)); + const uint64_t data_b = bswap_64(*reinterpret_cast(v + b + 5)); + return (data_a < data_b); +} + +void sort_indices(uint32_t N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indices) { uint32_t counters[2][COUNTING_SORT_SIZE] = {}; - for (int i = 0; i < N; ++i) { - const uint64_t k = bswap_64(*reinterpret_cast(v + i)); - ++counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]; - ++counters[1][k >> (64 - COUNTING_SORT_BITS)]; +#define ITER(X) \ + do { \ + const uint64_t k = bswap_64(*reinterpret_cast(v + i + X)); \ + ++counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]; \ + ++counters[1][k >> (64 - COUNTING_SORT_BITS)]; \ + } while (0) + + uint32_t i = 0; + const uint32_t n = N - 15; + for (; i < n; i += 16) { + ITER(0); ITER(1); ITER(2); ITER(3); ITER(4); ITER(5); ITER(6); ITER(7); + ITER(8); ITER(9); ITER(10); ITER(11); ITER(12); ITER(13); ITER(14); ITER(15); + } + for (; i < N; ++i) { + ITER(0); + } + +#undef ITER } uint32_t prev[2] = { counters[0][0], counters[1][0] }; @@ -109,41 +148,47 @@ void sort_indices(int N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indi prev[1] = cur[1]; } - for (int i = N - 1; i >= 0; --i) { - const uint64_t k = bswap_64(*reinterpret_cast(v + i)); - tmp_indices[counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]--] = (k & (static_cast(-1) << 21)) | i; - } +#define ITER(X) \ + do { \ + const uint64_t k = bswap_64(*reinterpret_cast(v + (i - X))); \ + tmp_indices[counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]--] = (k & (static_cast(-1) << 21)) | (i - X); \ + } while (0) - for (int i = N - 1; i >= 0; --i) - { - const uint64_t data = tmp_indices[i]; - indices[counters[1][data >> (64 - COUNTING_SORT_BITS)]--] = data; - } - - auto smaller = [v](uint64_t a, uint64_t b) - { - const uint64_t value_a = a >> 21; - const uint64_t value_b = b >> 21; - - if (value_a < value_b) { - return true; + uint32_t i = N; + for (; i >= 8; i -= 8) { + ITER(1); ITER(2); ITER(3); ITER(4); ITER(5); ITER(6); ITER(7); ITER(8); + } + for (; i > 0; --i) { + ITER(1); } - if (value_a > value_b) { - return false; +#undef ITER + } + + { +#define ITER(X) \ + do { \ + const uint64_t data = tmp_indices[i - X]; \ + indices[counters[1][data >> (64 - COUNTING_SORT_BITS)]--] = data; \ + } while (0) + + uint32_t i = N; + for (; i >= 8; i -= 8) { + ITER(1); ITER(2); ITER(3); ITER(4); ITER(5); ITER(6); ITER(7); ITER(8); + } + for (; i > 0; --i) { + ITER(1); } - const uint64_t data_a = bswap_64(*reinterpret_cast(v + (a % (1 << 21)) + 5)); - const uint64_t data_b = bswap_64(*reinterpret_cast(v + (b % (1 << 21)) + 5)); - return (data_a < data_b); - }; +#undef ITER + } uint64_t prev_t = indices[0]; - for (int i = 1; i < N; ++i) + for (uint32_t i = 1; i < N; ++i) { uint64_t t = indices[i]; - if (smaller(t, prev_t)) + if (smaller(v, t, prev_t)) { const uint64_t t2 = prev_t; int j = i - 1; @@ -157,7 +202,7 @@ void sort_indices(int N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indi } prev_t = indices[j]; - } while (smaller(t, prev_t)); + } while (smaller(v, t, prev_t)); indices[j + 1] = t; t = t2; }