diff --git a/.gitmodules b/.gitmodules index 57576c5fe..5980500e4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,7 @@ path = 3rdparty/tvm url = https://github.com/LeiWang1999/tvm branch = bitblas_tl +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git + branch = v3.2.2 diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000..44c704eae --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 44c704eae85da352d277d6f092f41412772f70e4 diff --git a/3rdparty/tvm b/3rdparty/tvm index 049a8c5f4..d9391a502 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 049a8c5f44d5c911be992f650dba78e8c7a75203 +Subproject commit d9391a502b5544722eb67c4a0c4dff49a3476c06 diff --git a/bitblas/__init__.py b/bitblas/__init__.py index e40f17f3e..ee79bc3c9 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -5,13 +5,19 @@ # installing tvm install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") +install_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") + os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" sys.path.insert(0, install_tvm_path + "/python") develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") +develop_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") + os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" sys.path.insert(0, develop_tvm_path + "/python") import tvm as tvm # noqa: E402 diff --git a/setup.py b/setup.py index 6954fa798..5fe71db40 100644 --- a/setup.py +++ b/setup.py @@ -239,6 +239,7 @@ def run(self): "3rdparty/tvm/mypy.ini", "3rdparty/tvm/pyproject.toml", "3rdparty/tvm/version.py", + "3rdparty/tvm/src/tl/tl_templates", ] for item in TVM_PREBUILD_ITEMS: source_dir = os.path.join(ROOT_DIR, item) @@ -252,6 +253,22 @@ def run(self): os.makedirs(target_dir) shutil.copy2(source_dir, target_dir) + # Copy CUTLASS to the package directory + CUTLASS_PREBUILD_ITEMS = [ + "3rdparty/cutlass", + ] + for item in CUTLASS_PREBUILD_ITEMS: + source_dir = os.path.join(ROOT_DIR, item) + target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + if os.path.isdir(source_dir): + self.mkpath(target_dir) + distutils.dir_util.copy_tree(source_dir, target_dir) + else: + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + shutil.copy2(source_dir, target_dir) + class BitBLASSdistCommand(sdist): """Customized setuptools sdist command - includes the pyproject.toml file.""" diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py new file mode 100644 index 000000000..c75e4ccc1 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +import bitblas.testing +from tvm import tl + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tvm.tl.language as T + + @T.prim_func + def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), + dtypeC)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + dtypeAccum, + num_stages, + num_threads, + ) + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + mod.assert_allclose(ref_program) + + +def test_gemm_f16f16f16_nn(): + run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def test_gemm_f16f16f32_nn(): + run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128, 32) + + +def test_gemm_bf16bf16f32_nn(): + run_gemm(512, 1024, 768, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) + + +def test_gemm_f32f32f32_nn(): + run_gemm(512, 1024, 768, False, False, "float32", "float32", "float32", 64, 128, 32) + + +def test_gemm_f64f64f64_nn(): + run_gemm(512, 1024, 768, False, False, "float64", "float64", "float64", 64, 64, 16) + + +def test_gemm_i8i8i32_nn(): + run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64) + + +def test_gemm_f16f16f16_tn(): + run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def test_gemm_f16f16f16_nt(): + run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) + + +def test_gemm_i8i8i32_nt(): + run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64) + + +def test_gemm_i8i8i32_tn(): + run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64) + + +def test_gemm_f64f64f64_nt(): + run_gemm(512, 1024, 768, False, True, "float64", "float64", "float64", 64, 32, 16) + + +def test_gemm_f64f64f64_tn(): + run_gemm(512, 1024, 768, True, False, "float64", "float64", "float64", 64, 32, 16) + + +def test_gemm_f32f32f32_nt(): + run_gemm(512, 1024, 768, False, True, "float32", "float32", "float32", 64, 128, 32) + + +def test_gemm_f32f32f32_tn(): + run_gemm(512, 1024, 768, True, False, "float32", "float32", "float32", 64, 128, 32) + + +def test_pad_aligned_f16f16f16_nn(): + run_gemm(512 - 8, 1024 - 32, 768 - 24, False, False, "float16", "float16", "float16", 128, 256, + 32, 2) + + +def test_pad_f16f16f16_nn(): + run_gemm(512 - 9, 1024 - 7, 768 - 5, False, False, "float16", "float16", "float16", 128, 256, + 32, 2) + + +def test_pad_f16f16f32_nn(): + run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64, + 32) + + +if __name__ == "__main__": + bitblas.testing.main()