From acb4aa4dce7f7a3470e8e562af7c51c57d514aec Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 4 Oct 2024 05:49:46 +0000 Subject: [PATCH 01/10] Merge TL Update --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index f1ad5c1c5..cd230c5cb 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit f1ad5c1c57c15485d5da1362621f40749ddfa9a1 +Subproject commit cd230c5cb374d7b6b7c51b3f34dbd4c0e598bf65 From 4aa081c4a7bbe876abae8e432f5ef4cc7cf288c2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 1 Nov 2024 06:50:42 +0000 Subject: [PATCH 02/10] submodule update --- .gitmodules | 2 +- 3rdparty/tvm | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index c8a359670..d5b545545 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/tvm"] path = 3rdparty/tvm url = https://github.com/TileLang/tvm.git - branch = tilelang + branch = upstream [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/TileLang/cutlass diff --git a/3rdparty/tvm b/3rdparty/tvm index 27078affb..e1c5b0897 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 27078affbe26b65d690d505f67178734d5c52629 +Subproject commit e1c5b089737e47a3849afa87df2432c13b633594 From 811e5c7e1354bf2d169aa21736f7b8a23b257ce9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 1 Nov 2024 07:24:21 +0000 Subject: [PATCH 03/10] Re-implement macro with sub function. --- bitblas/tl/macro_generator.py | 552 +++++++++++++++++++++------------- 1 file changed, 335 insertions(+), 217 deletions(-) diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 0f7adb791..349fa3557 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -33,19 +33,21 @@ class TensorCoreIntrinEmitter(object): "e5m2_float8": "e5m2", } - def __init__(self, - a_dtype="float16", - b_dtype="float16", - accum_dtype="float16", - a_transposed=False, - b_transposed=False, - block_row_warps=2, - block_col_warps=2, - warp_row_tiles=8, - warp_col_tiles=8, - chunk=16, - reduce_k=1, - num_elems_per_byte=1): + def __init__( + self, + a_dtype="float16", + b_dtype="float16", + accum_dtype="float16", + a_transposed=False, + b_transposed=False, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=8, + warp_col_tiles=8, + chunk=16, + reduce_k=1, + num_elems_per_byte=1, + ): self.a_dtype = a_dtype self.b_dtype = b_dtype self.accum_dtype = accum_dtype @@ -59,13 +61,17 @@ def __init__(self, self.chunk = chunk self._initialize_k_dim(a_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) - self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) + self._initialize_local_size( + self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE + ) self._initialize_mma_prefix(self.k_dim) self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k - self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + self.threads = ( + self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + ) self.num_elems_per_byte = num_elems_per_byte def _initialize_k_dim(self, a_dtype="float16"): @@ -73,7 +79,9 @@ def _initialize_k_dim(self, a_dtype="float16"): a_dtype = DataType(a_dtype) self.k_dim = 256 // a_dtype.bits - def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + def _initialize_local_size( + self, m_dim=16, n_dim=16, k_dim=16, warp_size=32 + ): self.local_size_a = (m_dim * k_dim) // warp_size self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size @@ -96,130 +104,192 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): self.micro_size_y = n_dim self.micro_size_k = k_dim - @T.macro - def _warp_ldmatrix_a( - inst, - A_local_buf, - A_shared_buf, - ki, - thread_bindings, - rk=0, - ): - stride = A_shared_buf.shape[-1] - tx = thread_bindings % inst.WARP_SIZE - ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps - - for i in T.serial(inst.warp_rows): - T.ptx_ldmatrix( - inst.a_dtype, - T.bool(False), - 4, - ".b16", - A_local_buf.data, - i * inst.local_size_a, - T.address_of(A_shared_buf[ - ty * inst.warp_row_tiles + i * inst.micro_size_x, - rk * inst.chunk + ki * inst.micro_size_k, - ]), - get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed), - ) - - @T.macro - def _warp_ldmatrix_b( - inst, - B_local_buf, - B_shared_buf, - ki, - thread_bindings, - rk=0, - ): - stride = B_shared_buf.shape[-1] - tx = thread_bindings % inst.WARP_SIZE - tz = (thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)) % inst.block_col_warps - - for j in T.serial(inst.warp_cols): - # Assign B_shared_elem - ri, rj = tz * inst.warp_col_tiles + j * inst.micro_size_y, rk * inst.chunk + ki * inst.micro_size_k - B_shared_elem = B_shared_buf[ri, rj] - - T.ptx_ldmatrix( - inst.b_dtype, - T.bool(False), # TODO(lei): should be optimized - 4, - ".b16", - B_local_buf.data, - j * inst.local_size_b, - T.address_of(B_shared_elem), - get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, inst.b_transposed), - ) - - @T.macro - def _warp_mma(inst, A_local_buf, B_local_buf, C_local_buf): - for i, j in T.grid(inst.warp_rows, inst.warp_cols): - T.ptx_mma( - inst.accum_dtype, - inst.mma_prefix, - "row", - "col", - inst.a_dtype_abbrv, - inst.b_dtype_abbrv, - inst.accum_dtype_abbrv, - A_local_buf.data, - i * inst.local_size_a, - B_local_buf.data, - j * inst.local_size_b, - C_local_buf.data, - i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out, - T.bool(False), - ) - - T.ptx_mma( - inst.accum_dtype, - inst.mma_prefix, - "row", - "col", - inst.a_dtype_abbrv, - inst.b_dtype_abbrv, - inst.accum_dtype_abbrv, - A_local_buf.data, - i * inst.local_size_a, - B_local_buf.data, - j * inst.local_size_b + lift(inst.local_size_b) // 2, - C_local_buf.data, - i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out + - lift(inst.local_size_out) // 2, - T.bool(False), - ) - - # STS - # MMA Store must be in simulated instead of TVM Intrins - # As TVM Intrins is like a hack that the threadIdx.x should be always - # equal to the warp_size - @T.macro - def _warp_stmatrix(inst, C_local_buf, C_shared_buf, thread_bindings): - tx = thread_bindings % inst.WARP_SIZE - ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps - tz = (thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)) % inst.block_col_warps - for i, j in T.grid(inst.warp_rows, inst.warp_cols): - for local_id_o in T.serial(inst.local_size_out // 2): - for local_id_i in T.vectorized(2): - local_id = local_id_o * 2 + local_id_i - row, col = T.meta_var(mma_store_index_map(tx, local_id)) - C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row, - col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) + - j * inst.local_size_out + local_id] - def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): - return self._warp_ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk) + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + warp_row_tiles = self.warp_row_tiles + warp_rows = self.warp_rows + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + a_dtype = self.a_dtype + a_transposed = self.a_transposed + local_size_a = self.local_size_a + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_bindings, + rk=0, + ): + stride = A_shared_buf.shape[-1] + tx = thread_bindings % WARP_SIZE + ty = (thread_bindings // WARP_SIZE) % block_row_warps + + for i in T.serial(warp_rows): + T.ptx_ldmatrix( + a_dtype, + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * local_size_a, + T.address_of( + A_shared_buf[ + ty * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ] + ), + get_ldmatrix_offset( + "A", tx, 0, stride, a_dtype, a_transposed + ), + ) + + return _warp_ldmatrix_a( + A_local_buf, A_shared_buf, ki, thread_bindings, rk + ) def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): - return self._warp_ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk) + + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + b_dtype = self.b_dtype + b_transposed = self.b_transposed + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_bindings, + rk=0, + ): + stride = B_shared_buf.shape[-1] + tx = thread_bindings % WARP_SIZE + tz = ( + thread_bindings // (WARP_SIZE * block_row_warps) + ) % block_col_warps + + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + tz * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + B_shared_elem = B_shared_buf[ri, rj] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), # TODO(lei): should be optimized + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + get_ldmatrix_offset( + "B", tx, 0, stride, b_dtype, b_transposed + ), + ) + + return _warp_ldmatrix_b( + B_local_buf, B_shared_buf, ki, thread_bindings, rk + ) def mma(self, A_local_buf, B_local_buf, C_local_buf): - return self._warp_mma(self, A_local_buf, B_local_buf, C_local_buf) + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + + j * local_size_out + + lift(local_size_out) // 2, + T.bool(False), + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) def stmatrix(self, C_local_buf, C_shared_buf, thread_bindings): - return self._warp_stmatrix(self, C_local_buf, C_shared_buf, thread_bindings) + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_out = self.local_size_out + + # STS + # MMA Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @T.macro + def _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings): + tx = thread_bindings % WARP_SIZE + ty = (thread_bindings // WARP_SIZE) % block_row_warps + tz = ( + thread_bindings // (WARP_SIZE * block_row_warps) + ) % block_col_warps + for i, j in T.grid(warp_rows, warp_cols): + for local_id_o in T.serial(local_size_out // 2): + for local_id_i in T.vectorized(2): + local_id = local_id_o * 2 + local_id_i + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + C_shared_buf[ + ty * warp_rows + i, tz * warp_cols + j, row, col + ] = C_local_buf[ + i * (warp_cols * local_size_out) + + j * local_size_out + + local_id + ] + + return _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings) class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): @@ -264,7 +334,9 @@ def __init__( def _initialize_k_dim(self, a_dtype="float16"): self.k_dim = 256 // DataType(a_dtype).bits - def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + def _initialize_local_size( + self, m_dim=16, n_dim=16, k_dim=16, warp_size=32 + ): self.local_size_a = (m_dim * k_dim) // warp_size self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size @@ -307,91 +379,137 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): assert transform_kind_b in [0, 3], "Currently only support 0 and 3" - @T.macro - def _warp_ldmatrix_b( - inst, - B_local_buf, - B_shared_buf, - ki, - thread_bindings, - rk=0, - ): - stride = B_shared_buf.shape[-1] - tx = thread_bindings % inst.WARP_SIZE - tz = (thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)) % inst.block_col_warps - - if inst.transform_kind_b < TransformKind.LDMatrixTransform: - for j in T.serial(inst.warp_cols): - # Assign B_shared_elem - ri, rj = tz * inst.warp_col_tiles + j * inst.micro_size_y, rk * inst.chunk + ki * inst.micro_size_k - ni, nj, nii, njj = (ri) // inst.micro_size_y, (rj) // inst.micro_size_k, ( - ri) % inst.micro_size_y, (rj) % inst.micro_size_k - args = (ni, nj, nii, njj) if inst.transform_kind_b > 0 else (ri, rj) - B_shared_elem = B_shared_buf[args] + def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): + @T.macro + def _warp_ldmatrix_b( + inst, + B_local_buf, + B_shared_buf, + ki, + thread_bindings, + rk=0, + ): + WARP_SIZE = inst.WARP_SIZE + block_row_warps = inst.block_row_warps + block_col_warps = inst.block_col_warps + warp_col_tiles = inst.warp_col_tiles + warp_cols = inst.warp_cols + chunk = inst.chunk + micro_size_y = inst.micro_size_y + micro_size_k = inst.micro_size_k + local_size_b = inst.local_size_b + b_dtype = inst.b_dtype + transform_kind_b = inst.transform_kind_b + b_transposed = inst.b_transposed + num_elems_per_byte = inst.num_elems_per_byte + + stride = B_shared_buf.shape[-1] + tx = thread_bindings % WARP_SIZE + tz = ( + thread_bindings // (WARP_SIZE * block_row_warps) + ) % block_col_warps + + if transform_kind_b < TransformKind.LDMatrixTransform: + for j in T.serial(warp_cols): + # Assign B_shared_elem + ri, rj = ( + tz * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + ni, nj, nii, njj = ( + (ri) // micro_size_y, + (rj) // micro_size_k, + (ri) % micro_size_y, + (rj) % micro_size_k, + ) + args = ( + (ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj) + ) + B_shared_elem = B_shared_buf[args] + + T.ptx_ldmatrix( + b_dtype, + T.bool(False), # TODO(lei): should be optimized + 4, + ".b16", + B_local_buf.data, + j * local_size_b, + T.address_of(B_shared_elem), + get_ldmatrix_offset( + "B", tx, 0, stride, b_dtype, b_transposed + ), + ) + else: + local_size_dequantize = local_size_b // num_elems_per_byte + for j in T.serial(warp_cols): + for local_id in T.vectorized(local_size_dequantize): + # Assign B_shared_elem + ri, rj = ( + tz * warp_cols + j, + rk * (chunk // micro_size_k) + ki, + ) + rii, rjj = (tx * local_size_dequantize + local_id) // ( + micro_size_k // num_elems_per_byte + ), (tx * local_size_dequantize + local_id) % ( + micro_size_k // num_elems_per_byte + ) + B_local_buf[j * local_size_dequantize + local_id] = ( + B_shared_buf[ri, rj, rii, rjj] + ) + + return _warp_ldmatrix_b( + B_local_buf, B_shared_buf, ki, thread_bindings, rk + ) - T.ptx_ldmatrix( - inst.b_dtype, - T.bool(False), # TODO(lei): should be optimized - 4, - ".b16", + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, B_local_buf.data, - j * inst.local_size_b, - T.address_of(B_shared_elem), - get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, inst.b_transposed), + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), ) - else: - local_size_dequantize = inst.local_size_b // inst.num_elems_per_byte - for j in T.serial(inst.warp_cols): - for local_id in T.vectorized(local_size_dequantize): - # Assign B_shared_elem - ri, rj = tz * inst.warp_cols + j, rk * (inst.chunk // inst.micro_size_k) + ki - rii, rjj = (tx * local_size_dequantize + - local_id) // (inst.micro_size_k // inst.num_elems_per_byte), ( - tx * local_size_dequantize + local_id) % ( - inst.micro_size_k // inst.num_elems_per_byte) - B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii, - rjj] - - @T.macro - def _warp_mma(inst, A_local_buf, B_local_buf, C_local_buf): - for i, j in T.grid(inst.warp_rows, inst.warp_cols): - T.ptx_mma( - inst.accum_dtype, - inst.mma_prefix, - "row", - "col", - inst.a_dtype_abbrv, - inst.b_dtype_abbrv, - inst.accum_dtype_abbrv, - A_local_buf.data, - i * inst.local_size_a, - B_local_buf.data, - j * inst.local_size_b, - C_local_buf.data, - i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out, - T.bool(False), - ) - - T.ptx_mma( - inst.accum_dtype, - inst.mma_prefix, - "row", - "col", - inst.a_dtype_abbrv, - inst.b_dtype_abbrv, - inst.accum_dtype_abbrv, - A_local_buf.data, - i * inst.local_size_a, - B_local_buf.data, - j * inst.local_size_b + lift(inst.local_size_b) // 2, - C_local_buf.data, - i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out + - lift(inst.local_size_out) // 2, - T.bool(False), - ) - def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): - return self._warp_ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk) + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + + j * local_size_out + + lift(local_size_out) // 2, + T.bool(False), + ) - def mma(self, A_local_buf, B_local_buf, C_local_buf): - return self._warp_mma(self, A_local_buf, B_local_buf, C_local_buf) + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) From 4a0afc931f9935ca0700fcd63e381bc09a67fe13 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 1 Nov 2024 08:56:33 +0000 Subject: [PATCH 04/10] lint fix --- bitblas/ops/base_scheduler.py | 10 ++ bitblas/tl/macro_generator.py | 125 ++++++------------ bitblas/tl/mma_layout.py | 10 +- bitblas/tl/utils.py | 42 +++--- format.sh | 4 +- requirements-dev.txt | 2 +- requirements-test.txt | 2 +- .../tilelang/test_tilelang_macro_gemm.py | 31 +++-- 8 files changed, 109 insertions(+), 117 deletions(-) diff --git a/bitblas/ops/base_scheduler.py b/bitblas/ops/base_scheduler.py index acdc057d6..f18c98026 100644 --- a/bitblas/ops/base_scheduler.py +++ b/bitblas/ops/base_scheduler.py @@ -70,3 +70,13 @@ def common_header(self): # TODO(lei): For HIP Backend it should be different common_header = "#include \n" return common_header + + +# Decorator to simplify the output of a function +def simplify_prim_func(func: Callable): + + def wrapper(*args, **kwargs): + stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs) + return BaseScheduler.Simplify(stmt) + + return wrapper diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 349fa3557..63433a52a 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -61,17 +61,13 @@ def __init__( self.chunk = chunk self._initialize_k_dim(a_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) - self._initialize_local_size( - self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE - ) + self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) self._initialize_mma_prefix(self.k_dim) self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k - self.threads = ( - self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k - ) + self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) self.num_elems_per_byte = num_elems_per_byte def _initialize_k_dim(self, a_dtype="float16"): @@ -79,9 +75,7 @@ def _initialize_k_dim(self, a_dtype="float16"): a_dtype = DataType(a_dtype) self.k_dim = 256 // a_dtype.bits - def _initialize_local_size( - self, m_dim=16, n_dim=16, k_dim=16, warp_size=32 - ): + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size @@ -136,20 +130,14 @@ def _warp_ldmatrix_a( ".b16", A_local_buf.data, i * local_size_a, - T.address_of( - A_shared_buf[ - ty * warp_row_tiles + i * micro_size_x, - rk * chunk + ki * micro_size_k, - ] - ), - get_ldmatrix_offset( - "A", tx, 0, stride, a_dtype, a_transposed - ), + T.address_of(A_shared_buf[ + ty * warp_row_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k, + ]), + get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed), ) - return _warp_ldmatrix_a( - A_local_buf, A_shared_buf, ki, thread_bindings, rk - ) + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): @@ -175,9 +163,7 @@ def _warp_ldmatrix_b( ): stride = B_shared_buf.shape[-1] tx = thread_bindings % WARP_SIZE - tz = ( - thread_bindings // (WARP_SIZE * block_row_warps) - ) % block_col_warps + tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps for j in T.serial(warp_cols): # Assign B_shared_elem @@ -195,14 +181,10 @@ def _warp_ldmatrix_b( B_local_buf.data, j * local_size_b, T.address_of(B_shared_elem), - get_ldmatrix_offset( - "B", tx, 0, stride, b_dtype, b_transposed - ), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), ) - return _warp_ldmatrix_b( - B_local_buf, B_shared_buf, ki, thread_bindings, rk - ) + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) def mma(self, A_local_buf, B_local_buf, C_local_buf): warp_rows = self.warp_rows @@ -249,9 +231,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) @@ -273,21 +253,15 @@ def stmatrix(self, C_local_buf, C_shared_buf, thread_bindings): def _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings): tx = thread_bindings % WARP_SIZE ty = (thread_bindings // WARP_SIZE) % block_row_warps - tz = ( - thread_bindings // (WARP_SIZE * block_row_warps) - ) % block_col_warps + tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps for i, j in T.grid(warp_rows, warp_cols): for local_id_o in T.serial(local_size_out // 2): for local_id_i in T.vectorized(2): local_id = local_id_o * 2 + local_id_i row, col = T.meta_var(mma_store_index_map(tx, local_id)) - C_shared_buf[ - ty * warp_rows + i, tz * warp_cols + j, row, col - ] = C_local_buf[ - i * (warp_cols * local_size_out) - + j * local_size_out - + local_id - ] + C_shared_buf[ty * warp_rows + i, tz * warp_cols + j, row, + col] = C_local_buf[i * (warp_cols * local_size_out) + + j * local_size_out + local_id] return _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings) @@ -334,9 +308,7 @@ def __init__( def _initialize_k_dim(self, a_dtype="float16"): self.k_dim = 256 // DataType(a_dtype).bits - def _initialize_local_size( - self, m_dim=16, n_dim=16, k_dim=16, warp_size=32 - ): + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size @@ -380,34 +352,31 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): assert transform_kind_b in [0, 3], "Currently only support 0 and 3" def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + b_dtype = self.b_dtype + transform_kind_b = self.transform_kind_b + b_transposed = self.b_transposed + num_elems_per_byte = self.num_elems_per_byte + @T.macro def _warp_ldmatrix_b( - inst, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0, ): - WARP_SIZE = inst.WARP_SIZE - block_row_warps = inst.block_row_warps - block_col_warps = inst.block_col_warps - warp_col_tiles = inst.warp_col_tiles - warp_cols = inst.warp_cols - chunk = inst.chunk - micro_size_y = inst.micro_size_y - micro_size_k = inst.micro_size_k - local_size_b = inst.local_size_b - b_dtype = inst.b_dtype - transform_kind_b = inst.transform_kind_b - b_transposed = inst.b_transposed - num_elems_per_byte = inst.num_elems_per_byte - stride = B_shared_buf.shape[-1] tx = thread_bindings % WARP_SIZE - tz = ( - thread_bindings // (WARP_SIZE * block_row_warps) - ) % block_col_warps + tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps if transform_kind_b < TransformKind.LDMatrixTransform: for j in T.serial(warp_cols): @@ -422,9 +391,7 @@ def _warp_ldmatrix_b( (ri) % micro_size_y, (rj) % micro_size_k, ) - args = ( - (ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj) - ) + args = ((ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj)) B_shared_elem = B_shared_buf[args] T.ptx_ldmatrix( @@ -435,9 +402,7 @@ def _warp_ldmatrix_b( B_local_buf.data, j * local_size_b, T.address_of(B_shared_elem), - get_ldmatrix_offset( - "B", tx, 0, stride, b_dtype, b_transposed - ), + get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed), ) else: local_size_dequantize = local_size_b // num_elems_per_byte @@ -448,18 +413,14 @@ def _warp_ldmatrix_b( tz * warp_cols + j, rk * (chunk // micro_size_k) + ki, ) - rii, rjj = (tx * local_size_dequantize + local_id) // ( - micro_size_k // num_elems_per_byte - ), (tx * local_size_dequantize + local_id) % ( - micro_size_k // num_elems_per_byte - ) + rii, rjj = (tx * local_size_dequantize + + local_id) // (micro_size_k // num_elems_per_byte), ( + tx * local_size_dequantize + local_id) % ( + micro_size_k // num_elems_per_byte) B_local_buf[j * local_size_dequantize + local_id] = ( - B_shared_buf[ri, rj, rii, rjj] - ) + B_shared_buf[ri, rj, rii, rjj]) - return _warp_ldmatrix_b( - B_local_buf, B_shared_buf, ki, thread_bindings, rk - ) + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) def mma(self, A_local_buf, B_local_buf, C_local_buf): warp_rows = self.warp_rows @@ -506,9 +467,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index 8be21a1d1..719885be5 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -14,15 +14,15 @@ def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col -def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): +def ldmatrix_16x32_to_shared_16x32_layout_a(thread_id, local_id): row = thread_id % 16 - col = local_id + (thread_id // 16) * 16 + col = 16 * (thread_id // 16) + local_id % 16 return row, col -def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): - row = (thread_id // 16) * 8 + (thread_id % 8) - col = local_id + 16 * ((thread_id % 16) // 8) +def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id): + row = 8 * (thread_id // 16) + (thread_id % 8) + col = 16 * ((thread_id % 16) // 8) + local_id % 16 return row, col diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 4b8b4cf6e..053dbe4d5 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -8,8 +8,8 @@ from .mma_layout import ( ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, - ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_16x32_to_shared_16x32_layout_a, + ldmatrix_16x32_to_shared_16x32_layout_b, mma_store_32x8_to_shared_16x16_layout, ) @@ -70,28 +70,40 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) +# the original implementation and insight is from the following code snippet +# 3rdparty/tvm/python/tvm/tir/tensor_intrin/cuda.py#get_ldmatrix_intrin def get_ldmatrix_offset( matrix: Literal["A", "B"], row_idx, col_idx, stride, dtype: Literal["float16", "int8"] = "float16", - transpose: bool = False, + transposed: bool = False, ): assert matrix in ["A", "B"], "matrix should be either A or B" - transform_func = ( - ldmatrix_32x8_to_shared_16x16_layout - if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_b) - transform_func_trans = ( - ldmatrix_trans_32x8_to_shared_16x16_layout - if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_a) - if matrix == "A": - assert not transpose, "A matrix should not be transposed" - new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + dtype_bits = DataType(dtype).bits + if dtype_bits == 16: + transform_func = ldmatrix_32x8_to_shared_16x16_layout + transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout + if transposed: + new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif dtype_bits == 8: + if matrix == "B" and transposed: + transform_func = ldmatrix_16x32_to_shared_16x32_layout_b + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif matrix == "A" and not transposed: + transform_func = ldmatrix_16x32_to_shared_16x32_layout_a + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") else: - new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) - return new_row_idx * stride + new_col_idx + raise ValueError(f"Unsupported dtype {dtype}") def mma_store_index_map(*args, **kwargs): diff --git a/format.sh b/format.sh index c5e81a1ef..5d3056123 100755 --- a/format.sh +++ b/format.sh @@ -148,7 +148,7 @@ echo 'bitblas codespell: Done' echo 'bitblas ruff: Check Start' # Lint specified files lint() { - ruff "$@" + ruff check "$@" } # Lint files that differ from main branch. Ignores dirs that are not slated @@ -170,7 +170,7 @@ lint_changed() { if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - ruff + ruff check fi } diff --git a/requirements-dev.txt b/requirements-dev.txt index 0b09c0856..de7f9d340 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ yapf==0.40.2 toml==0.10.2 tomli==2.0.1 -ruff==0.1.5 +ruff==0.6.5 codespell==2.3.0 cffi diff --git a/requirements-test.txt b/requirements-test.txt index 13fd3d1af..a06a6dd87 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,7 +2,7 @@ yapf==0.40.2 toml==0.10.2 tomli==2.0.1 -ruff==0.1.5 +ruff==0.6.5 codespell==2.3.0 cffi diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index 4d1318960..cc4839568 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -14,6 +14,7 @@ TensorCoreIntrinEmitterWithLadderTransform, ) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 +from bitblas.ops.base_scheduler import simplify_prim_func torch.manual_seed(0) @@ -33,6 +34,7 @@ def transform_func(i, j): return T.Layout(shape, transform_func) +@simplify_prim_func def tl_matmul( M, N, @@ -61,7 +63,8 @@ def tl_matmul( block_col_warps = 1 warp_row_tiles = 16 warp_col_tiles = 16 - chunk = 32 if in_dtype == "float16" else 64 + # chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 shared_scope = "shared.dyn" # Pipeline Stage @@ -84,7 +87,9 @@ def tl_matmul( warp_size = 32 threads = warp_size * (block_row_warps * block_col_warps) - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -113,9 +118,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -181,15 +186,18 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) src_code = mod.imported_modules[0].get_source() - # src_code is the generated cuda source assert src_code is not None - A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + if in_dtype == "int8": + A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) @@ -202,7 +210,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): assert latency is not None # Get Reference Result - ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -873,6 +883,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct def test_assert_tl_matmul(): assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") def test_assert_tl_matmul_with_block_reduce(): From d2f7fcbaa124ffa7befc5c18711869d392798663 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 1 Nov 2024 09:08:49 +0000 Subject: [PATCH 05/10] Refactor tensor core memory allocation in MatmulFineGrainScheduler - Adjusted the local fragment sizes for tensor core memory allocation in the MatmulFineGrainScheduler class. - Updated the allocation sizes for A_local, B_local, and C_local variables based on the new fragment sizes. - The changes ensure efficient memory utilization and improve performance. Refactor tensor core memory allocation in MatmulDequantizeFineGrainedScheduler - Modified the fragment sizes for tensor core memory allocation in the MatmulDequantizeFineGrainedScheduler class. - Updated the allocation sizes for A_frag, B_frag, and C_frag variables based on the new fragment sizes. - The changes optimize memory usage and enhance the efficiency of the dequantization process. Refactor tensor core memory allocation in MatmulDequantizeWeightPropagationScheduler - Adjusted the fragment sizes for tensor core memory allocation in the MatmulDequantizeWeightPropagationScheduler class. - Updated the allocation sizes for A_frag, B_frag, B_dequantize_frag, and C_frag variables based on the new fragment sizes. - The changes improve memory utilization and optimize the weight propagation process. --- .../tilelang/dense/matmul_tensorcore.py | 10 ++++++---- .../dequantize/finegrained_primitive_tensorcore.py | 10 ++++++---- .../dequantize/ladder_weight_transform_tensorcore.py | 12 +++++++----- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index eea256fd9..13658aab4 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -424,7 +424,9 @@ def apply_config( threads = warp_size * (block_row_warps * block_col_warps) # Calculate local fragment sizes for tensor core - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -459,9 +461,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) # Thread-level parallelism for Tensor Cores thread_bindings = T.thread_binding(0, threads, "threadIdx.x") diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index d755ba2f8..d57951455 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -231,7 +231,9 @@ def apply_config( block_K = chunk threads = warp_size * (block_row_warps * block_col_warps) - fragement_size = (micro_size_x * micro_size_y) // warp_size + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -318,9 +320,9 @@ def general_dequant_matmul( B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) - B_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) - C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index bb463e59a..7f8920575 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -71,7 +71,9 @@ def apply_config( block_K = chunk threads = warp_size * (block_row_warps * block_col_warps) - fragement_size = (micro_size_x * micro_size_y) // warp_size + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -173,11 +175,11 @@ def general_dequant_matmul( B_shared = T.alloc_shared(B_shared_shape, storage_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) - B_frag = T.alloc_local((warp_cols * fragement_size // num_elems_per_byte), + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b // num_elems_per_byte), storage_dtype) - B_dequantize_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) - C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + B_dequantize_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) tx = T.thread_binding(0, threads, thread="threadIdx.x") From 2af586d874ee27eee28802e7d7c5a93933e7d795 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 1 Nov 2024 15:55:03 +0000 Subject: [PATCH 06/10] Implement int4 tensorcore --- 3rdparty/tvm | 2 +- bitblas/tl/macro_generator.py | 217 +++++++++ bitblas/tl/utils.py | 4 +- .../tilelang/test_tilelang_gemm_s4_mma.py | 425 ++++++++++++++++++ .../tilelang/test_tilelang_macro_gemm.py | 33 +- 5 files changed, 665 insertions(+), 16 deletions(-) create mode 100644 testing/python/tilelang/test_tilelang_gemm_s4_mma.py diff --git a/3rdparty/tvm b/3rdparty/tvm index e1c5b0897..71fe7ce82 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit e1c5b089737e47a3849afa87df2432c13b633594 +Subproject commit 71fe7ce827396b98a3169343c3744e788a82566c diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 63433a52a..25376f294 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -472,3 +472,220 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): ) return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + +class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = "int4" + b_dtype_abbrv = "int4" + accum_dtype = self.accum_dtype + accum_dtype_abbrv = accum_dtype + mma_prefix = "m16n8k32" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + ''' + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + ''' + + # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + + j * local_size_out + + lift(local_size_out) // 2, + T.bool(False), + ) + + # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_a) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_b) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + + j * local_size_out + + lift(local_size_out) // 2, + T.bool(False), + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + +class INT4TensorCoreIntrinEmitterWithLadderTransform( + TensorCoreIntrinEmitterWithLadderTransform +): + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = "int4" + b_dtype_abbrv = "int4" + accum_dtype = self.accum_dtype + accum_dtype_abbrv = "int32" + mma_prefix = "m16n8k32" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + ''' + A[16, 32], B[16, 32], C[16, 16] + A_local_size -> 16 + B_local_size -> 16 + C_local_size -> 8 + For each m16n8k32 inst + For A: m16k32 consume 16 int4 elements -> 8 A_local_size + For A: n8k32 consume 8 int4 elements -> 4 B_local_size + For C: m16n8 consume 4 int32 elements -> 4 C_local_size + ''' + + # A[0:16, 0:16] * B[0:8, 0:16] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 0:16] * B[8:16, 0:16] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2, + C_local_buf.data, + i * warp_cols * local_size_out + + j * local_size_out + + lift(local_size_out) // 2, + T.bool(False), + ) + + # A[0:16, 16:32] * B[0:8, 16:32] -> C[0:16, 0:8] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_a) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + j * local_size_out, + T.bool(False), + ) + + # A[0:16, 16:32] * B[8:16, 16:32] -> C[0:16, 8:16] + T.ptx_mma( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_local_buf.data, + i * local_size_a + lift(local_size_b) // 2, + B_local_buf.data, + j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, + C_local_buf.data, + i * warp_cols * local_size_out + + j * local_size_out + + lift(local_size_out) // 2, + T.bool(False), + ) + + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 053dbe4d5..2c88bec64 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -120,12 +120,12 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]): return micro_size_x, micro_size_y, micro_size_k -def make_swizzle_layout(shared_buf): +def make_swizzle_layout(shared_buf, is_smooth: bool = False): dtype = shared_buf.dtype shape = shared_buf.shape can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: + if is_smooth or not can_swizzle: return T.Layout(shape, lambda *args: args) def transform_func(i, j): diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py new file mode 100644 index 000000000..bda7334eb --- /dev/null +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -0,0 +1,425 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import ( + make_swizzle_layout, +) + +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter, + INT4TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + K = K // 2 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local( + (warp_rows * warp_cols * local_size_c), accum_dtype + ) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod(compressed_A, compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +@simplify_prim_func +def tl_matmul_weight_only_transform( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + K = K // 2 + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local( + (warp_rows * warp_cols * local_size_c), accum_dtype + ) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + transform_b = 3 + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + LB = ladder_permutate(compressed_B.cpu()).cuda() + + mod(compressed_A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul_weight_only_transform(): + assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32") + + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index cc4839568..346888f4d 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -472,7 +472,9 @@ def tl_matmul_with_ladder_weight_only_transform( warp_size = 32 threads = warp_size * (block_row_warps * block_col_warps) - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -501,9 +503,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") T.annotate_layout({ @@ -667,7 +669,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( warp_size = 32 threads = warp_size * (block_row_warps * block_col_warps) - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -704,10 +708,14 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) - B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local( + (warp_cols * local_size_b // num_elems_per_byte), storage_dtype + ) + B_dequantize_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local( + (warp_rows * warp_cols * local_size_c), accum_dtype + ) reduced_accum_res = T.alloc_local(0, accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") rk = T.thread_binding(0, reduce_k, "threadIdx.y") @@ -765,15 +773,14 @@ def main( ) for j in T.serial(warp_cols): - local_size_b = mma_emitter.local_size_b T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * local_size_b]), 8) + T.address_of(B_local[j * mma_emitter.local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8) mma_emitter.mma(A_local, B_dequantize_local, C_local) if reduce_k > 1: - for n in T.serial(warp_rows * warp_cols * local_size): + for n in T.serial(warp_rows * warp_cols * local_size_c): T.attr( T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), "reduce_scope", From 8f7767bca01e54998292a16edbd15d8f44576d18 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 1 Nov 2024 16:01:31 +0000 Subject: [PATCH 07/10] lint fix --- bitblas/tl/macro_generator.py | 22 ++++--------- .../tilelang/test_tilelang_gemm_s4_mma.py | 31 +++++++------------ .../tilelang/test_tilelang_macro_gemm.py | 16 +++++----- 3 files changed, 24 insertions(+), 45 deletions(-) diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 25376f294..fd8ec43ae 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -473,6 +473,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): def mma(self, A_local_buf, B_local_buf, C_local_buf): @@ -533,9 +534,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) @@ -571,18 +570,14 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) return _warp_mma(A_local_buf, B_local_buf, C_local_buf) -class INT4TensorCoreIntrinEmitterWithLadderTransform( - TensorCoreIntrinEmitterWithLadderTransform -): +class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): def mma(self, A_local_buf, B_local_buf, C_local_buf): @@ -643,9 +638,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) @@ -681,11 +674,8 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) - return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index bda7334eb..37c210b91 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -9,8 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import ( - make_swizzle_layout, -) + make_swizzle_layout,) from bitblas.tl.macro_generator import ( INT4TensorCoreIntrinEmitter, @@ -20,6 +19,7 @@ torch.manual_seed(0) + @simplify_prim_func def tl_matmul( M, @@ -61,8 +61,8 @@ def tl_matmul( block_N = block_col_warps * warp_col_tiles block_K = chunk - A_shape = (M, K) # int8 storage represents int4*2 - B_shape = (N, K) # int8 storage represents int4*2 + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) C_shared_shape = ( @@ -107,9 +107,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -180,7 +178,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None - A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) @@ -196,9 +193,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -245,7 +240,7 @@ def tl_matmul_weight_only_transform( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk - + is_smooth_a = False can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) @@ -291,6 +286,7 @@ def tl_matmul_weight_only_transform( chunk=chunk, transform_kind_b=transform_b, ) + @T.prim_func def main( A: T.Buffer(A_shape, in_dtype), @@ -304,9 +300,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -379,7 +373,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt # src_code is the generated cuda source assert src_code is not None transform_b = 3 - + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) @@ -408,9 +402,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -420,6 +412,5 @@ def test_assert_tl_matmul_weight_only_transform(): assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32") - if __name__ == "__main__": bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index 346888f4d..4c4cf8f59 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -709,13 +709,9 @@ def main( B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local( - (warp_cols * local_size_b // num_elems_per_byte), storage_dtype - ) + B_local = T.alloc_local((warp_cols * local_size_b // num_elems_per_byte), storage_dtype) B_dequantize_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") rk = T.thread_binding(0, reduce_k, "threadIdx.y") @@ -773,9 +769,11 @@ def main( ) for j in T.serial(warp_cols): - T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * mma_emitter.local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8) + T.call_extern( + 'handle', 'decode_i4u_to_f16', + T.address_of(B_local[j * mma_emitter.local_size_b // + num_elems_per_byte]), + T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8) mma_emitter.mma(A_local, B_dequantize_local, C_local) From 85a93085977a03ac5b1fccd7e34810cb12a3aff2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 2 Nov 2024 09:12:16 +0000 Subject: [PATCH 08/10] support uint2->uint4 fast dequantize --- .../cpp/lop3_type_conversion/CMakeLists.txt | 1 + .../lop3_type_conversion/fast_decoding.hpp | 86 +++++++ .../lowprecision_to_int4.cu | 234 ++++++++++++++++++ 3 files changed, 321 insertions(+) create mode 100644 testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu diff --git a/testing/cpp/lop3_type_conversion/CMakeLists.txt b/testing/cpp/lop3_type_conversion/CMakeLists.txt index 61903faf4..8b104ca47 100644 --- a/testing/cpp/lop3_type_conversion/CMakeLists.txt +++ b/testing/cpp/lop3_type_conversion/CMakeLists.txt @@ -10,3 +10,4 @@ endfunction(ADD_CUDA_TEST_EXECUTABLE) ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_float16) ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_int8) +ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_int4) diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp index 6d5b6335a..fbc7f12c1 100644 --- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -797,3 +797,89 @@ __device__ void decode_i1u_to_i8s(T1 *_i1u, T2 *B_local_decode, const int N = 16 { decode_i1b_to_i8s(_i1u, B_local_decode, N); } + + +void general_interleave_int4(int8_t *origin_arr, int8_t *interleaved, const int nbit, size_t size_in_bytes, bool verbose = false) +{ + // For int4 example + // i2s {e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // |-----8b-----||-----8b-----||-----8b-----||-----8b-----| + // 0b00110011 0b00110011 0b00110011 0b00110011 + // interleave {e15,e7,e14,e6,e13,e5,e12,e4,e11,e3,e10,e2,e9,e1,e8,e0} + + size_t size = size_in_bytes / sizeof(int32_t); + int32_t *int32_origin = (int32_t *)origin_arr; + int32_t *int32_interleaved = (int32_t *)interleaved; + + constexpr int bits_stride = 4; + int elems_per_group = bits_stride / nbit; + int mask = (1 << nbit) - 1; + int num_groups = 32 / bits_stride; + + for (int idx = 0; idx < size; ++idx) + { + int32_t current_value = int32_origin[idx]; + int32_t new_value = 0; + for (int i = 0; i < num_groups; ++i) + { + for (int j = 0; j < elems_per_group; ++j) + { + int offset = i * elems_per_group + j; + int shift = (offset % num_groups) * bits_stride + (offset / num_groups) * nbit; + int group_value = (current_value >> (nbit * (i * elems_per_group + j))) & mask; + new_value |= group_value << shift; + if (verbose) + printf("put %d to %d\n", offset, shift); + } + } + if (nbit == 1) + { + throw std::runtime_error("Not implemented"); + } + else + int32_interleaved[idx] = new_value; + } + + // Convert back to int8_t if needed + memcpy(interleaved, int32_interleaved, size * sizeof(int32_t)); +} + + +template +__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16) +{ + uint *i4s = reinterpret_cast(_i4s); + uint *i2b = reinterpret_cast(_i2b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000; + +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i4s[i]) + : "r"(i2b[0] >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + // TODO(lei): uint4 sub should be enhanced. + i4s[i] = __vsubss4(i4s[i], MEDIAN_NUM); + } + } +} + +template +__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4u, B_local_decode, N); +} + diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu new file mode 100644 index 000000000..a044c9474 --- /dev/null +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu @@ -0,0 +1,234 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "fast_decoding.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +#define REGISTER_GLOBAL_DEVICE_INVOKER(kernel, function) \ + template \ + __global__ void kernel(Args... args) \ + { \ + function(args...); \ + } + +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2s_to_i4s, decode_i2s_to_i4s) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_i4s, decode_i2u_to_i4s) + +// TEST(DecodeTest, DecodeInt4ToINT8) +// { +// using target_dtype = int8_t; +// constexpr int nbits = 2; +// constexpr int N = 32 / nbits; +// constexpr int QN = N / 8 * nbits; +// constexpr bool isSigned = true; +// constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + +// // create four int8_t values +// int8_t in_data[N] = { +// 0, +// }; +// // breed seed +// srand(0); + +// // random initializations with nbits range +// for (int i = 0; i < N; i++) +// { +// in_data[i] = (rand() % (1 << nbits)) - zero_point; +// } + +// // print input data +// for (int i = 0; i < N; i++) +// { +// printf("i:%d %d %x \n", i, in_data[i], in_data[i]); +// } + +// int8_t *ins = new int8_t[QN]; +// for (int i = 0; i < QN; i++) +// { +// ins[i] = (in_data[i * 4] & 0x3) | ((in_data[i * 4 + 1] & 0x3) << 2) | ((in_data[i * 4 + 2] & 0x3) << 4) | ((in_data[i * 4 + 3] & 0x3) << 6); +// } +// // print input data +// printf("ins \n"); +// for (int i = 0; i < QN; i++) +// { +// printf("i:%d %d %x b: ", i, ins[i], ins[i]); +// for (int j = 7; j >= 0; j--) +// { +// printf("%d", (ins[i] >> j) & 1); +// } +// printf("\n"); +// } +// printf("\n"); +// int8_t *interleaved = new int8_t[QN]; +// general_interleave_int4(ins, interleaved, 2, QN * sizeof(int8_t), true); +// printf("interleaved \n"); +// for (int i = 0; i < QN; i++) +// { +// printf("i:%d %d %x b: ", i, interleaved[i], interleaved[i]); +// for (int j = 7; j >= 0; j--) +// { +// printf("%d", (interleaved[i] >> j) & 1); +// } +// printf("\n"); +// } +// target_dtype *decoded = new target_dtype[N]; +// int8_t *ins_gpu; +// target_dtype *decoded_gpu; + +// cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); +// cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); +// cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); +// cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); +// cudaCheckLastError(cudaDeviceSynchronize()); + +// kernelWrapper_i2s_to_i4s<<>>(ins_gpu, decoded_gpu); +// cudaCheckLastError(cudaDeviceSynchronize()); +// cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); +// cudaCheckLastError(cudaFree(ins_gpu)); +// cudaCheckLastError(cudaFree(decoded_gpu)); +// printf("decoded \n"); +// for (int i = 0; i < (N / 2); i++) +// { +// printf("i %d %d %x \n", i, decoded[i], decoded[i]); +// } +// // output data int8 +// int8_t i8_out[N] = { +// 0, +// }; +// for (int i = 0; i < N; i++) +// { +// i8_out[i] = (decoded[i / 2] >> (4 * (i % 2)) ) & 0xf; +// } +// printf("i8_out \n"); +// for (int i = 0; i < N; i++) +// { +// printf("i %d in_data: %d %x decode_data: %d %x \n", i, in_data[i], in_data[i], i8_out[i], i8_out[i]); +// } +// for (int i = 0; i < (N / 2); i++) +// { +// EXPECT_EQ(in_data[i], int(i8_out[i])); +// } +// free(ins); +// free(interleaved); +// free(decoded); +// } + + +// int32 -> 16 int2 -> 4 int8 +// -> 16 int4 -> 8 int8 +TEST(DecodeTest, DecodeUInt4ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + // in_data[i] = (i % 2); + // in_data[i] = 1; + } + + // print input data + for (int i = 0; i < N; i++) + { + printf("i:%d %d %x \n", i, in_data[i], in_data[i]); + } + + int8_t *ins = new int8_t[QN]; + for (int i = 0; i < QN; i++) + { + ins[i] = (in_data[i * 4] & 0x3) | ((in_data[i * 4 + 1] & 0x3) << 2) | ((in_data[i * 4 + 2] & 0x3) << 4) | ((in_data[i * 4 + 3] & 0x3) << 6); + } + // print input data + printf("ins \n"); + for (int i = 0; i < QN; i++) + { + printf("i:%d %d %x b: ", i, ins[i], ins[i]); + for (int j = 7; j >= 0; j--) + { + printf("%d", (ins[i] >> j) & 1); + } + printf("\n"); + } + printf("\n"); + int8_t *interleaved = new int8_t[QN]; + general_interleave_int4(ins, interleaved, 2, QN * sizeof(int8_t), true); + printf("interleaved \n"); + for (int i = 0; i < QN; i++) + { + printf("i:%d %d %x b: ", i, interleaved[i], interleaved[i]); + for (int j = 7; j >= 0; j--) + { + printf("%d", (interleaved[i] >> j) & 1); + } + printf("\n"); + } + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2u_to_i4s<<>>(ins_gpu, decoded_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + printf("decoded \n"); + for (int i = 0; i < (N / 2); i++) + { + printf("i %d %d %x \n", i, decoded[i], decoded[i]); + } + // output data int8 + int8_t i8_out[N] = { + 0, + }; + for (int i = 0; i < N; i++) + { + i8_out[i] = (decoded[i / 2] >> (4 * (i % 2)) ) & 0xf; + } + printf("i8_out \n"); + for (int i = 0; i < N; i++) + { + printf("i %d in_data: %d %x decode_data: %d %x \n", i, in_data[i], in_data[i], i8_out[i], i8_out[i]); + } + for (int i = 0; i < (N / 2); i++) + { + EXPECT_EQ(in_data[i], int(i8_out[i])); + } + free(ins); + free(interleaved); + free(decoded); +} From 9f9f397e3a3bf040a5535cd4f847efb051f1d468 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 2 Nov 2024 17:16:33 +0000 Subject: [PATCH 09/10] Support int4 tensorcore decoding --- 3rdparty/tvm | 2 +- bitblas/gpu/intrin/lop3.py | 40 +++ .../ops/lop3_permutate/lop3_permutate_impl.py | 2 + bitblas/tl/utils.py | 20 ++ .../BitNet/int4_kernel/tl_int4xint2.py | 276 +++++++++++++++ .../tl_int4xint2_ladder_weight_only.py | 325 ++++++++++++++++++ .../BitNet/int4_kernel/tl_int4xint4.py | 218 ++++++++++++ .../tl_int4xint4_ladder_weight_only.py | 243 +++++++++++++ .../BitNet/int4_kernel/tl_int8xint8.py | 231 +++++++++++++ .../tl_int8xint8_ladder_weight_only.py | 261 ++++++++++++++ .../lop3_type_conversion/fast_decoding.hpp | 5 +- .../lowprecision_to_int4.cu | 1 + 12 files changed, 1621 insertions(+), 3 deletions(-) create mode 100644 integration/BitNet/int4_kernel/tl_int4xint2.py create mode 100644 integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py create mode 100644 integration/BitNet/int4_kernel/tl_int4xint4.py create mode 100644 integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py create mode 100644 integration/BitNet/int4_kernel/tl_int8xint8.py create mode 100644 integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 71fe7ce82..be013f6d5 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 71fe7ce827396b98a3169343c3744e788a82566c +Subproject commit be013f6d5e623e1787351aac897e270970e33ada diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index 8d60c7651..84ee05034 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -978,6 +978,46 @@ } """ +decode_i2s_to_i4s = r""" +template +__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16) +{ + uint *i4s = reinterpret_cast(_i4s); + uint *i2b = reinterpret_cast(_i2b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000; + +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i4s[i]) + : "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + // TODO(lei): uint4 sub should be enhanced. + // 0x03 0x03 0x03 0x03 + i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; + } + } +} + +template +__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4u, B_local_decode, N); +} +""" def get_fast_decode_intrin( source_bit=4, diff --git a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py index 07d8f4f0c..94ddd13c6 100644 --- a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py +++ b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py @@ -126,6 +126,8 @@ def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer(( return interleave_weight_f16_1b elif target_dtype == "int8" and bits == 1: return interleave_weight_int8_1b + elif target_dtype == "int4" and bits == 2: + pass return interleave_weight diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 2c88bec64..f00a5937f 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -133,3 +133,23 @@ def transform_func(i, j): return [new_warp_i, new_warp_j] return T.Layout(shape, transform_func) + +def index_to_coordinates(index, shape): + ''' + General Implementation of: + vjj = index % (micro_size_k // num_elems_per_byte) + coordinates[-1] = index % shape[-1]; + vii = index // (micro_size_k // num_elems_per_byte) % micro_size_y + index = index // shape[-1]; coordinates[-2] = index % shape[-2]; + vj = index // (micro_size_k // num_elems_per_byte * micro_size_y) % block_K // (micro_size_k // num_elems_per_byte) + index = index // shape[-2]; coordinates[-3] = index % shape[-3]; + vi = index // (micro_size_k // num_elems_per_byte * micro_size_y * (block_K // (micro_size_k // num_elems_per_byte))) % block_N // micro_size_y + index = index // shape[-3]; coordinates[-4] = index % shape[-4]; + ''' + coordinates = [] + dims = len(shape) + for i in range(dims): + coordinates.append(index % shape[dims - i - 1]) + index = index // shape[dims - i - 1] + coordinates.reverse() + return coordinates diff --git a/integration/BitNet/int4_kernel/tl_int4xint2.py b/integration/BitNet/int4_kernel/tl_int4xint2.py new file mode 100644 index 000000000..d40852538 --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint2.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import ( + make_swizzle_layout, + index_to_coordinates +) +from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s + +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter, +) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + K = K // 2 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == "int32": + micro_size_k = 32 + + num_elems_per_byte = 2 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + storage_dtype = "int8" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, + prelude=decode_i2s_to_i4s) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local( + (warp_rows * warp_cols * fragement_size_c), accum_dtype + ) + + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_frag) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): + B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = ( + i * threads * local_size_compressed + thread_bindings * local_size_compressed + + v) + vi, vj = index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj] + + if fast_decoding: + T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), T.address_of(B_dequantize_local[0]), 32) + else: + for v in T.serial(0, local_size): + int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F + + int4_0 = (int2x2_value >> 0) & 0x03 + int4_1 = (int2x2_value >> 2) & 0x03 + + B_dequantize_local[v] = (int4_1 << 4) | int4_0 + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + thread_bindings * local_size + v + vi, vj = index_to_coordinates(index, B_dequantize_shared_shape) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Perform STMatrix + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(matmul) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ((B[:, 3::4] & 0x03) << 6) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + print(f"{compressed_B=}") + lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() + print(f"{lop3_compressed_B=}") + mod(compressed_A, lop3_compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py new file mode 100644 index 000000000..89b3aec90 --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py @@ -0,0 +1,325 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import make_swizzle_layout, index_to_coordinates +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, +): + K = K // 2 + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + num_elems_per_byte = 2 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + # shared_scope = "shared" + storage_dtype = "int8" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + B_dequantize_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + + vec_load_qb = 16 + if block_N * (block_K) // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * (block_K) // num_elems_per_byte // threads + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i2s_to_i4s) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local( + (warp_rows * warp_cols * fragement_size_c), accum_dtype + ) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + # B_dequantize_shared: make_swizzle_layout(B_dequantize_shared, is_smooth=True), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_frag) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + t = thread_bindings + idx = i * threads * vec_load_qb + threads * vec_load_qb + t * vec_load_qb + v + vj, vk, vjj, vkk = index_to_coordinates(idx, B_shared_shape) + B_shared[vj, vk, vjj, + vkk] = B[bx * (block_N // micro_size_y) + vj, + ko * (block_K // micro_size_k) + vk, vjj, vkk] + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = ( + i * threads * local_size_compressed + thread_bindings * local_size_compressed + + v) + vi, vj, vii, vjj = index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj, vii, vjj] + + if fast_decoding: + # Simulated dequantization + T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), T.address_of(B_dequantize_local[0]), 32) + else: + for v in T.serial(0, local_size): + int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F + + int4_0 = (int2x2_value >> 0) & 0x03 + int4_1 = (int2x2_value >> 2) & 0x03 + + B_dequantize_local[v] = (int4_1 << 4) | int4_0 + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + thread_bindings * local_size + v + vi, vj, vii, vjj = index_to_coordinates(index, B_dequantize_shared_shape) + B_dequantize_shared[vi, vj, vii, vjj] = B_dequantize_local[v] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Perform STMatrix + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(matmul) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + print(src_code) + assert src_code is not None + transform_b = 3 + + # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() + ladder_shape = compressed_B_ladder.shape + int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) + int2_tensor = torch.zeros(int2_shape, device="cuda", dtype=torch.int8) + for i in range(int2_tensor.shape[-1]): + int2_tensor[..., i] = (compressed_B_ladder[..., 2 * i] & 0x03) | ((compressed_B_ladder[..., 2 * i] >> 4) & 0x03) << 2 | ((compressed_B_ladder[..., 2 * i + 1] & 0x03) << 4) | ((compressed_B_ladder[..., 2 * i + 1] >> 4) << 6) + + raw_tensor_shape = int2_tensor.shape + print(f"{raw_tensor_shape=}") + if fast_decoding: + lop3_compressed_B = lop3_permutate(int2_tensor.cpu()).cuda() + lop3_compressed_B = lop3_compressed_B.view(raw_tensor_shape) + else: + lop3_compressed_B = int2_tensor + + mod(compressed_A, lop3_compressed_B, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32", False) diff --git a/integration/BitNet/int4_kernel/tl_int4xint4.py b/integration/BitNet/int4_kernel/tl_int4xint4.py new file mode 100644 index 000000000..3419e0b1f --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint4.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import ( + make_swizzle_layout, +) + +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter, +) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + K = K // 2 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local( + (warp_rows * warp_cols * local_size_c), accum_dtype + ) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + # print(src_code) + # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod(compressed_A, compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py new file mode 100644 index 000000000..cb0e85e81 --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import make_swizzle_layout +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + K = K // 2 + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local( + (warp_rows * warp_cols * local_size_c), accum_dtype + ) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + # B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + transform_b = 3 + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + LB = ladder_permutate(compressed_B.cpu()).cuda() + + mod(compressed_A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int8xint8.py b/integration/BitNet/int4_kernel/tl_int8xint8.py new file mode 100644 index 000000000..a58b7ce22 --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int8xint8.py @@ -0,0 +1,231 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local( + (warp_rows * warp_cols * local_size_c), accum_dtype + ) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + if in_dtype == "int8": + A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py new file mode 100644 index 000000000..197513abc --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py @@ -0,0 +1,261 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local( + (warp_rows * warp_cols * local_size_c), accum_dtype + ) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + transform_b = 3 + + if in_dtype == "int8": + A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + LB = ladder_permutate(B.cpu()).cuda() + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( + getattr(torch, accum_dtype) + ) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp index fbc7f12c1..e6f8b2923 100644 --- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -862,11 +862,12 @@ __device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16) // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(i4s[i]) - : "r"(i2b[0] >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + : "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); if constexpr (isSigned) { // TODO(lei): uint4 sub should be enhanced. - i4s[i] = __vsubss4(i4s[i], MEDIAN_NUM); + // 0x03 0x03 0x03 0x03 + i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; } } } diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu index a044c9474..d39a85dcd 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu @@ -53,6 +53,7 @@ REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_i4s, decode_i2u_to_i4s) // } // // print input data +// printf("in_data \n"); // for (int i = 0; i < N; i++) // { // printf("i:%d %d %x \n", i, in_data[i], in_data[i]); From ff770903e586ab6fac7f62ea77806cca79ed2183 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 2 Nov 2024 17:17:37 +0000 Subject: [PATCH 10/10] lint fix --- bitblas/gpu/intrin/lop3.py | 1 + bitblas/tl/utils.py | 1 + .../BitNet/int4_kernel/tl_int4xint2.py | 53 +++++++++---------- .../tl_int4xint2_ladder_weight_only.py | 50 +++++++++-------- .../BitNet/int4_kernel/tl_int4xint4.py | 21 +++----- .../tl_int4xint4_ladder_weight_only.py | 17 +++--- .../BitNet/int4_kernel/tl_int8xint8.py | 16 ++---- .../tl_int8xint8_ladder_weight_only.py | 16 ++---- 8 files changed, 78 insertions(+), 97 deletions(-) diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index 84ee05034..75f4b1757 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -1019,6 +1019,7 @@ } """ + def get_fast_decode_intrin( source_bit=4, storage_dtype="int8", diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index f00a5937f..18f0d3274 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -134,6 +134,7 @@ def transform_func(i, j): return T.Layout(shape, transform_func) + def index_to_coordinates(index, shape): ''' General Implementation of: diff --git a/integration/BitNet/int4_kernel/tl_int4xint2.py b/integration/BitNet/int4_kernel/tl_int4xint2.py index d40852538..16797501a 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2.py @@ -8,19 +8,16 @@ from tvm import DataType from tvm import tl as TL import tvm.tl.language as T -from bitblas.tl.utils import ( - make_swizzle_layout, - index_to_coordinates -) +from bitblas.tl.utils import (make_swizzle_layout, index_to_coordinates) from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s from bitblas.tl.macro_generator import ( - INT4TensorCoreIntrinEmitter, -) + INT4TensorCoreIntrinEmitter,) from bitblas.ops.base_scheduler import simplify_prim_func torch.manual_seed(0) + @simplify_prim_func def tl_matmul( M, @@ -61,7 +58,7 @@ def tl_matmul( chunk = 32 if in_dtype == "float16" else 64 shared_scope = "shared.dyn" storage_dtype = "int8" - + # Pipeline Stage stage = 2 @@ -69,8 +66,8 @@ def tl_matmul( block_N = block_col_warps * warp_col_tiles block_K = chunk - A_shape = (M, K) # int8 storage represents int4*2 - B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2 + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2 A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @@ -109,22 +106,24 @@ def main( B: T.Buffer(B_shape, storage_dtype), C: T.Buffer((M, N), out_dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, prelude=decode_i2s_to_i4s) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared( + B_dequantize_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) - C_frag = T.alloc_local( - (warp_rows * warp_cols * fragement_size_c), accum_dtype - ) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) - + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") T.annotate_layout({ @@ -136,7 +135,7 @@ def main( T.use_swizzle(panel_size=10) T.clear(C_frag) - + for ko in T.Pipelined((K // block_K), num_stages=stage): # Load A into shared memory @@ -148,16 +147,17 @@ def main( B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = ( - i * threads * local_size_compressed + thread_bindings * local_size_compressed + - v) + i * threads * local_size_compressed + + thread_bindings * local_size_compressed + v) vi, vj = index_to_coordinates(index, B_shared_shape) B_local[v] = B_shared[vi, vj] - + if fast_decoding: - T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), T.address_of(B_dequantize_local[0]), 32) + T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), 32) else: for v in T.serial(0, local_size): int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F @@ -166,7 +166,7 @@ def main( int4_1 = (int2x2_value >> 2) & 0x03 B_dequantize_local[v] = (int4_1 << 4) | int4_0 - + for v in T.vectorized(0, local_size): index = i * threads * local_size + thread_bindings * local_size + v vi, vj = index_to_coordinates(index, B_dequantize_shared_shape) @@ -224,7 +224,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) - + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( M=N, N=K, @@ -240,7 +240,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) - compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ((B[:, 3::4] & 0x03) << 6) + compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( + (B[:, 3::4] & 0x03) << 6) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) print(f"{compressed_B=}") @@ -254,9 +255,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) diff --git a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py index 89b3aec90..d44717e7f 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py @@ -10,13 +10,13 @@ import tvm.tl.language as T from bitblas.tl.utils import make_swizzle_layout, index_to_coordinates from bitblas.tl.macro_generator import ( - INT4TensorCoreIntrinEmitterWithLadderTransform, -) + INT4TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s from bitblas.ops.base_scheduler import simplify_prim_func torch.manual_seed(0) + @simplify_prim_func def tl_matmul( M, @@ -66,14 +66,15 @@ def tl_matmul( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk - + is_smooth_a = False can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) pad_factor = 8 A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, + micro_size_k // num_elems_per_byte) A_shared_shape = ( block_M, (block_K + pad_factor) if apply_pad_a else block_K, @@ -118,7 +119,7 @@ def tl_matmul( chunk=chunk, transform_kind_b=transform_b, ) - + vec_load_qb = 16 if block_N * (block_K) // num_elems_per_byte // threads < vec_load_qb: vec_load_qb = block_N * (block_K) // num_elems_per_byte // threads @@ -129,17 +130,20 @@ def main( B: T.Buffer(B_shape, storage_dtype), C: T.Buffer((M, N), out_dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i2s_to_i4s) as (bx, by): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i4s) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) - B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared( + B_dequantize_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) - C_frag = T.alloc_local( - (warp_rows * warp_cols * fragement_size_c), accum_dtype - ) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) @@ -174,17 +178,18 @@ def main( ko * (block_K // micro_size_k) + vk, vjj, vkk] for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * local_size_compressed)): + (threads * local_size_compressed)): for v in T.vectorized(0, local_size_compressed): index = ( - i * threads * local_size_compressed + thread_bindings * local_size_compressed + - v) + i * threads * local_size_compressed + + thread_bindings * local_size_compressed + v) vi, vj, vii, vjj = index_to_coordinates(index, B_shared_shape) B_local[v] = B_shared[vi, vj, vii, vjj] - + if fast_decoding: - # Simulated dequantization - T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), T.address_of(B_dequantize_local[0]), 32) + # Simulated dequantization + T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), 32) else: for v in T.serial(0, local_size): int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F @@ -193,7 +198,7 @@ def main( int4_1 = (int2x2_value >> 2) & 0x03 B_dequantize_local[v] = (int4_1 << 4) | int4_0 - + for v in T.vectorized(0, local_size): index = i * threads * local_size + thread_bindings * local_size + v vi, vj, vii, vjj = index_to_coordinates(index, B_dequantize_shared_shape) @@ -248,7 +253,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast print(src_code) assert src_code is not None transform_b = 3 - + # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) @@ -286,7 +291,10 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) int2_tensor = torch.zeros(int2_shape, device="cuda", dtype=torch.int8) for i in range(int2_tensor.shape[-1]): - int2_tensor[..., i] = (compressed_B_ladder[..., 2 * i] & 0x03) | ((compressed_B_ladder[..., 2 * i] >> 4) & 0x03) << 2 | ((compressed_B_ladder[..., 2 * i + 1] & 0x03) << 4) | ((compressed_B_ladder[..., 2 * i + 1] >> 4) << 6) + int2_tensor[..., i] = (compressed_B_ladder[..., 2 * i] & 0x03) | ( + (compressed_B_ladder[..., 2 * i] >> 4) & 0x03) << 2 | ( + (compressed_B_ladder[..., 2 * i + 1] & 0x03) << 4) | ( + (compressed_B_ladder[..., 2 * i + 1] >> 4) << 6) raw_tensor_shape = int2_tensor.shape print(f"{raw_tensor_shape=}") @@ -304,9 +312,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) diff --git a/integration/BitNet/int4_kernel/tl_int4xint4.py b/integration/BitNet/int4_kernel/tl_int4xint4.py index 3419e0b1f..5b040db89 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint4.py +++ b/integration/BitNet/int4_kernel/tl_int4xint4.py @@ -4,21 +4,18 @@ import torch import torch.backends from bitblas import tvm as tvm -import bitblas.testing -from tvm import DataType from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import ( - make_swizzle_layout, -) + make_swizzle_layout,) from bitblas.tl.macro_generator import ( - INT4TensorCoreIntrinEmitter, -) + INT4TensorCoreIntrinEmitter,) from bitblas.ops.base_scheduler import simplify_prim_func torch.manual_seed(0) + @simplify_prim_func def tl_matmul( M, @@ -60,8 +57,8 @@ def tl_matmul( block_N = block_col_warps * warp_col_tiles block_K = chunk - A_shape = (M, K) # int8 storage represents int4*2 - B_shape = (N, K) # int8 storage represents int4*2 + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) C_shared_shape = ( @@ -106,9 +103,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -196,9 +191,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) diff --git a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py index cb0e85e81..1603698b2 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py @@ -10,9 +10,7 @@ import tvm.tl.language as T from bitblas.tl.utils import make_swizzle_layout from bitblas.tl.macro_generator import ( - INT4TensorCoreIntrinEmitterWithLadderTransform, -) -from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + INT4TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.ops.base_scheduler import simplify_prim_func torch.manual_seed(0) @@ -59,7 +57,7 @@ def tl_matmul( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk - + is_smooth_a = False can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) @@ -105,6 +103,7 @@ def tl_matmul( chunk=chunk, transform_kind_b=transform_b, ) + @T.prim_func def main( A: T.Buffer(A_shape, in_dtype), @@ -118,9 +117,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -193,7 +190,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None transform_b = 3 - + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) @@ -222,9 +219,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) diff --git a/integration/BitNet/int4_kernel/tl_int8xint8.py b/integration/BitNet/int4_kernel/tl_int8xint8.py index a58b7ce22..e809c673e 100644 --- a/integration/BitNet/int4_kernel/tl_int8xint8.py +++ b/integration/BitNet/int4_kernel/tl_int8xint8.py @@ -4,16 +4,12 @@ import torch import torch.backends from bitblas import tvm as tvm -import bitblas.testing from tvm import DataType from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.macro_generator import ( - TensorCoreIntrinEmitter, - TensorCoreIntrinEmitterWithLadderTransform, -) -from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + TensorCoreIntrinEmitter,) from bitblas.ops.base_scheduler import simplify_prim_func torch.manual_seed(0) @@ -119,9 +115,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -204,16 +198,14 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) mod(A, B, C) - + latency = mod.do_bench(mod.func, warmup=25) print(f"Latency: {latency}") # Ensure that the latency is not None assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) diff --git a/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py index 197513abc..733441f2f 100644 --- a/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py @@ -10,10 +10,7 @@ import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.macro_generator import ( - TensorCoreIntrinEmitter, - TensorCoreIntrinEmitterWithLadderTransform, -) -from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.ops.base_scheduler import simplify_prim_func torch.manual_seed(0) @@ -74,7 +71,7 @@ def tl_matmul( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk - + is_smooth_a = False can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) @@ -120,6 +117,7 @@ def tl_matmul( chunk=chunk, transform_kind_b=transform_b, ) + @T.prim_func def main( A: T.Buffer(A_shape, in_dtype), @@ -133,9 +131,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -241,9 +237,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)