SIMDization of switch statements

Author:Wojciech Muła
Added on:2019-02-03

Contents

Introduction

There are two main purposes of a switch statement:

  1. Express simple function that translate from one set of values into another, like getting a string representation of enum values.
  2. Dispatch different code sequences based on switch argument, as an alternative to "if-ladder".

Compilers usually transform switch statements using following approaches:

  1. Binary search on constant keys: a compiler emits series of comparisons and jumps interleaved with case code.
  2. When the key values span a small range (even non-continuous one), the values are used to index a lookup table of jump targets.

Of course, a compiler might optimize some specific cases, for some neat examples look at tree-switch-conversion.c from GCC.

However, switch statements can be expressed also with SIMD instructions. Vector instructions are used used to translate from the argument value into case ordinal number. Then, the index is used either to 1) fetch a value from a precalculated table, or 2) get a jump target (the address) which is used to dispatch code fragments.

Example of binary search code

Following C++ code is compiled by GCC 7.3.0 into a binary search procedure.

enum class Colour : uint32_t {
    RED     = 0x00ff0000,
    GREEN   = 0x0000ff00,
    BLUE    = 0x000000ff,
    WHITE   = 0x00ffffff,
    GRAY0   = 0x00333333,
    GRAY1   = 0x00aaaaaa,
    GRAY2   = 0x00dddddd,
    BLACK   = 0x00000000
}

int palette(Colour col) {
    switch (col) {
        case Colour::RED:
            return 0;
        case Colour::GREEN:
            return 1;
        case Colour::BLUE:
            return 2;
        case Colour::WHITE:
            return 3;
        case Colour::GRAY0:
            return 4;
        case Colour::GRAY1:
            return 5;
        case Colour::GRAY2:
            return 6;
        case Colour::BLACK:
            return 7;
        default:
            return -1;
    }
}
_Z7palette6Colour:
    cmpl    $3355443, %edi
    je      .L3
    jbe     .L24
    movl    $6, %eax
    cmpl    $14540253, %edi
    je      .L21
    jbe     .L25
    xorl    %eax, %eax
    cmpl    $16711680, %edi
    je      .L21
    movl    $3, %eax
    cmpl    $16777215, %edi
    jne     .L2
.L21:
    ret
.L24:
    movl    $2, %eax
    cmpl    $255, %edi
    je      .L21
    movl    $1, %eax
    cmpl    $65280, %edi
    je      .L21
    movl    $7, %eax
    testl   %edi, %edi
    je      .L26
.L2:
    movl    $-1, %eax
    ret
.L25:
    movl    $5, %eax
    cmpl    $11184810, %edi
    jne     .L2
    ret
.L3:
    movl    $4, %eax
    ret
.L26:
    ret

Example of jump lookup

#include <cstdio>

int code_block(int x) {

    int result = -1;

    switch (x) {
        case 0:
            puts("zero");
            break;

        case 3:
            puts("one");
            result = 42;
            break;

        case 4:
            puts("two");
            result = 42;
            break;

        case 7:
            puts("three");
            result = 1024;
            break;

        case 8:
            puts("four");
            result = 42;
            break;

        case 11:
            puts("five");
            result = 1024;
            break;
    }

    return result;
}
.LC0:
    .string "zero"
.LC1:
    .string "one"
.LC2:
    .string "two"
.LC3:
    .string "three"
.LC4:
    .string "four"
.LC5:
    .string "five"
    .text
_Z10code_blocki:
    cmpl    $11, %edi
    ja      .L13
    leaq    .L4(%rip), %rdx
    movl    %edi, %edi
    subq    $8, %rsp
    movslq  (%rdx,%rdi,4), %rax
    addq    %rdx, %rax
    jmp     *%rax
    .section        .rodata
.L4:
    .long   .L3-.L4
    .long   .L10-.L4
    .long   .L10-.L4
    .long   .L5-.L4
    .long   .L6-.L4
    .long   .L10-.L4
    .long   .L10-.L4
    .long   .L7-.L4
    .long   .L8-.L4
    .long   .L10-.L4
    .long   .L10-.L4
    .long   .L9-.L4
    .text
.L9:
    leaq    .LC5(%rip), %rdi
    call    puts@PLT
    movl    $1024, %eax
.L11:
    addq    $8, %rsp
    ret
.L3:
    leaq    .LC0(%rip), %rdi
    call    puts@PLT
    movl    $-1, %eax
    addq    $8, %rsp
    ret
.L5:
    leaq    .LC1(%rip), %rdi
    call    puts@PLT
    movl    $42, %eax
    addq    $8, %rsp
    ret
.L6:
    leaq    .LC2(%rip), %rdi
    call    puts@PLT
    movl    $42, %eax
    addq    $8, %rsp
    ret
.L7:
    leaq    .LC3(%rip), %rdi
    call    puts@PLT
    movl    $1024, %eax
    addq    $8, %rsp
    ret
.L8:
    leaq    .LC4(%rip), %rdi
    call    puts@PLT
    movl    $42, %eax
    addq    $8, %rsp
    ret
.L10:
    movl    $-1, %eax
    jmp     .L11
.L13:
    movl    $-1, %eax
    ret

SIMD approaches

SIMDization of plain function

Let's see how palette function might be vectorized:

int palette(Colour col) {
    switch (col) {
        case Colour::RED:   return 0;
        case Colour::GREEN: return 1;
        case Colour::BLUE:  return 2;
        case Colour::WHITE: return 3;
        case Colour::GRAY0: return 4;
        case Colour::GRAY1: return 5;
        case Colour::GRAY2: return 6;
        case Colour::BLACK: return 7;
        default:
            return -1;
    }
}
  1. Load the key values into a SIMD register. In this case we have exactly eight 32-bit numbers; when compiled to an AVX2 target, all values fit in single AVX register.

    const __m256i lookup = _mm256_setr_epi32(
        0x00ff0000, 0x0000ff00, 0x000000ff, 0x00ffffff,
        0x00333333, 0x00aaaaaa, 0x00dddddd, 0x00000000
    );
    
  2. Broadcast the switch argument in a SIMD register:

    const __m256i vec = _mm256_set1_epi32((uint32_t)arg);
    
  3. Compare the argument with all keys.

    const __m256i mask = _mm256_cmpeq_epi32(vec, lookup);
    
  4. Obtain a bitmask from the 32-bit mask.

    uint32_t bitmask = _mm256_movemask_ps((__m256)mask);
    
  5. If the bitmask is non-zero it means argument is equal to one of keys. The key index is determined by position of first (and only) bit set in bitmask. If bitmask is zero, we have to return the default value.

    if (bitmask)
        return __builtin_ctz(bitmask);
    else
        return -1;
    

In this example the function value is equal to index of key, but for more complex mappings we need an extra lookup table.

int palette_ANSI(Colour col) { // man console_codes
    switch (col) {
        case Colour::RED:   return 31;
        case Colour::GREEN: return 32;
        case Colour::BLUE:  return 34;
        case Colour::WHITE: return 37;
        case Colour::BLACK: return 30;
        default:
            return -1;
    }
}

The 5th point of the above algorithm would be like this:

// we dispatch five values, set 6th bit — bitmask will never be zero
bitmask = bitmask | 0x20;

// the result lookup includes also default value
const int result[6] = {31, 32, 34, 37, 30, -1};
const int index = __builtin_ctz(bitmask);

return result[index];

Please note one trick we used. If SIMD vector is not fully filled, and switch has got a default label, we might use one bit of the bitmask to include the default value in the result. Thanks to that an extra if is ridden off.

Implementation

#include <immintrin.h>
#include <cstdint>

enum class Colour: uint32_t {
    RED     = 0x00ff0000,
    GREEN   = 0x0000ff00,
    BLUE    = 0x000000ff,
    WHITE   = 0x00ffffff,
    GRAY0   = 0x00333333,
    GRAY1   = 0x00aaaaaa,
    GRAY2   = 0x00dddddd,
    BLACK   = 0x00000000
};

int palette_avx2(Colour col) {

    const __m256i vec = _mm256_set1_epi32((uint32_t)col);
    const __m256i lookup = _mm256_setr_epi32(
        0x00ff0000, 0x0000ff00, 0x000000ff, 0x00ffffff,
        0x00333333, 0x00aaaaaa, 0x00dddddd, 0x00000000
    );

    const __m256i mask = _mm256_cmpeq_epi32(vec, lookup);

    uint32_t bitmask = _mm256_movemask_ps((__m256)mask);
    if (bitmask) {
        return __builtin_ctz(bitmask);
    } else {
        // default value
        return -1;
    }
}

SIMD-ization of code dispatching

In procedure code_block the keys {0, 3, 4, 7, 8, 11} are mapped into addresses of code sequences or basic blocks in the compiler construction vocabulary.

Assembly code generated by GCC which does the mapping:

cmpl        $11, %edi
ja  .L13
leaq        .L4(%rip), %rdx
movslq      (%rdx,%rdi,4), %rax
jmp *%rax

Firstly, it determines whether the input is in range 0 .. 11. When it is, then fetch the address of case code and jump there.

Below is an AVX2 algorithm that does exactly the same.

  1. Load into SIMD register key values. This time we match six keys, so two 32-bit words in an AVX2 register are not used (filled with -1).

    const __m256i lookup = _mm256_setr_epi32(0, 3, 4, 7, 8, 11, -1, -1);
    
  2. Broadcast switch argument in AVX2 register.

    const __m256i vec = _mm256_set1_epi32(x);
    
  3. Compare the argument with all keys.

    const __m256i mask = _mm256_cmpeq_epi32(vec, lookup);
    
  4. Obtain a bitmask from the 32-bit mask. Here we also set the 7th bit of bitmask.

    uint32_t bitmask = _mm256_movemask_ps((__m256)mask) | 0x40;
    
  5. Get index of set bit and jump to appropriate label. Please note that the 7th label (end) points to basic block which is equivalent to a no-match condition in the switch statement (there's no default label in code).

    Here the C++ code uses a GCC extension which allows to store label addresses and jump to runtime-selected address. On the assembly code level this maps into indirect jump instruction.

    static void* labels[8] = {
        &&case_0, &&case_3, &&case_4, &&case_7, &&case_8, &&case_11,
        &&end, &&end
    };
    
    goto *labels[__builtin_ctz(bitmask)];
    

Implementation

#include <cstdio>
#include <cstdint>
#include <immintrin.h>

int code_block(int x) {

    int result = -1;

    static void* labels[8] = {
        &&case_0, &&case_3, &&case_4, &&case_7, &&case_8, &&case_11,
        &&case_default, &&case_default
    };

    const __m256i vec = _mm256_set1_epi32(x);
    const __m256i lookup = _mm256_setr_epi32(
        0, 3, 4, 7, 8, 11,
        -1, -1 // these two don't matter
    );

    const __m256i mask = _mm256_cmpeq_epi32(vec, lookup);

    uint32_t bitmask = _mm256_movemask_ps((__m256)mask) | 0xc0;
    goto *labels[__builtin_ctz(bitmask)];

case_0:
    puts("zero");
    goto end;

case_3:
    puts("one");
    result = 42;
    goto end;

case_4:
    puts("two");
    result = 42;
    goto end;

case_7:
    puts("three");
    result = 1024;
    goto end;

case_8:
    puts("four");
    result = 42;
    goto end;

case_11:
    puts("five");
    result = 1024;
    goto end;

case_default:
    result = -1;

end:
    return result;
}

Performance

In case of code involving jumps, it's hard to say anything. Performance of dispatching depends mostly on a branch predictor, as mispredicted branches can cost several cycles.

It's likely that in typical scenarios, where just one or two cases are really executed, SIMD code will not be faster.

However, for simple lookup functions, SIMD code is likely to be faster for switches with large number of cases. First of all, such a function is almost branch free, as the default case can be resolved also by the vector code. SIMD code depends only on the number of cases not their values.

Practical example

Use of SIMDized case in parsing of RFC dates.

Source code

Sample programs are available at github.