AstroBWT speedup

This commit is contained in:
SChernykh 2021-10-18 18:05:51 +02:00
parent 7627b23212
commit 04f50c24e2

View file

@ -86,15 +86,54 @@ static void Salsa20_XORKeyStream_AVX256(const void* key, void* output, size_t si
} }
#endif #endif
void sort_indices(int N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indices) static inline bool smaller(const uint8_t* v, uint64_t a, uint64_t b)
{
const uint64_t value_a = a >> 21;
const uint64_t value_b = b >> 21;
if (value_a < value_b) {
return true;
}
if (value_a > value_b) {
return false;
}
a &= (1 << 21) - 1;
b &= (1 << 21) - 1;
if (a == b) {
return false;
}
const uint64_t data_a = bswap_64(*reinterpret_cast<const uint64_t*>(v + a + 5));
const uint64_t data_b = bswap_64(*reinterpret_cast<const uint64_t*>(v + b + 5));
return (data_a < data_b);
}
void sort_indices(uint32_t N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indices)
{ {
uint32_t counters[2][COUNTING_SORT_SIZE] = {}; uint32_t counters[2][COUNTING_SORT_SIZE] = {};
for (int i = 0; i < N; ++i)
{ {
const uint64_t k = bswap_64(*reinterpret_cast<const uint64_t*>(v + i)); #define ITER(X) \
++counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]; do { \
++counters[1][k >> (64 - COUNTING_SORT_BITS)]; const uint64_t k = bswap_64(*reinterpret_cast<const uint64_t*>(v + i + X)); \
++counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]; \
++counters[1][k >> (64 - COUNTING_SORT_BITS)]; \
} while (0)
uint32_t i = 0;
const uint32_t n = N - 15;
for (; i < n; i += 16) {
ITER(0); ITER(1); ITER(2); ITER(3); ITER(4); ITER(5); ITER(6); ITER(7);
ITER(8); ITER(9); ITER(10); ITER(11); ITER(12); ITER(13); ITER(14); ITER(15);
}
for (; i < N; ++i) {
ITER(0);
}
#undef ITER
} }
uint32_t prev[2] = { counters[0][0], counters[1][0] }; uint32_t prev[2] = { counters[0][0], counters[1][0] };
@ -109,41 +148,47 @@ void sort_indices(int N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indi
prev[1] = cur[1]; prev[1] = cur[1];
} }
for (int i = N - 1; i >= 0; --i)
{ {
const uint64_t k = bswap_64(*reinterpret_cast<const uint64_t*>(v + i)); #define ITER(X) \
tmp_indices[counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]--] = (k & (static_cast<uint64_t>(-1) << 21)) | i; do { \
} const uint64_t k = bswap_64(*reinterpret_cast<const uint64_t*>(v + (i - X))); \
tmp_indices[counters[0][(k >> (64 - COUNTING_SORT_BITS * 2)) & (COUNTING_SORT_SIZE - 1)]--] = (k & (static_cast<uint64_t>(-1) << 21)) | (i - X); \
} while (0)
for (int i = N - 1; i >= 0; --i) uint32_t i = N;
{ for (; i >= 8; i -= 8) {
const uint64_t data = tmp_indices[i]; ITER(1); ITER(2); ITER(3); ITER(4); ITER(5); ITER(6); ITER(7); ITER(8);
indices[counters[1][data >> (64 - COUNTING_SORT_BITS)]--] = data; }
} for (; i > 0; --i) {
ITER(1);
auto smaller = [v](uint64_t a, uint64_t b)
{
const uint64_t value_a = a >> 21;
const uint64_t value_b = b >> 21;
if (value_a < value_b) {
return true;
} }
if (value_a > value_b) { #undef ITER
return false; }
{
#define ITER(X) \
do { \
const uint64_t data = tmp_indices[i - X]; \
indices[counters[1][data >> (64 - COUNTING_SORT_BITS)]--] = data; \
} while (0)
uint32_t i = N;
for (; i >= 8; i -= 8) {
ITER(1); ITER(2); ITER(3); ITER(4); ITER(5); ITER(6); ITER(7); ITER(8);
}
for (; i > 0; --i) {
ITER(1);
} }
const uint64_t data_a = bswap_64(*reinterpret_cast<const uint64_t*>(v + (a % (1 << 21)) + 5)); #undef ITER
const uint64_t data_b = bswap_64(*reinterpret_cast<const uint64_t*>(v + (b % (1 << 21)) + 5)); }
return (data_a < data_b);
};
uint64_t prev_t = indices[0]; uint64_t prev_t = indices[0];
for (int i = 1; i < N; ++i) for (uint32_t i = 1; i < N; ++i)
{ {
uint64_t t = indices[i]; uint64_t t = indices[i];
if (smaller(t, prev_t)) if (smaller(v, t, prev_t))
{ {
const uint64_t t2 = prev_t; const uint64_t t2 = prev_t;
int j = i - 1; int j = i - 1;
@ -157,7 +202,7 @@ void sort_indices(int N, const uint8_t* v, uint64_t* indices, uint64_t* tmp_indi
} }
prev_t = indices[j]; prev_t = indices[j];
} while (smaller(t, prev_t)); } while (smaller(v, t, prev_t));
indices[j + 1] = t; indices[j + 1] = t;
t = t2; t = t2;
} }