How to detect if all bytes in SIMD register are the same?

Author:Wojciech Muła
Added on:2021-02-02
Updated on:2021-02-03 (added SSE variant #2)

Contents

Introduction

We want to detect if all bytes stored in a SIMD register (SSE, AVX2, AVX512, Neon etc.) are the same. For example for byte layout in an SSE register like this:

[42|42|42|42|42|42|42|42|42|42|42|42|42|42|42|42]

We see that all bytes are equal to 42. For this one not all bytes have the same value:

[42|42|42|42|42|42|42|42|42|42|42|42|03|42|42|42]

The algorithm which uses basic vector operations:

  1. broadcast the 0th byte of register into a new vector:

    input       = [42|42|42|42|42|42|42|42|42|42|42|42|03|42|42|42]
    broadcasted = [42|42|42|42|42|42|42|42|42|42|42|42|42|42|42|42]
    
  2. perform a vector-wide compare for equality:

    cmp         = (input == broadcasted)
                = [ff|ff|ff|ff|ff|ff|ff|ff|ff|ff|ff|ff|00|ff|ff|ff]
    
  3. check whether all elements of cmp vector are "true".

Depending on a SIMD flavour, these simple steps may not be that simple.

SSE

Broadcasting byte is done with _mm_shuffle_epi8. Then, a vector comparison yielding a byte mask is used (_mm_cmpeq_epi8). Finally, the byte-mask is converted into a bit-mask and tested with regular instructions.

bool all_equal(__m128i input) {
    const __m128i populated_0th_byte = _mm_shuffle_epi8(input, _mm_setzero_si128());
    const __m128i eq = _mm_cmpeq_epi8(input, populated_0th_byte);

    return (_mm_movemask_epi8(eq) == 0xffff);
}

Assembly code generated by GCC 10.2:

pxor          %xmm2, %xmm2      ; _mm_setzero_si128
movdqa        %xmm0, %xmm1
pshufb        %xmm2, %xmm0      ; _mm_shuffle_epi8
pcmpeqb       %xmm1, %xmm0      ; _mm_cmpeq_epi8
pmovmskb      %xmm0, %eax       ; _mm_movemask_epi8
cmpl          $65535, %eax      ; ... == 0xffff
sete          %al

SSE — variant #2

Travis Downs & Robert Clausecker noted that instead of broadcasting a byte we may also perform byte rotate by one (or any odd number) using _mm_alignr_epi8. Unfortunatelly, this is only applicable for SSE, because the AVX2 counterpart works on 128-bit lanes, rather the whole register.

Algorithm:

input       = [42|42|42|42|42|42|42|42|42|42|42|42|03|42|42|42]
rotated     = [42|42|42|42|42|42|42|42|42|42|42|03|42|42|42]42]
cmp         = (input == rotated)
            = [ff|ff|ff|ff|ff|ff|ff|ff|ff|ff|ff|00|00|ff|ff|ff]

A sample implementation:

bool all_equal(__m128i input) {
    const __m128i rotated = _mm_alignr_epi8(input, input, 1);
    const __m128i eq = _mm_cmpeq_epi8(input, rotated);

    return ((uint16_t)_mm_movemask_epi8(eq) == 0xffff);
}

Assembly code generated by GCC 10.2:

movdqa        %xmm0, %xmm1
palignr       $1, %xmm0, %xmm1              ; _mm_alignr_epi8
pcmpeqb       %xmm1, %xmm0                  ; _mm_cmpeq_epi8
pmovmskb      %xmm0, %eax                   ; _mm_movemask_epi8
cmpw          $-1, %ax
sete          %al

AVX2

In case of AVX2, a byte shuffle instruction works on indivdual 128-bit lanes. We use fact that SSE registers (xmm) are aliased to lower lanes of AVX/AVX2 registers (ymm). The intrinsic function _mm256_castsi256_si128 does not generate any instruction, it just satisfies C/C++ type systems.

The 0th byte is populated exactly like in the SSE variant, and then a new 256-bit vector is built from 128-bit lanes.

Finally, the comparison procedure is very similar to SSE's variant.

bool all_equal(__m256i input) {
    const __m128i lane0 = _mm256_castsi256_si128(input);
    const __m128i tmp   = _mm_shuffle_epi8(lane0, _mm_setzero_si128());
    const __m256i populated_0th_byte = _mm256_set_m128i(tmp, tmp);
    const __m256i eq = _mm256_cmpeq_epi8(input, populated_0th_byte);

    return ((uint32_t)_mm256_movemask_epi8(eq) == 0xffffffff);
}

Assembly code generated by GCC 10.2:

vmovdqa       %ymm0, %ymm2
vpxor         %xmm0, %xmm0, %xmm0           ; _mm_setzero_si128
vpshufb       %xmm0, %xmm2, %xmm0           ; _mm_shuffle_epi8
vinserti128   $1, %xmm0, %ymm0, %ymm0       ; _mm256_set_m128i (0th lane is already there)
vpcmpeqb      %ymm2, %ymm0, %ymm0           ; _mm256_cmpeq_epi8
vpmovmskb     %ymm0, %eax                   ; _mm256_movemask_epi8
cmpl          $-1, %eax
sete          %al

AVX512F

First of all, AVX512F does not support byte-level operations. Luckily, SSE registers (xmm) are mapped on the 0th lane of AVX512 registers (zmm). This, broadcasting of the 0th byte is done similarly in the AVX2 variant.

Lack of byte-level comparison is not a problem in our case, as we seek for equality of whole registers, it doesn't matter what units are used in comparison.

bool all_equal(__m512i input) {
    const __m128i lane0  = _mm512_castsi512_si128(input);
    const __m128i t0     = _mm_shuffle_epi8(lane0, _mm_setzero_si128());
    const __m512i populated_0th_byte = _mm512_broadcast_i32x4(t0);

    const __mmask16 mask = _mm512_cmpeq_epi32_mask(input, populated_0th_byte);

    return (mask == 0xffff);
}

Assembly code generated by GCC 10.2:

vpxor         %xmm2, %xmm2, %xmm2           ; _mm_setzero_si128
vmovdqa64     %zmm0, %zmm1
vpshufb       %xmm2, %xmm0, %xmm0           ; _mm_shuffle_epi8
vshufi32x4    $0x0, %zmm0, %zmm0, %zmm0     ; _mm512_broadcast_i32x4 (broadcast 0th lane)
vpcmpeqd      %zmm0, %zmm1, %k0             ; _mm512_cmpeq_epi32_mask
kmovw         %k0, %eax
cmpw          $-1, %ax
sete          %al

AVX512BW

AVX512BW let to express the problem directly as it was stated. There is a specialised instruction that broadcasts the 0th byte (_mm512_broadcastb_epi8/vpbroadcastb). Also the extension supports byte-level comparison, however it's not crucial in our problem.

bool all_equal(__m512i input) {
    const __m128i lane0  = _mm512_castsi512_si128(input);
    const __m512i populated_0th_byte = _mm512_broadcastb_epi8(lane0);
    const __mmask64 mask = _mm512_cmpeq_epu8_mask(input, populated_0th_byte);

    return (mask == 0xffffffffffffffffLU);
}

Assembly code generated by GCC 10.2:

vpbroadcastb  %xmm0, %zmm1                  ; _mm512_broadcastb_epi8
vpcmpub       $0, %zmm1, %zmm0, %k0         ; _mm512_cmpeq_epu8_mask
kmovq         %k0, %rax
cmpq          $-1, %rax
sete          %al

Better AVX512

Both AVX512BW and AVX512F variants can be a bit shorter in terms of assembly output. We want to tests whether bit-vector is full of ones. There's a dedicated CPU instruction _mm512_kortestc that works on the mask registers (k0, k1, ...) which does exactly the test we need. Below is psuedocode showing this instruction.

bool _mm512_kortestc(__mmask16 a, __mmask16 b) {
    return popcount(a | b) == 16;
}

Below is a AVX512BW code which uses that instruction.

bool all_equal(__m512i input) {
    const __m128i lane0  = _mm512_castsi512_si128(input);
    const __m512i populated_0th_byte = _mm512_broadcastb_epi8(lane0);
    const __mmask16 mask = _mm512_cmpeq_epi32_mask(input, populated_0th_byte);

    return _mm512_kortestc(mask, mask);
}

Assembly code generated by GCC 10.2:

vpbroadcastb        %xmm0, %zmm1        ; _mm512_broadcastb_epi8
vpcmpeqd            %zmm1, %zmm0, %k0   ; _mm512_cmpeq_epi32_mask
kortestw            %k0, %k0            ; _mm512_kortestc
setc  %al

Source code

All implementations are available at github.