Counting byte in byte stream with AVX512BW instructions

Author:Wojciech Muła
Added on:2021-02-14
Updated on:2021-02-18 (add new, faster implementation suggested by Travis Downs; fixed info about port usage, as noted by InstLatX64)

Contents

Introduction

This is a follow up to article SIMDized counting byte in byte stream. In this article only AVX512BW variants are discussed. Performance is analyzed only for the Skylake-X CPU.

We want to count how many times given byte appears in a byte stream. The following C++ code shows the naive algorithm:

size_t countbyte(uint8_t* data, size_t size, uint8_t b) {
    size_t result = 0;
    for (size_t i=0; i < size; i++)
        result += (data[i] == b);

    return result;
}

AVX512BW — bitmask algorithm

AVX512 allows to express the problem in a really elegant way. Unlike other common SIMD extensions, the result of a vector comparison in AVX512 is a bitmask stored in a mask register. A mask register in AVX512BW is a 64-bit word.

We only have to produce such bitmask for equals to comparison, and then count bits in the bitmask. The bit count operation is cheap (1 cycle latency on current CPUs).

The following C++ procedure shows the actual implementation.

uint64_t avx512bw_count_bytes(const uint8_t* data, size_t size, uint8_t byte) {

    const uint8_t* end = data + size;
    const uint8_t* ptr = data;

    uint64_t sum = 0;

    const __m512i v = _mm512_set1_epi8(byte);

    while (ptr + 64 < end) {
        const __m512i in = _mm512_loadu_si512((const __m512i*)ptr);
        sum += __builtin_popcountll(_mm512_cmpeq_epi8_mask(in, v));

        ptr += 64;
    }

    return sum + scalar_count_bytes(ptr, end - ptr, byte);
}

The key lines are:

const __m512i in = _mm512_loadu_si512((const __m512i*)ptr);
sum += __builtin_popcountll(_mm512_cmpeq_epi8_mask(in, v));

Where

GCC 10.2 compiles the while loop into:

loop:
    vpcmpeqb -0x40(%rax),%zmm0,%k0      ; load & compare -> bitmask in k0
    add    $0x40,%rax                   ; increment the pointer `ptr`
    kmovq  %k0,%rdx                     ; transfer from mask to generic purpose register rdx
    popcnt %rdx,%rdx                    ; count bits in rdx
    add    %rdx,%rbx                    ; update `sum` variable
    cmp    %rsi,%rax
    jb     loop

The code is short, both in C++ and assembly. The loop consists only seven instruction. Also it is fast, and unrolling makes it even faster — see the results in the following section.

AVX512BW — bytemask algorithm

Even if we unroll the loop, we will be hit by congestion on execution ports. Both instructions vpcmpeqb and kmovq can be executed only on port #5. While latency of both AVX512 instructions is only one CPU cycle, popcnt has three cycles of latency.

Obviously this is not bad, but it made me think if we can do better. The main idea is to avoid moving data between register sets. In the original approach we are using 1) vector, 2) mask and 3) generic purpose registers.

The goal is to keep the comparison result in a vector register and follow the algorithm used for AVX2 and SSE instruction set. This algorithm updates byte counters 255 times, and then sums byte counters in some wider accumulator to avoid overflows on bytes. Below is the main loop of AVX2 implementation:

while (ptr + 256*32 < end) {
    local_sum = _mm256_setzero_si256();

    // update 32 x 8-bit counter
    for (int i=0; i < 255; i++, ptr += 32) {
        const __m256i in = _mm256_loadu_si256((const __m256i*)ptr);
        const __m256i eq = _mm256_cmpeq_epi8(in, v); // 0 or -1

        local_sum = _mm256_sub_epi8(local_sum, eq);
    }

    // update the global accumulator 2 x 64-bit
    const __m256i tmp = _mm256_sad_epu8(local_sum, _mm256_setzero_si256());
    global_sum = _mm256_add_epi64(global_sum, tmp);
}

Porting that procedure to AVX512 is straightforward, the only missing equivalent is _mm256_cmpeq_epi8. Fortunately, it's quite easy to build one.

We use the fact that if two numbers x and y are equal then x xor y = 0. If the operands are not equal, then x xor y has some non-zero value. Note that if bytes are interpreted as unsigned numbers, than non-zero values are in range [1..255].

Now the question is, can we reduce such unspecified non-zero value into something nice, like -1 from the AVX2 case? Yes, we can use the minimum operator. Our expression is min(x xor y, 1):

AVX512BW defines vpminub (_mm512_min_epu8) that calculates the minimum value of unsigned bytes, thus the translation of the byte-wise algorithm is possible. Below is the C++ implementation:

while (ptr + 64*255 < end) {
    __m512i local_sum = _mm512_set1_epi8(-1); // 255

    // update 64 x 8-bit counter
    for (int i=0; i < 255; i++, ptr += 64) {
        const __m512i in = _mm512_loadu_si512((const __m512i*)ptr);
        const __m512i t0 = _mm512_xor_si512(in, v);
        const __m512i t1 = _mm512_min_epu8(t0, v_01); // 0 if equal, 1 otherwise

        local_sum = _mm512_sub_epi8(local_sum, t1);
    }

    // update the global accumulator 8 x 64-bit
    const __m512i tmp = _mm512_sad_epu8(local_sum, _mm512_setzero_si512());
    vector_acc = _mm512_add_epi64(vector_acc, tmp);
}

It's almost identical to the AVX2 version, with just two exceptions:

GCC 10.2 compiles the while loop into:

loop:
    vpxord (%rax),%zmm2,%zmm0           ; load & xor
    add    $0x40,%rax                   ; increment the pointer `ptr`
    mov    %rdx,%rdi
    vpminub %zmm0,%zmm3,%zmm0           ; calulate min(xor-result, 1)
    vpsubb %zmm0,%zmm1,%zmm0            ; update `local_sum`
    vmovdqa64 %zmm0,%zmm1
    cmp    %rdx,%rax
    jne    loop

The vector instructions that appear in the listing have following characteristics:

When the loop is unrolled, then it is more likely that vector instructions will be run in parallel.

AVX512BW — bytemask algorithm with mask registers

The third variant is a modification of the bytemask algorithm. The byte counters are updated with by a masked operation. True, we get back to mask registers, however loop unrolling hides longer latencies of the masked operation.

This modification was proposed by Travis Downs, I merely coded it.

Below is the unrolled while loop:

const __m512i v    = _mm512_set1_epi8(byte);
const __m512i v_01 = _mm512_set1_epi8(0x01);

__m512i global_sum = _mm512_setzero_si512();
while (ptr + 64 * (4*63) < end) {
    __m512i local_sum0 = _mm512_setzero_si512();
    __m512i local_sum1 = _mm512_setzero_si512();
    __m512i local_sum2 = _mm512_setzero_si512();
    __m512i local_sum3 = _mm512_setzero_si512();

    // update 64 x 8-bit counter
    for (int i=0; i < 63; i++, ptr += 4*64) {
        const __m512i   in0 = _mm512_loadu_si512((const __m512i*)(ptr + 0*64));
        const __m512i   in1 = _mm512_loadu_si512((const __m512i*)(ptr + 1*64));
        const __m512i   in2 = _mm512_loadu_si512((const __m512i*)(ptr + 2*64));
        const __m512i   in3 = _mm512_loadu_si512((const __m512i*)(ptr + 3*64));
        const __mmask64 eq0 = _mm512_cmpeq_epi8_mask(in0, v);
        const __mmask64 eq1 = _mm512_cmpeq_epi8_mask(in1, v);
        const __mmask64 eq2 = _mm512_cmpeq_epi8_mask(in2, v);
        const __mmask64 eq3 = _mm512_cmpeq_epi8_mask(in3, v);

        local_sum0 = _mm512_mask_add_epi8(local_sum0, eq0, local_sum0, v_01);
        local_sum1 = _mm512_mask_add_epi8(local_sum1, eq1, local_sum1, v_01);
        local_sum2 = _mm512_mask_add_epi8(local_sum2, eq2, local_sum2, v_01);
        local_sum3 = _mm512_mask_add_epi8(local_sum3, eq3, local_sum3, v_01);
    }

    local_sum0 = _mm512_add_epi8(local_sum0, local_sum1);
    local_sum2 = _mm512_add_epi8(local_sum2, local_sum3);
    local_sum0 = _mm512_add_epi8(local_sum0, local_sum2);

    const __m512i tmp = _mm512_sad_epu8(local_sum0, _mm512_setzero_si512());
    global_sum = _mm512_add_epi64(global_sum, tmp);
}

GCC 10.2 compiles the while loop into:

loop:
        vpcmpeqb (%rdi),%zmm1,%k4           ; load & compare
        vmovdqa64 %zmm2,%zmm6
        vmovdqa64 %zmm3,%zmm9
        add    $0x100,%rdi                  ; increment the pointer `ptr`
        vpcmpeqb -0xc0(%rdi),%zmm1,%k3      ; load & compare
        vpcmpeqb -0x80(%rdi),%zmm1,%k2      ; load & compare
        vmovdqa64 %zmm4,%zmm7
        vpcmpeqb -0x40(%rdi),%zmm1,%k1      ; load & compare
        vmovdqa64 %zmm5,%zmm8
        vpaddb %zmm2,%zmm0,%zmm6{%k4}       ; update local_sum
        vpaddb %zmm3,%zmm0,%zmm9{%k3}       ; update local_sum
        vpaddb %zmm4,%zmm0,%zmm7{%k2}       ; update local_sum
        vmovdqa64 %zmm6,%zmm2
        vpaddb %zmm5,%zmm0,%zmm8{%k1}       ; update local_sum
        vmovdqa64 %zmm9,%zmm3
        vmovdqa64 %zmm7,%zmm4
        vmovdqa64 %zmm8,%zmm5
        cmp    %rax,%rdi
        jne    loop

Experiment results from Skylake-X

CPU: Intel(R) Xeon(R) W-2104 CPU @ 3.20GHz

Compiler: GCC version 8.4.0 (Ubuntu 8.4.0-1ubuntu1~16.04.1)

Procedures:

procedure time in cycles per byte speed-up
  average best    
scalar 0.458 0.435 1.0 ██▋
AVX2 0.071 0.069 6.3 ████████████████▊
AVX512BW 0.061 0.057 7.6 ████████████████████▎
AVX512BW (unrolled 4x) 0.044 0.044 9.9 ██████████████████████████▎
AVX512BW (ver 5) 0.040 0.040 10.9 █████████████████████████████
AVX512BW (ver 5 unrolled 2x) 0.033 0.032 13.6 ████████████████████████████████████▎
AVX512BW (ver 5 unrolled 4x) 0.035 0.035 12.4 █████████████████████████████████▏
AVX512BW (ver 6 unrolled 4x) 0.029 0.029 15.0 ████████████████████████████████████████

Conclusions

Source code

All implementation are available at github.