Author: | Wojciech Muła |
---|---|
Added on: | 2021-02-14 |
Updated on: | 2021-02-18 (add new, faster implementation suggested by Travis Downs; fixed info about port usage, as noted by InstLatX64) |
Contents
This is a follow up to article SIMDized counting byte in byte stream. In this article only AVX512BW variants are discussed. Performance is analyzed only for the Skylake-X CPU.
We want to count how many times given byte appears in a byte stream. The following C++ code shows the naive algorithm:
size_t countbyte(uint8_t* data, size_t size, uint8_t b) { size_t result = 0; for (size_t i=0; i < size; i++) result += (data[i] == b); return result; }
AVX512 allows to express the problem in a really elegant way. Unlike other common SIMD extensions, the result of a vector comparison in AVX512 is a bitmask stored in a mask register. A mask register in AVX512BW is a 64-bit word.
We only have to produce such bitmask for equals to comparison, and then count bits in the bitmask. The bit count operation is cheap (1 cycle latency on current CPUs).
The following C++ procedure shows the actual implementation.
uint64_t avx512bw_count_bytes(const uint8_t* data, size_t size, uint8_t byte) { const uint8_t* end = data + size; const uint8_t* ptr = data; uint64_t sum = 0; const __m512i v = _mm512_set1_epi8(byte); while (ptr + 64 < end) { const __m512i in = _mm512_loadu_si512((const __m512i*)ptr); sum += __builtin_popcountll(_mm512_cmpeq_epi8_mask(in, v)); ptr += 64; } return sum + scalar_count_bytes(ptr, end - ptr, byte); }
The key lines are:
const __m512i in = _mm512_loadu_si512((const __m512i*)ptr); sum += __builtin_popcountll(_mm512_cmpeq_epi8_mask(in, v));
Where
GCC 10.2 compiles the while loop into:
loop: vpcmpeqb -0x40(%rax),%zmm0,%k0 ; load & compare -> bitmask in k0 add $0x40,%rax ; increment the pointer `ptr` kmovq %k0,%rdx ; transfer from mask to generic purpose register rdx popcnt %rdx,%rdx ; count bits in rdx add %rdx,%rbx ; update `sum` variable cmp %rsi,%rax jb loop
The code is short, both in C++ and assembly. The loop consists only seven instruction. Also it is fast, and unrolling makes it even faster — see the results in the following section.
Even if we unroll the loop, we will be hit by congestion on execution ports. Both instructions vpcmpeqb and kmovq can be executed only on port #5. While latency of both AVX512 instructions is only one CPU cycle, popcnt has three cycles of latency.
Obviously this is not bad, but it made me think if we can do better. The main idea is to avoid moving data between register sets. In the original approach we are using 1) vector, 2) mask and 3) generic purpose registers.
The goal is to keep the comparison result in a vector register and follow the algorithm used for AVX2 and SSE instruction set. This algorithm updates byte counters 255 times, and then sums byte counters in some wider accumulator to avoid overflows on bytes. Below is the main loop of AVX2 implementation:
while (ptr + 256*32 < end) { local_sum = _mm256_setzero_si256(); // update 32 x 8-bit counter for (int i=0; i < 255; i++, ptr += 32) { const __m256i in = _mm256_loadu_si256((const __m256i*)ptr); const __m256i eq = _mm256_cmpeq_epi8(in, v); // 0 or -1 local_sum = _mm256_sub_epi8(local_sum, eq); } // update the global accumulator 2 x 64-bit const __m256i tmp = _mm256_sad_epu8(local_sum, _mm256_setzero_si256()); global_sum = _mm256_add_epi64(global_sum, tmp); }
Porting that procedure to AVX512 is straightforward, the only missing equivalent is _mm256_cmpeq_epi8. Fortunately, it's quite easy to build one.
We use the fact that if two numbers x and y are equal then x xor y = 0. If the operands are not equal, then x xor y has some non-zero value. Note that if bytes are interpreted as unsigned numbers, than non-zero values are in range [1..255].
Now the question is, can we reduce such unspecified non-zero value into something nice, like -1 from the AVX2 case? Yes, we can use the minimum operator. Our expression is min(x xor y, 1):
AVX512BW defines vpminub (_mm512_min_epu8) that calculates the minimum value of unsigned bytes, thus the translation of the byte-wise algorithm is possible. Below is the C++ implementation:
while (ptr + 64*255 < end) { __m512i local_sum = _mm512_set1_epi8(-1); // 255 // update 64 x 8-bit counter for (int i=0; i < 255; i++, ptr += 64) { const __m512i in = _mm512_loadu_si512((const __m512i*)ptr); const __m512i t0 = _mm512_xor_si512(in, v); const __m512i t1 = _mm512_min_epu8(t0, v_01); // 0 if equal, 1 otherwise local_sum = _mm512_sub_epi8(local_sum, t1); } // update the global accumulator 8 x 64-bit const __m512i tmp = _mm512_sad_epu8(local_sum, _mm512_setzero_si512()); vector_acc = _mm512_add_epi64(vector_acc, tmp); }
It's almost identical to the AVX2 version, with just two exceptions:
GCC 10.2 compiles the while loop into:
loop: vpxord (%rax),%zmm2,%zmm0 ; load & xor add $0x40,%rax ; increment the pointer `ptr` mov %rdx,%rdi vpminub %zmm0,%zmm3,%zmm0 ; calulate min(xor-result, 1) vpsubb %zmm0,%zmm1,%zmm0 ; update `local_sum` vmovdqa64 %zmm0,%zmm1 cmp %rdx,%rax jne loop
The vector instructions that appear in the listing have following characteristics:
When the loop is unrolled, then it is more likely that vector instructions will be run in parallel.
The third variant is a modification of the bytemask algorithm. The byte counters are updated with by a masked operation. True, we get back to mask registers, however loop unrolling hides longer latencies of the masked operation.
This modification was proposed by Travis Downs, I merely coded it.
Below is the unrolled while loop:
const __m512i v = _mm512_set1_epi8(byte); const __m512i v_01 = _mm512_set1_epi8(0x01); __m512i global_sum = _mm512_setzero_si512(); while (ptr + 64 * (4*63) < end) { __m512i local_sum0 = _mm512_setzero_si512(); __m512i local_sum1 = _mm512_setzero_si512(); __m512i local_sum2 = _mm512_setzero_si512(); __m512i local_sum3 = _mm512_setzero_si512(); // update 64 x 8-bit counter for (int i=0; i < 63; i++, ptr += 4*64) { const __m512i in0 = _mm512_loadu_si512((const __m512i*)(ptr + 0*64)); const __m512i in1 = _mm512_loadu_si512((const __m512i*)(ptr + 1*64)); const __m512i in2 = _mm512_loadu_si512((const __m512i*)(ptr + 2*64)); const __m512i in3 = _mm512_loadu_si512((const __m512i*)(ptr + 3*64)); const __mmask64 eq0 = _mm512_cmpeq_epi8_mask(in0, v); const __mmask64 eq1 = _mm512_cmpeq_epi8_mask(in1, v); const __mmask64 eq2 = _mm512_cmpeq_epi8_mask(in2, v); const __mmask64 eq3 = _mm512_cmpeq_epi8_mask(in3, v); local_sum0 = _mm512_mask_add_epi8(local_sum0, eq0, local_sum0, v_01); local_sum1 = _mm512_mask_add_epi8(local_sum1, eq1, local_sum1, v_01); local_sum2 = _mm512_mask_add_epi8(local_sum2, eq2, local_sum2, v_01); local_sum3 = _mm512_mask_add_epi8(local_sum3, eq3, local_sum3, v_01); } local_sum0 = _mm512_add_epi8(local_sum0, local_sum1); local_sum2 = _mm512_add_epi8(local_sum2, local_sum3); local_sum0 = _mm512_add_epi8(local_sum0, local_sum2); const __m512i tmp = _mm512_sad_epu8(local_sum0, _mm512_setzero_si512()); global_sum = _mm512_add_epi64(global_sum, tmp); }
GCC 10.2 compiles the while loop into:
loop: vpcmpeqb (%rdi),%zmm1,%k4 ; load & compare vmovdqa64 %zmm2,%zmm6 vmovdqa64 %zmm3,%zmm9 add $0x100,%rdi ; increment the pointer `ptr` vpcmpeqb -0xc0(%rdi),%zmm1,%k3 ; load & compare vpcmpeqb -0x80(%rdi),%zmm1,%k2 ; load & compare vmovdqa64 %zmm4,%zmm7 vpcmpeqb -0x40(%rdi),%zmm1,%k1 ; load & compare vmovdqa64 %zmm5,%zmm8 vpaddb %zmm2,%zmm0,%zmm6{%k4} ; update local_sum vpaddb %zmm3,%zmm0,%zmm9{%k3} ; update local_sum vpaddb %zmm4,%zmm0,%zmm7{%k2} ; update local_sum vmovdqa64 %zmm6,%zmm2 vpaddb %zmm5,%zmm0,%zmm8{%k1} ; update local_sum vmovdqa64 %zmm9,%zmm3 vmovdqa64 %zmm7,%zmm4 vmovdqa64 %zmm8,%zmm5 cmp %rax,%rdi jne loop
CPU: Intel(R) Xeon(R) W-2104 CPU @ 3.20GHz
Compiler: GCC version 8.4.0 (Ubuntu 8.4.0-1ubuntu1~16.04.1)
Procedures:
procedure | time in cycles per byte | speed-up | ||
---|---|---|---|---|
average | best | |||
scalar | 0.458 | 0.435 | 1.0 | ██▋ |
AVX2 | 0.071 | 0.069 | 6.3 | ████████████████▊ |
AVX512BW | 0.061 | 0.057 | 7.6 | ████████████████████▎ |
AVX512BW (unrolled 4x) | 0.044 | 0.044 | 9.9 | ██████████████████████████▎ |
AVX512BW (ver 5) | 0.040 | 0.040 | 10.9 | █████████████████████████████ |
AVX512BW (ver 5 unrolled 2x) | 0.033 | 0.032 | 13.6 | ████████████████████████████████████▎ |
AVX512BW (ver 5 unrolled 4x) | 0.035 | 0.035 | 12.4 | █████████████████████████████████▏ |
AVX512BW (ver 6 unrolled 4x) | 0.029 | 0.029 | 15.0 | ████████████████████████████████████████ |
All implementation are available at github.