Tutorial 5: Quantization Operator

Based on previous Tutorial 1 Static Add and Tutorial 2 Dynamic Add, in this tutorial, you will write a quantization vector addition with dynamic kernel using Compass DSL. You will learn about:

  • How to upgrade low-level wide types to high-precision types.

  • How to determine the type of output when two operands are signed and unsigned in an operation.

  • How to convert high width types to low width types.

Inputs & Outputs

  • Inputs:

    • Tensor(a,shape=(n,), dtype=”int8”)

    • Tensor(b,shape=(n,), dtype=”int8”)

  • Output:

    • Tensor(c,shape=(n,), dtype=”int8”)

  • Others:

    • zero point:

      • input0 zp: int16

      • input1 zp: int16

      • output zp: int16

    • scale

      • input0 scale: uint16

      • input1 scale: uint16

      • output scale: uint8

    • shift: int8

So we can write the primfunc like this:

dtype = "int8"

@S.prim_func(is_entry=True)
def eltwise_add_func(
    a: S.ptr(dtype, "global"),
    b: S.ptr(dtype, "global"),
    c: S.ptr(dtype, "global"),
    n: S.i32,
    zp_i0: S.i16,
    zp_i1: S.i16,
    zp_o: S.i16,
    scale_i0: S.u16,
    scale_i1: S.u16,
    scale_o: S.u8,
    shift: S.i8,
):
    # func body
    ...

Quantization operator implementation is tight with the quantization stage. This tutorial concentrates on operator implementation but not quantization algorithm.

Calculation

The following is the main calculation process.

@S.prim_func
def compute(
    inp0_lsram: S.ptr(dtype, "lsram"),
    inp1_lsram: S.ptr(dtype, "lsram"),
    size: S.i32,
    zp_i0: S.i16,
    zp_i1: S.i16,
    zp_o: S.i16,
    scale_i0: S.u16,
    scale_i1: S.u16,
    scale_o: S.u8,
    shift: S.i8,
):
for i in range((size + 7) // 8):
    a = S.vload(inp0_lsram + i * 8)
    b = S.vload(inp1_lsram + i * 8)

    # Upgrade types: int8x32 --> int32x8
    a32 = S.vxtl(S.vxtl(a))
    b32 = S.vxtl(S.vxtl(b))

    # Compute: (a + zp_i0) * scale_i0
    a32 = S.vadd(a32, zp_i0, saturate=True, out_sign="s")
    a32 = S.vmul(a32, scale_i0, out_sign="s")

    # Compute: (b + zp_i1) * scale_i1
    b32 = S.vadd(b32, zp_i1, saturate=True, out_sign="s")
    b32 = S.vmul(b32, scale_i1, out_sign="s")

    # Element-wise Method Operated
    tmp_w = S.vadd(a32, b32, saturate=True)

    # Multiply with scale_o
    tmp_w = S.vmul(tmp_w, scale_o, out_sign="s")

    # Shift and do narrow convert type from 8-bit to 16-bit
    tmp_h = S.i16x16(0)
    if shift < 0:
        tmp_w = S.vsl(tmp_w, -shift)
        tmp_h = S.vnsrsr(tmp_w, 0, to_h=True)
    else:
        tmp_h = S.vnsrsr(tmp_w, shift, to_h=True)

    # Subtract zero point of outputand narrow convert from 16-bit to 8-bit
    tmp_h = S.vsub(tmp_h, zp_o, saturate=True, out_sign="s")
    tmp_b = S.vnsrsr(tmp_h, 0)

    # Pack
    out = S.vcompt(tmp_b, mask="8TFFF")

    # save
    S.vstore(out, inp0_lsram + i * 8, mask="8T24F")

Upgrade type

Perform type upgrade twice on elements: first from 8-bit type to 16-bit, then to 32-bit. Variable “a” has 32 8-bit elements, and the new variable “a32” has 8 32-bit elements.

# If "a" is int8x32 [0,1,2,3,...,31], then "a32" is int32x8 [0,1,2,3,4,5,6,7].
a32 = S.vxtl(S.vxtl(a))
b32 = S.vxtl(S.vxtl(b))

Therefore, the subsequent calculations are all on 32 bits.

Operation of Signed and Unsigned

Compute the obeys formula: (a + zp_i0) * scale_i0:

  • zp_i0: the zero point of the first input

  • scale_i0: the scale of the first input

The zero point and scale of input0 are derived from the quantization stage. Here calculation follows the requirements and order of the quantization stage to ensure that accuracy is not lost.

a32 = S.vadd(a32, zp_i0, saturate=True, out_sign="s")
a32 = S.vmul(a32, scale_i0, out_sign="s")

Addition uses the saturate version to avoid overflow. Because the zero point is a signed value, so the specifying sign of output is “s”. Considering that the result of addition is a signed value, thus the sign of output for multiplication is “s”.

Downgrade Type and Pack

Downgrade type by using S.vnsrsr: the “to_h” argument of S.vnsrsr specifies whether the type of outputs is 16-bit, otherwise it is lower bit (8-bit).

# Assume "tmp_h" is int16x16 [0,1,2,3,...,15]
tmp_b = S.vnsrsr(tmp_h, 0)  # "tmp_b" is int8x32 [0,0,1,0,2,0,3,0,4,...,0,15,0]
out = S.vcompt(tmp_b, mask="8TFFF")  # "out" is int8x32 [0,2,4,6,8,10,12,14,0,0,...,0]

Pack the interleaved result together while squeezing bubbles caused by narrow convert. The mask “8TFFF” specifies that elements in “tmp_b” need to remain, representing 8 interleaving elements visually - “0b10001000100010001000100010001” in binary.

Complete Code

You can find the sample code in PYTHON_PACKAGE_PATH/tvm/aipu/samples/dsl/tutorial_5_quantization_op.py. The placeholder PYTHON_PACKAGE_PATH represents the location where you install the Compass DSL Python package, in general, it will be something like ~/.local/lib/python3.8/site-packages.