Python Simulator (PySim)
What is PySim?
PySim is Python Simulator, which is a very important feature of Compass DSL. In Compass DSL, the kernel
function decorated with S.prim_func
is callable, and each API used in the kernel function has a
corresponding implementation in Python, so we can run and debug it in Python.
Why PySim?
During the writing process of operator programs, most bugs encountered are the simple logical errors. However, during the debugging process, each modification needs to be recompiled and deployed to the Compass NPU Simulator or real device to run. Such repeated operations back and forth are very time-consuming, and after the program is compiled, you need to use the Compass OpenCL Debugger to debug the OpenCL code, which also brings greater difficulties.
Therefore, we introduced PySim to the Compass DSL. When encountering logic errors in the program, debugging can be completed on the Python side without compiling and deploying every time, which greatly reduces the difficulty of debugging.
Implementation Principle of PySim
The use of Pysim is very simple. After writing a function using Compass DSL, first call BuildManager to perform syntax compilation check, and then call the function directly in Python.
import numpy as np
from tvm import aipu
from tvm.aipu import script as S
dtype = "int32"
n = 8
@S.prim_func
def func_add(a: S.ptr(dtype, "global"), b: S.ptr(dtype, "global")):
va = S.vload(a)
vb = S.vadd(va, 1)
S.vstore(vb, b)
bm = aipu.tir.BuildManager()
ex = bm.build(func_add)
a = np.array(list(range(n)), dtype=dtype)
aipu_out = np.zeros((n,), dtype=dtype)
func_add(a, aipu_out) # Run in Python(Pysim)
Then we use the above code as an example to describe the implementation principle of PySim.
One Interface, Two Implementations
In order to allow an Compass DSL program to be compiled and deployed through BuildManager, and
use the same code to run and debug through PySim, we have made two implementations of each
interface, taking S.vload
as an example:
@register_ir_api
def vload(addr, mask=None, lanes=None, stride=None):
...
@register_ir_api
def _py_vload(addr, mask=None, lanes=None, stride=None):
...
For a vload
interface, with the above two implementations, the first one is used when the Zhouyi Compass
DSL program is compiling, and the second one is used when the Zhouyi Compass DSL program is running
directly by the Python interpreter. The decorator register_ir_api
is added to both
implementations.
Similarly, all interfaces have two implementations, and both have the decorator register_ir_api
added. In the module import phase of the DSL program, all interfaces will be imported into the
current namespace as module members. When importing each interface, the decorator register_ir_api
is executed first.
def register_ir_api(func):
...
if func.__module__.startswith("tvm.aipu.script."):
name = func.__name__
else:
# Some IR API functions of TVM Script are created through generator, so
# their function name is same, e.g., that of T.int32 and T.uint32 is
# "tvm.script.xxx.func_gen.<locals>.func".
caller_code = inspect.getframeinfo(inspect.currentframe().f_back).code_context[0]
name = caller_code.split("=")[0].strip()
if name[0] == "_":
assert name[:4] == "_py_", "Decorate wrong function."
name = name[4:]
assert _IR_API_NAME_2_IMPLEMENTS[name][1] is None, f"{name} Override happened."
_IR_API_NAME_2_IMPLEMENTS[name][1] = func
return func
assert _IR_API_NAME_2_IMPLEMENTS[name][0] is None, f"{name} Override happened."
_IR_API_NAME_2_IMPLEMENTS[name][0] = func
...
In register_ir_api
, determine which implementation it is by the function name, and then put the
interface and its corresponding two implementation relationships into the table
_IR_API_NAME_2_IMPLEMENTS
. Finally, a decorator function _wrapper
is returned to allocate the
implementation when the interface is called.
After completing the two implementations of all interfaces, the next thing we have to do is to ensure that the two execution paths of BuildManager and Pysim can get the correct implementation for use when executing the Compass DSL program.
Wrap Kernel Function with PyPrimFunc
In the Compass DSL program, each kernel function has a decorator S.prim_func
.
@S.prim_func
def func_add(a: S.ptr(dtype, "global"), b: S.ptr(dtype, "global")):
va = S.vload(a)
vb = S.vadd(va, 1)
S.vstore(vb, b)
@functools.wraps(T.prim_func)
def prim_func(func=None, private=False, is_entry=False): # pylint: disable=unused-argument
"""Simple wrapper of the corresponding API of TVM Script."""
def _decorator(myf):
return functools.wraps(myf)(PyPrimFunc(myf, {"private": private}))
setattr(_decorator, "dispatch_token", "tir")
return _decorator(func) if func else _decorator
When calling kernel function func_add
, you will first enter the decorator S.prim_func
. You can
see that it returns a class instance PyPrimFunc(myf, {"private": private})
, and func_add
is
assigned as a parameter to the py_func
attribute of the PyPrimFunc
.
class PyPrimFunc:
"""The simple wrapper of the Python function written by user."""
def __init__(self, py_func, prim_func_kwargs):
self.py_func = py_func
self.prim_func_kwargs = prim_func_kwargs
self.prim_func = None
self._param_anns = []
...
def __call__(self, *args):
...
At this time, the kernel function func_add
we got is actually a PyPrimFunc
object returned from
S.prim_func
.
Path to BuildManger
When we pass this kernel function to BuildManeger
, BuildManager
will take out py_func
in
PyPrimFunc
and convert it into prim_func
in the lower
stage, and then perform a regular build workflow.
def parse_to_prim_func(func):
"""Parse to TensorIR PrimFunc through TVM Script parser."""
if not isinstance(func, PyPrimFunc):
raise RuntimeError(
f'The function "{func.__module__}.{func.__name__}" must be decorated by "S.prim_func".'
)
ret = T.prim_func(func.py_func, **(func.prim_func_kwargs))
func.prim_func = ret
return ret
Path to PySim
When we call the kernel function directly to run on PySim, we will enter the __call__
method of
PyPrimFunc
.
def __call__(self, *args):
...
with PySimInfo() as py_sim_info:
...
self.py_func(*[x.copy() if isinstance(x, PyVar) else x for x in args])
...
Here self.py_func
will be executed. In this example, it is the kernel function func_add
.
S.prim_func
def func_add(a: S.ptr(dtype, "global"), b: S.ptr(dtype, "global")):
va = S.vload(a)
vb = S.vadd(va, 1)
S.vstore(vb, b)
When running the kernel function func_add
, it is wrapped with the with statement and the class
PysimInfo
is used as the context manager.
class PySimInfo:
"""Maintain all of the status information when simulate the TVM script in Python."""
current = None
def __init__(self, is_multi_thread=True):
self.local_size = 4
self.thread_local_data = threading.local()
self.barrier = threading.Barrier(self.local_size)
self.cur_shared_buffer = None
self.is_multi_thread = is_multi_thread
self._old_current = None
def __enter__(self):
self._old_current, PySimInfo.current = PySimInfo.current, self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
PySimInfo.current = self._old_current
When executing each interface after entering the func_add
function, since func_add
is in the
context of PySimInfo
, as mentioned earlier, the interface returns a _wrapper
after being
processed by register_ir_api
. In _wrapper
, it will be detected that the current context is under
PysimInfo
, and then the Python version implementation of the interface will be taken out from
_IR_API_NAME_2_IMPLEMENTS
to run and return the result.
@functools.wraps(func)
def _wrapper(*args, **kwargs):
if PySimInfo.current is None:
return _IR_API_NAME_2_IMPLEMENTS[name][0](*args, **kwargs)
# Execute the PySim version of the IR API.
tld = PySimInfo.current.thread_local_data
old_value, tld.is_in_ir_api = tld.is_in_ir_api, True
ret = _IR_API_NAME_2_IMPLEMENTS[name][1](*args, **kwargs)
tld.is_in_ir_api = old_value
return ret
return _wrapper
Multi-Thread to Simulate All Circumstances
Since the Zhouyi NPU architecture is multi-TEC parallel, PySim also supports multi-thread parallelism, and each thread simulates the execution of a TEC.
def __call__(self, *args):
...
with PySimInfo(is_multi_thread=True) as py_sim_info:
def _run(future, thread_id):
py_sim_info.thread_local_data.id = thread_id
py_sim_info.thread_local_data.is_in_ir_api = False
try:
self.py_func(*[x.copy() if isinstance(x, PyVar) else x for x in args])
except BaseException as exc:
future.set_exception(exc)
else:
future.set_result(None)
futures = []
for i in range(py_sim_info.local_size):
future = concurrent.futures.Future()
threading.Thread(target=_run, name=f"TEC{i}", args=(future, i)).start()
futures.append(future)
for future in futures:
# The exceptions raised in the sub-thread will be re-raised
# here, so the main thread can catch and handle them.
future.result()
...
Single-Thread Loop to Simulate Simple Circumstances (Needn’t Sync)
Of course, to facilitate debugging, we also support single-threaded operation. You only need to set
an environment variable export AIPU_TVM_PYSIM_SINGLE_THREAD=TRUE
. After all, multi-threaded
debugging is still a relatively complicated matter. If you find that it can be reproduced with a
single thread, then you can debug the program in a single thread.
if os.getenv("AIPU_TVM_PYSIM_SINGLE_THREAD", "true").upper() == "TRUE":
WARN("PySim is running in single thread, some data race issues may can't be caught.")
with PySimInfo(is_multi_thread=False) as py_sim_info:
for i in range(py_sim_info.local_size):
py_sim_info.thread_local_data.id = i
py_sim_info.thread_local_data.is_in_ir_api = False
self.py_func(*[x.copy() if isinstance(x, PyVar) else x for x in args])