Tutorial 4: LUT Operator

In this tutorial, you will write a simple vector lookup-table (LUT) operator to implement silu using Compass DSL. You will learn about:

  • How to call a subroutine function on the device side.

  • The difference between running on the host side when compling and function running on the device side during runtime.

  • How to declare a constant LUT table.

  • How to vectorize using table-lookup vector-builtin to implement silu using the pre-prepared LUT table.

  • How to interpolate between 2 items to implement higher precision Silu.

Inputs & Outputs

  • Inputs:

    • Tensor(in0 ,shape=[n], dtype=”float16”)

    • n: i32

  • Output:

    • Tensor(out0 ,shape=[n], dtype=”float16”)

So we can write the primfunc like this:

dtype = "float16"
vdtype = "float16x16"
lut_len = 512
lut_edge = 10

FP16_ELEMS_ON_LSRAM = 32 * 1024 // 2
FP16x16_ELEMS_ON_LSRAM = FP16_ELEMS_ON_LSRAM // 16

@S.prim_func
def silu_fp16(
    in0: S.ptr(dtype, "global"),
    out0: S.ptr(dtype, "global"),
    n: S.i32,
):
    # func body
    ...

Prepare LUT Table

Assume that the input value ranges from [-10, 10] and we will prepare a table with 512 items.

  • In this case, we assume that the input value ranges from [-10, 10].

  • We prepare a table with 512 items.

def gen_silu_lut(lut_len, lut_edge, dtype):
    # Key points of FP16 LUT generation:
    # 1. lut_x[-lut_edge, lut_edge]
    # 2. dytpe = "float16"
    # 3. pad the lut_table with the last element (lut_edge)
    x = np.linspace(-lut_edge, lut_edge, lut_len - 1)
    n = lut_len - 1
    lut = np.zeros((n,), dtype=dtype)
    for i in range(n):
        lut[i] = x[i] / (1 + np.exp(-x[i]))
    lut.resize(lut_len)
    lut[-1] = lut_edge
    return lut

@S.prim_func
def silu_fp16(
    in0: S.ptr(dtype, "global"),
    out0: S.ptr(dtype, "global"),
    n: S.i32,
):
    lut = S.alloc_const((lut_len,), dtype, gen_silu_lut(lut_len, lut_edge, dtype))
    ...

You may notice that:

  • gen_silu_lut is not decorated by S.prim_func, which means that:

    • This function would be running on the host side during the compiling time.

    • This function would not generate any code on the device side.

  • S.alloc_const means allocating a “__constant” buffer on OpenCL and its value is determined by the gen_silu_lut return value.

Kernel Function

Here is the main kernel function

@S.prim_func
def silu_fp16(
    in0: S.ptr(dtype, "global"),
    out0: S.ptr(dtype, "global"),
    n: S.i32,
):
    lut = S.alloc_const((lut_len,), dtype, gen_silu_lut(lut_len, lut_edge, dtype))
    lut_inverse_delta = S.fp32(lut_len - 2) / (2 * lut_edge)

    lsram_ptr = S.alloc(FP16x16_ELEMS_ON_LSRAM, vdtype, scope="lsram")
    tec_cnt = S.get_local_size()
    tid = S.get_local_id()

    elems_per_tec = S.ceildiv(n, tec_cnt)
    elems_cur_tec = S.clip(n - tid * elems_per_tec, min_val=0, max_val=elems_per_tec)

    offset_cur_tec = tid * elems_per_tec
    for lsram_idx in range(S.ceildiv(elems_cur_tec, FP16_ELEMS_ON_LSRAM)):
        elems_cur_lsram = S.min(FP16_ELEMS_ON_LSRAM, elems_cur_tec - lsram_idx * FP16_ELEMS_ON_LSRAM)
        offset_cur_lsram = offset_cur_tec + lsram_idx * FP16_ELEMS_ON_LSRAM

        S.dma_copy(lsram_ptr.as_ptr(dtype), in0 + offset_cur_lsram, elems_cur_lsram)
        for vec_idx in range(S.ceildiv(elems_cur_lsram, vdtype.lanes)):
            lsram_ptr[vec_idx] = compute(lsram_ptr[vec_idx], lut, lut_inverse_delta, lut_len, lut_edge)
        S.dma_copy(out0 + offset_cur_lsram, lsram_ptr.as_ptr(dtype), elems_cur_lsram)

The main calculation process is as follows:

  • Calculate how many elements should be calculated on this TEC

  • Use dma_copy to copy the related input data from DDR to LSRAM

  • Call device prim_func to compute silu inplace on LSRAM

  • Use dma_copy to move the silu result from LSRAM to DDR

The main calucation logic is put on function “compute”. This function should be generated on the device side. So this is a device function. As a device funtion that should be run on the device during runtime, it should be decorated by S.prim_func.

Subroutine Call Device Function

The “compute” function accepts 5 parameters:

  • x, the input data, the data type is “fp16x16”, on LSRAM

  • lut, the lut table we pre-prepared

  • lut_inverse_delta: the gap between each item on the lut table, used to calucate the index and for interpolation.

  • lut_len, the table len

  • lut_edge: the table value range. The function return value data type is “fp16x16”. Also, it is a device function that should be decorated by S.prim_func.

So the function looks like this:

    @S.prim_func
    def compute(
        x: S.fp16x16,
        lut: S.ptr(dtype, "constant"),
        lut_inverse_delta: S.fp32,
        lut_len: S.i32,
        lut_edge: S.i32,
    ) -> S.fp16x16:
    ...
  • Clip the input value to the table range [-10, 10]

  • Use S.cast to convert the input value from “fp16x16” to fp32x16 “x_fp32”

  • Calculate the index and round fp32x16 “x_idx”

    # in compute function.
    x_clipped = S.clip(x, min_val=S.fp16(-lut_edge), max_val=S.fp16(lut_edge))
    x_fp32 = S.cast(x_clipped, "fp32")
    x_idx = (x_fp32 + lut_edge) * lut_inverse_delta
    x_idxr = S.vrint(x_idx - 0.5)
    x_idx_u16 = S.cast(x_idxr, "u16")
    ...
  • Use S.vload_gather to look up the lut table.

  • For interpolation, we also look up the next items of the lut table.

  • Calculate the diff and linear interpolation. Remember to use fp32x8 for interpolation.

  • After interpolation, cast down to fp16x16.

    lut_x_idx = S.vload_gather(lut, x_idx_u16)
    lut_x_idx_plus1 = S.vload_gather(lut, x_idx_u16 + 1)
    x_idx_diff = x_idx - x_idxr
    lut_diff = S.cast(lut_x_idx_plus1 - lut_x_idx, "fp32")
    yy = S.cast(lut_diff * x_idx_diff, dtype)
    ...
  • To ensure that the index does not exceed the range, calculate the mask.

  • Return base + diff.

    mask_idx_ge_lutlen_m2 = x_idx_u16 >= S.u16x16(lut_len - 2)
    y = S.vsel(x, 0, mask_idx_ge_lutlen_m2)
    mask_xor = S.vxor(x_idx_u16 >= 0, mask_idx_ge_lutlen_m2)
    return S.vadd(yy, lut_x_idx, mask=mask_xor, r=y)

The whole code is like this:

@S.prim_func
def compute(
    x: S.fp16x16,
    lut: S.ptr(dtype, "constant"),
    lut_inverse_delta: S.fp32,
    lut_len: S.i32,
    lut_edge: S.i32,
) -> S.fp16x16:
    # Original formula: silu(x) = x / (1 + e^(-x))
    # Here use lookup table implement with interpolation instead
    x_clipped = S.clip(x, min_val=S.fp16(-lut_edge), max_val=S.fp16(lut_edge))
    x_fp32 = S.cast(x_clipped, "fp32")
    x_idx = (x_fp32 + lut_edge) * lut_inverse_delta
    x_idxr = S.vrint(x_idx - 0.5)
    x_idx_u16 = S.cast(x_idxr, "u16")
    mask_idx_ge_lutlen_m2 = x_idx_u16 >= S.u16x16(lut_len - 2)

    y = S.vsel(x, 0, mask_idx_ge_lutlen_m2)
    lut_x_idx = S.vload_gather(lut, x_idx_u16)
    lut_x_idx_plus1 = S.vload_gather(lut, x_idx_u16 + 1)
    x_idx_diff = x_idx - x_idxr
    lut_diff = S.cast(lut_x_idx_plus1 - lut_x_idx, "fp32")
    yy = S.cast(lut_diff * x_idx_diff, dtype)
    mask_xor = S.vxor(x_idx_u16 >= 0, mask_idx_ge_lutlen_m2)
    return S.vadd(yy, lut_x_idx, mask=mask_xor, r=y)

Complete Code

You can find the sample code in PYTHON_PACKAGE_PATH/tvm/aipu/samples/dsl/tutorial_4_lut_op.py. The placeholder PYTHON_PACKAGE_PATH represents the location where you install the Compass DSL Python package, in general, it will be something like ~/.local/lib/python3.8/site-packages.