How to Use Macros
When implementing a kernel, the body of the PrimFunc may contain repeated chunks of code, or long pieces of code that make the PrimFunc harder to read. In such a situation, we would like to group some codes into other macros and make the main function body clean and readable.
Warning
PySim does not support macros. If your DSL program has used a macro, you cannot debug it directly in Python, only can run and debug it through the AIPU simulator or hardware, so we recommend using subfunctions as more as possible.
Why use Macros
The macros would only serve as a tool to help organize the original source, and make the main code clean and readable.
Macro Principle
The Macro is similar to C preprocessor.
The macro would never have any actual calls generated to the PrimFunc body. Instead, they would be pasted (inlined) into TIR.
For macro call macro_name(arg1, arg2, arg3, ...)
, the values are substituted into the body of the macro, and the body with the substituted values is then inserted at the point where the call is located.
Macro Language Spec
decorator: function that is used as a macro should be decorated with
@S.macro
parameters: all parameters can follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed.
macro call:
macro_name(arg1, arg2, arg3, ...)
, the same as the function call syntax.hygienic: this is the option to control the symbol resolution rules of macro. It is set to True (
@S.macro(hygienic=True)
) by default.hygienic=True: if all symbols used in the macro’s body are resolved to values from the location of the macro definition.
hygienic=False: the macro will have its symbols resolved to values at the time of macro use.
Macro Symbol Resolution
In the following example, the symbol n
is not in macro parameters, but used in the macro body.
Let’s see these two cases:
Static Capture: Symbol resolved to the value at the time of macro definition
In this case, the hygienic is set to
True
by default. The symboln
is resolved at the time of macro definition, and its value is 8.n = 8 @S.macro def add_func(inp, out, idx): out[idx] = inp[idx] + n @S.prim_func def func(a: S.ptr("i32", "global"), b: S.ptr("i32", "global"), n: S.i32): if n > 0: add_func(a, b, 0) else: add_func(a, b, 1)
The generated code is:
__kernel void func(__global int* a, __global int* b, int n) { if (0 < n) { b[0] = (a[0] + 8); } else { b[1] = (a[1] + 8); } }
Dynamic Capture: Symbol resolved to the value at the time of macro use.
In this case, you should explicitly set
hygienic=False
. The symbol ofn
will be resolved to the value at the time of macro use.n = 8 @S.macro(hygienic=False) def add_func(inp, out, idx): out[idx] = inp[idx] + n @S.prim_func def func(a: S.ptr("i32", "global", b: S.ptr("i32", "global", n: S.i32): if n > 0: add_func(a, b, 0) else: add_func(a, b, 1)
The generated code will be:
__kernel void func(__global int* a, __global int* b, int n) { if (0 < n) { b[0] = (a[0] + n); } else { b[1] = (a[1] + n); } }
Macro Example
Here we use Macro in the example of the concat operator. The concat takes 4 inputs with shape(8) and concatenates into the output of shape(4*8).
c = 8
n = 4
dtype = "int8"
@S.macro(hygienic=False)
def body(inp):
S.dma_copy(lsram, inp, c)
@S.prim_func
def concat(
a1: S.ptr(dtype, "global"),
a2: S.ptr(dtype, "global"),
a3: S.ptr(dtype, "global"),
a4: S.ptr(dtype, "global"),
out: S.ptr(dtype, "global"),
):
lsram = S.alloc_buffer(dtype=dtype, shape=[c], scope="lsram")
for tec_i in S.tec_range(0, 4):
if tec_i == 0:
body(a1)
if tec_i == 1:
body(a2)
if tec_i == 2:
body(a3)
if tec_i == 3:
body(a4)
S.dma_copy(out + tec_i * c, lsram, c)
In this example, the macro body has 3 symbols:
inp
: in the macro parameterc
: defined as const c=8lsram
: not defined at the time of macro definition
If we use @S.macro
, with hygienic set to True by default, the symbol will be resolved at the time of macro definition. However, lsram
is not defined at the time of macro definition. Then you will get the error:
error: Undefined variable: lsram
You can solve this problem in two ways:
Use
@S.macro(hygienic=False)
, then the symbol oflsram
will be resolved at the time of macro use, with lsram defined with this code:lsram = S.alloc_buffer(dtype=dtype, shape=[c], scope="lsram")
The generated code will be:
__kernel void concat(__global char* a1, __global char* a2, __global char* a3, __global char* a4, __global char* out) { int tid = get_local_id(0); __lsram char lsram[8]; if (tid == 0) { DMA1D(a1, lsram, 8, 1, 0); } if (tid == 1) { DMA1D(a2, lsram, 8, 1, 0); } if (tid == 2) { DMA1D(a3, lsram, 8, 1, 0); } if (tid == 3) { DMA1D(a4, lsram, 8, 1, 0); } DMA1D((out + (tid * 8)), lsram, 8, 0, 0); }
Put
lsram
into macro parameters and use the default@S.macro(hygienic=True)
.c = 8 n = 4 dtype = "int8" @S.macro def body(inp, lsram): S.dma_copy(lsram, inp, c) @S.prim_func def concat( a1: S.ptr(dtype, "global"), a2: S.ptr(dtype, "global"), a3: S.ptr(dtype, "global"), a4: S.ptr(dtype, "global"), out: S.ptr(dtype, "global"), ): lsram = S.alloc_buffer(dtype=dtype, shape=[c], scope="lsram") for tec_i in S.tec_range(0, 4): if tec_i == 0: body(a1, lsram) if tec_i == 1: body(a2, lsram) if tec_i == 2: body(a3, lsram) if tec_i == 3: body(a4, lsram) S.dma_copy(out+ tec_i * c, lsram, c)
The generated code is:
__kernel void concat(__global char* a1, __global char* a2, __global char* a3, __global char* a4, __global char* out) { int tid = get_local_id(0); __lsram char lsram[8]; if (tid == 0) { DMA1D(a1, lsram, 8, 1, 0); } if (tid == 1) { DMA1D(a2, lsram, 8, 1, 0); } if (tid == 2) { DMA1D(a3, lsram, 8, 1, 0); } if (tid == 3) { DMA1D(a4, lsram, 8, 1, 0); } DMA1D((out + (tid * 8)), lsram, 8, 0, 0); }
In summary, if you put all the symbols used in the macro body in macro parameters, you can just use
@S.macro
. If you want to only put the variant one in macro parameters, with other common symbols defined at the time of macro use, use@S.macro(hygienic=False)
.