Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.

[BUGFix] Fix LowerThreadAllReduce Pass for Hopper Arch#165

Merged
LeiWang1999 merged 1 commit intomicrosoft:mainfrom
LeiWang1999:hopper
Aug 31, 2024
Merged

[BUGFix] Fix LowerThreadAllReduce Pass for Hopper Arch#165
LeiWang1999 merged 1 commit intomicrosoft:mainfrom
LeiWang1999:hopper

Conversation

@LeiWang1999
Copy link
Contributor

In hopper, the warp reduction instructions will have unknown behavior.

import tvm
from tvm.script import ir as I
from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((1, 1024), "int8"), B: T.Buffer((1024, 1024), "int8"), D: T.Buffer((1, 1024), "int32")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    A_local = T.alloc_buffer((1, 1024), "int8", scope="local")
    B_local = T.alloc_buffer((1024, 1024), "int8", scope="local")
    C_local = T.alloc_buffer((1, 1024), "int32", scope="local")
    for ax0_0 in T.thread_binding(512, thread="blockIdx.x"):
        for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
            for ax1_0 in range(1):
                for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
                    for ax0 in range(1):
                        for ax1 in T.vectorized(16):
                            with T.block("A_local"):
                                v0 = T.axis.spatial(1, ax0)
                                v1 = T.axis.spatial(1024, ax1_1 * 16 + ax1)
                                T.reads(A[v0, v1])
                                T.writes(A_local[v0, v1])
                                A_local[v0, v1] = A[v0, v1]
                    for ax0 in range(1):
                        for ax1 in T.vectorized(16):
                            with T.block("B_local"):
                                v0 = T.axis.spatial(1024, ax0_0 * 2 + ax0_1 + ax0)
                                v1 = T.axis.spatial(1024, ax1_1 * 16 + ax1)
                                T.reads(B[v0, v1])
                                T.writes(B_local[v0, v1])
                                B_local[v0, v1] = B[v0, v1]
                    for ax1_2 in range(16):
                        with T.block("C"):
                            v0 = T.axis.spatial(1024, ax0_0 * 2 + ax0_1)
                            v1 = T.axis.reduce(1024, ax1_0 * 1024 + ax1_1 * 16 + ax1_2)
                            T.reads(A_local[0, v1], B_local[v0, v1])
                            T.writes(C_local[0, v0])
                            with T.init():
                                C_local[0, v0] = 0
                            C_local[0, v0] = C_local[0, v0] + T.Cast("int32", A_local[0, v1]) * T.Cast("int32", B_local[v0, v1])
            for ax0, ax1 in T.grid(1, 1):
                with T.block("C_local"):
                    v0 = T.axis.spatial(1, ax0)
                    v1 = T.axis.spatial(1024, ax0_0 * 2 + ax0_1 + ax1)
                    T.reads(C_local[v0, v1])
                    T.writes(D[0, v1])
                    D[0, v1] = T.Cast("int32", C_local[v0, v1])

target = "cuda"
with tvm.transform.PassContext(config={
        "tir.use_async_copy": True,
}):
    rt_mod = tvm.build(main, target=target)

print(rt_mod.imported_modules[0].get_source())

import numpy as np

device = tvm.cuda(0)
a = tvm.nd.array(np.random.randint(-128, 127, (1, 1024)).astype("int8"),device=device)
b = tvm.nd.array(np.random.randint(-128, 127, (1024, 1024)).astype("int8"),device=device)
d = tvm.nd.array(np.zeros((1, 1024)).astype("int32"),device=device)
rt_mod(a, b, d)
print(d)

ref_d  = np.dot(a.numpy().astype("int32"), b.numpy().astype("int32").T)
print(ref_d)
print(np.allclose(d.numpy(), ref_d))

->

illegal instruction was encountered 

This was fixed by the upstream tvm, and may lead to bugs ref to issue #157

We also revert the changes for volatile annotation in this pr.

@LeiWang1999 LeiWang1999 merged commit b1f5e79 into microsoft:main Aug 31, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant