Average of two unsigned integers

Author:Wojciech Muła
Added on:2012-07-02

The famous error in Bentley's binary search code was caused by an integer overflow in following code:

unsinged a, b, c;

c = (a + b)/2

First the sum is evaluated, then divided (or shifted), but the sum could exceed the integer range. In unsafe languages, like Java, C or C++, such errors are not detected, while Ada would trigger an exception.

The safest expression, even in unsafe languages, looks like this:

unsigned a, b, c;
unsigned LSB_carry;

LSB_carry = a & b & 1;
c = a/2 + b/2 + LSB_carry;

We sum up all-but-the-lowest bits, then adjust this sum with carry from lowest bit (a & b & 1). This expression involves two shifts, 2 additions, and two ands, require also additional storage for carry.

Another approach detects an overflow without accessing to hardware registers:

unsigned a, b, c;       // sizeof(unsigned) = 4;
unsigned sum;

sum = a + b;
if (sum < a) { // or sum < b
        // overflow, combine "lost" carry bit from the highest bit
        c = (sum/2) | 0x8000000;
} else {
        c = sum/2;
}

This require on sum, one shift and conditionally or. The condition could be expressed branchlessly:

unsigned a, b, c;
unsigned sum;
unsigned MSB_carry;

sum = a + b;
MSB_carry = (unsigned)(sum < x); // 1 or 0

c = sum/2 | (MSB_carry << 31);

Sample program

Following python script varifies both methods (for 8-bit integers)

bits = 8    # unsigned width
mask = (1 << bits) - 1
MSB  = (1 << (bits - 1))

def base(a, b):
    return (a + b)/2


def safe1(a, b):
    LSB_carry = (a & b & 1)
    return (a >> 1) + (b >> 1) + LSB_carry


def safe2(a, b):
    sum = (a + b) & mask
    if sum < a:
        return (sum >> 1) | MSB
    else:
        return sum >> 1


def safe3(a, b):
    sum = (a + b) & mask
    MSB_carry = int(sum < a)

    return (sum >> 1) | (MSB_carry << (bits - 1))


def main():
    n = 1 << bits
    for a in xrange(n):
        for b in xrange(n):
            ref = base(a, b)
            r1  = safe1(a, b)
            r2  = safe2(a, b)
            r3  = safe3(a, b)

            assert ref == r1
            assert ref == r2
            assert ref == r3

if __name__ == '__main__':
    main()