Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
d8884e6
Refactor BatchMatMulEmitter and BatchMatMulSelector for improved read…
LeiWang1999 Jul 5, 2024
fc84173
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
02f64de
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
397eee6
disable failure email for ci
LeiWang1999 Jul 5, 2024
20f6ad1
remove email notifications.
LeiWang1999 Jul 6, 2024
b93c394
move relax pass from testing to mlc_llm
LeiWang1999 Jul 6, 2024
ba6a6df
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
257693a
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
9bb7f49
Lint Fix
LeiWang1999 Jul 6, 2024
39e7614
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
93eb5a5
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
aa66a90
bug fix in test
LeiWang1999 Jul 6, 2024
ae14a53
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 6, 2024
79b08e4
lint fix.
LeiWang1999 Jul 6, 2024
86fd036
test cuda i4 kernel
LeiWang1999 Jul 7, 2024
6b73a21
Refactor copyright notice in i4matmul.hpp
LeiWang1999 Jul 7, 2024
0ba90c1
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 7, 2024
086d208
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 7, 2024
47a3abd
refactor test as version below python 3.9 cannot handle int32 overflow.
LeiWang1999 Jul 8, 2024
024b247
format lint for test
LeiWang1999 Jul 8, 2024
bfedeaa
Refactor test_int4b_fp16_convert.py for improved readability and main…
LeiWang1999 Jul 8, 2024
e672a23
remove unused design file
LeiWang1999 Jul 8, 2024
21e5430
move tile device from package to base
LeiWang1999 Jul 8, 2024
fd11940
dummy impl for codegen
LeiWang1999 Jul 8, 2024
9ccfa85
Refactor file structure for ladder_permutate module
LeiWang1999 Jul 8, 2024
7c7d73e
Refactor backend class and fix typos in comments
LeiWang1999 Jul 8, 2024
47d5fc5
Deep refactor Lib related code.
LeiWang1999 Jul 8, 2024
53dd0dd
remove ci pull.
LeiWang1999 Jul 10, 2024
d58ac43
LintFix
LeiWang1999 Jul 10, 2024
37cb07c
refactor builder for whl build
LeiWang1999 Jul 10, 2024
f5b9999
Refactor TIRWrapper.wrap() method to include an assertion for the opt…
LeiWang1999 Jul 11, 2024
fb78244
Refactor lib_generator to set library and source paths
LeiWang1999 Jul 11, 2024
706e227
lint fix
LeiWang1999 Jul 11, 2024
63f5515
BitNet vllm integration
LeiWang1999 Jul 16, 2024
de91c0d
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 16, 2024
b9655fd
chore: update codespell to version 2.3.0
LeiWang1999 Jul 16, 2024
fff385f
Lintfix
LeiWang1999 Jul 16, 2024
72a98e7
Bump version to 0.0.1.dev13
LeiWang1999 Jul 18, 2024
5646ab5
lint fix
LeiWang1999 Jul 18, 2024
b965863
disable fast decoding [u]int4xint8 by default.
LeiWang1999 Jul 21, 2024
1198fc7
optimize from dict design in Hint
LeiWang1999 Jul 21, 2024
014213c
Implement SplitK
LeiWang1999 Jul 21, 2024
e0ca752
bitnet benchmark generation.
LeiWang1999 Jul 21, 2024
81b9cf0
Add benchmark script for BitNet integration
LeiWang1999 Jul 21, 2024
02edc0b
AtomicAdd Support
LeiWang1999 Jul 21, 2024
1a70c2d
LintFix
LeiWang1999 Jul 21, 2024
28d851c
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 21, 2024
c447a95
ci fix when 3rdparty tvm is initialized.
LeiWang1999 Jul 21, 2024
79a001b
bug fix for setup
LeiWang1999 Jul 21, 2024
31813b2
fix a bug in block reduce
LeiWang1999 Jul 21, 2024
78b6a3d
typo fix
LeiWang1999 Jul 21, 2024
9c55218
BUG Fix for block reduce.
LeiWang1999 Jul 22, 2024
1aa8868
Lint fix
LeiWang1999 Jul 22, 2024
22f70bf
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
5f082a5
Refactor block reduce schedule template
LeiWang1999 Jul 22, 2024
b4fb31e
transform branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
35eaa00
Fix subproject commit reference in 3rdparty/tvm
LeiWang1999 Jul 22, 2024
254dd74
chore: update submodule branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
31a44aa
force update config.cmake
LeiWang1999 Jul 22, 2024
427800e
Bug fix
LeiWang1999 Jul 22, 2024
96db111
Fix subproject commit reference in 3rdparty/cutlass
LeiWang1999 Jul 22, 2024
38b251a
chore: Add submodule for cutlass library
LeiWang1999 Jul 22, 2024
87d1c5a
update tl cutlass path
LeiWang1999 Jul 22, 2024
6200b1e
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
0ffe0b5
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 22, 2024
8e08e77
format fix
LeiWang1999 Jul 22, 2024
df05a64
Copy CUTLASS to the package directory
LeiWang1999 Jul 22, 2024
4f529c5
Refactor setup.py to include additional TVM header files
LeiWang1999 Jul 22, 2024
d02bbc7
lint fix
LeiWang1999 Jul 23, 2024
cffe3fd
bug fix
LeiWang1999 Jul 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
path = 3rdparty/tvm
url = https://github.com/LeiWang1999/tvm
branch = bitblas_tl
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v3.2.2
1 change: 1 addition & 0 deletions 3rdparty/cutlass
Submodule cutlass added at 44c704
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+1 −1 python/tvm/tl/engine.py
6 changes: 6 additions & 0 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

# installing tvm
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
sys.path.insert(0, install_tvm_path + "/python")

develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
sys.path.insert(0, develop_tvm_path + "/python")

import tvm as tvm # noqa: E402
Expand Down
17 changes: 17 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def run(self):
"3rdparty/tvm/mypy.ini",
"3rdparty/tvm/pyproject.toml",
"3rdparty/tvm/version.py",
"3rdparty/tvm/src/tl/tl_templates",
]
for item in TVM_PREBUILD_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
Expand All @@ -252,6 +253,22 @@ def run(self):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)

# Copy CUTLASS to the package directory
CUTLASS_PREBUILD_ITEMS = [
"3rdparty/cutlass",
]
for item in CUTLASS_PREBUILD_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)


class BitBLASSdistCommand(sdist):
"""Customized setuptools sdist command - includes the pyproject.toml file."""
Expand Down
173 changes: 173 additions & 0 deletions testing/python/tilelang/test_tilelang_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm as tvm
import bitblas.testing
from tvm import tl


def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tvm.tl.language as T

@T.prim_func
def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N),
dtypeC)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_gemm(
M,
N,
K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)

def ref_program(A, B):
import torch

if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C

mod.assert_allclose(ref_program)


def test_gemm_f16f16f16_nn():
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)


def test_gemm_f16f16f32_nn():
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128, 32)


def test_gemm_bf16bf16f32_nn():
run_gemm(512, 1024, 768, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)


def test_gemm_f32f32f32_nn():
run_gemm(512, 1024, 768, False, False, "float32", "float32", "float32", 64, 128, 32)


def test_gemm_f64f64f64_nn():
run_gemm(512, 1024, 768, False, False, "float64", "float64", "float64", 64, 64, 16)


def test_gemm_i8i8i32_nn():
run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64)


def test_gemm_f16f16f16_tn():
run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)


def test_gemm_f16f16f16_nt():
run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)


def test_gemm_i8i8i32_nt():
run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64)


def test_gemm_i8i8i32_tn():
run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64)


def test_gemm_f64f64f64_nt():
run_gemm(512, 1024, 768, False, True, "float64", "float64", "float64", 64, 32, 16)


def test_gemm_f64f64f64_tn():
run_gemm(512, 1024, 768, True, False, "float64", "float64", "float64", 64, 32, 16)


def test_gemm_f32f32f32_nt():
run_gemm(512, 1024, 768, False, True, "float32", "float32", "float32", 64, 128, 32)


def test_gemm_f32f32f32_tn():
run_gemm(512, 1024, 768, True, False, "float32", "float32", "float32", 64, 128, 32)


def test_pad_aligned_f16f16f16_nn():
run_gemm(512 - 8, 1024 - 32, 768 - 24, False, False, "float16", "float16", "float16", 128, 256,
32, 2)


def test_pad_f16f16f16_nn():
run_gemm(512 - 9, 1024 - 7, 768 - 5, False, False, "float16", "float16", "float16", 128, 256,
32, 2)


def test_pad_f16f16f32_nn():
run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64,
32)


if __name__ == "__main__":
bitblas.testing.main()