Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and
|:-----------:|:-----------:|:---------------:|:---------------:|:----------------------:|:----------------------:|
| FP16 | FP16 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | FP4_E2M1 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | FP8_E4M3 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | INT8 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | UINT4/INT4 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | UINT2/INT2 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
Expand Down
8 changes: 4 additions & 4 deletions python/bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -213,8 +213,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down
12 changes: 6 additions & 6 deletions python/bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -633,8 +633,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -1123,8 +1123,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down
89 changes: 30 additions & 59 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from typing import Any, List, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind
from .impl.matmul_dequantize_impl import (
select_implementation as weight_dequantize_implementation,
)
select_implementation as weight_dequantize_implementation,)
from .impl.matmul_impl import select_implementation as consistent_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils.target_detector import auto_detect_nvidia_target
Expand Down Expand Up @@ -110,36 +109,23 @@ def __legalize_dynamic_symbolic(self, M):

def __legalize_propagate(self, propagate):
if isinstance(propagate, bool):
return (
TransformKind.IntraWarpTransform
if propagate
else TransformKind.NonTransform
)
return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform)
elif isinstance(propagate, int):
return TransformKind(propagate)

return propagate

def __initialize_propagate(
self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]
):
def __initialize_propagate(self, propagate_a: Optional[TransformKind],
propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
if (
isinstance(self.M, int)
and (self.M % MICRO_KERNEL_SIZE) == 0
and (self.K % MICRO_KERNEL_SIZE) == 0
):
if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and
(self.K % MICRO_KERNEL_SIZE) == 0):
object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform)
else:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)

if (
self.M == 1
or (self.N % MICRO_KERNEL_SIZE) != 0
or (self.K % MICRO_KERNEL_SIZE) != 0
or isinstance(self.M, Tuple)
or (self.with_zeros and self.zeros_mode == "quantized")
):
if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or
isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")):
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)
else:
Expand All @@ -164,10 +150,7 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):
if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)
elif (
"int" not in self.W_dtype
or self.W_dtype == self.A_dtype
):
elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype):
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand All @@ -186,12 +169,8 @@ def __post_init__(self):
object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M))

# set propagate_a and propagate_b to default value if it is None
object.__setattr__(
self, "propagate_a", self.__legalize_propagate(self.propagate_a)
)
object.__setattr__(
self, "propagate_b", self.__legalize_propagate(self.propagate_b)
)
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
Expand All @@ -214,10 +193,10 @@ def __post_init__(self):
object.__setattr__(self, "with_zeros", False)

if self.A_dtype == self.W_dtype and self.W_dtype in [
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)

Expand All @@ -242,10 +221,9 @@ class Matmul(Operator):
"int1": ("int", 1),
"uint1": ("uint", 1),
"nf4": ("nf", 4),
"fp8_e5m2": ("fp", 8),
"fp4_e2m1": ("fp", 4),
"e4m3_float8": ("fp", 8), # "e4m3_float8" is a trick for "float8_e4m3fn"
"e5m2_float8": ("fp", 8),
"e4m3_float8": ("fp_e4m3", 8), # "e4m3_float8" is a trick for "float8_e4m3fn"
"e5m2_float8": ("fp_e5m2", 8),
}

def __init__(
Expand All @@ -261,9 +239,8 @@ def __init__(
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")
assert (
config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP
), f"Unsupported input dtype {config.A_dtype}"
assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]

self.source_format = source_format
Expand All @@ -284,8 +261,7 @@ def __init__(
if isinstance(self.M, Tuple):
self.dynamic_range = {"m": self.M}
self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs(
{"opt_shapes": self.dynamic_range}
)
{"opt_shapes": self.dynamic_range})
else:
self.dynamic_range = None

Expand Down Expand Up @@ -394,9 +370,7 @@ def __init__(

def _build_default_module(self, target: Target):
try:
self.optimized_func = self.apply_default_schedule(
self.prim_func_mod, target
)
self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target)
except Exception:
self.optimized_func = None
logger.warning(
Expand Down Expand Up @@ -447,9 +421,7 @@ def post_process(self, code: str) -> str:
return code

def retrieve_weight_shape(self):
return [
int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape
]
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]

def transform_weight(self, weight, scale=None, zeros=None, bias=None):
"""
Expand Down Expand Up @@ -481,18 +453,20 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
if source_format == "int":
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2 ** (bit - 1)
maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
elif source_format in ["fp_e5m2", "fp_e4m3"]:
weight = weight.view(torch.int8)
weight = weight.int()
else:
# For non-integer formats, simply convert weights to integers
weight = weight.int()

np_storage_dtype = getattr(np, self.storage_dtype)

weight = general_compress(
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype
)
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype)

weight = torch.from_numpy(weight).cuda().contiguous()

Expand All @@ -518,24 +492,21 @@ def transform_input(self, input_tensor):
raise ValueError(
f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size."
)
self.ladder_permutate_a._forward_from_prebuild_lib(
input_tensor, self.workspace
)
self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace)
return self.workspace
return input_tensor

def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args = []
args.append(self.transform_input(A))
args.append(W)

if self.lut is not None:
args.append(self.lut)

if output is None:
output = torch.empty(
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device
)
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device)
if scale is not None:
args.append(scale)
if zeros is not None:
Expand Down
20 changes: 15 additions & 5 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_u32_to_f4_to_f16,
_tir_u8_to_f8_e4m3_to_f16,
_tir_packed_to_unsigned_convert_with_zeros,
)

Expand Down Expand Up @@ -58,14 +59,17 @@ def qzeros_dequantize(k, n):
dtype=storage_dtype,
)

Dequantize_qzeros = te.compute(
(K // group_size, N),
qzeros_dequantize,
name="Dequantize_zeros",
)
Dequantize_qzeros = None
if with_zeros and zeros_mode == "quantized":
Dequantize_qzeros = te.compute(
(K // group_size, N),
qzeros_dequantize,
name="Dequantize_zeros",
)

def decode_func(n, k):
if with_zeros and zeros_mode == "quantized":
assert Dequantize_qzeros is not None, "Dequantize_zeros is None"
w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)(
bit,
B[n, k // n_float_per_elem],
Expand All @@ -87,6 +91,8 @@ def decode_func(n, k):
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -260,6 +266,8 @@ def decode_func(n, k):
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -446,6 +454,8 @@ def decode_func(n, k):
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down
1 change: 1 addition & 0 deletions python/bitblas/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
_tir_packed_to_signed_convert, # noqa: F401
_tir_packed_to_unsigned_convert, # noqa: F401
_tir_u32_to_f4_to_f16, # noqa: F401
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
)

Expand Down
17 changes: 17 additions & 0 deletions python/bitblas/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)


def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16")
offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16"))
e_f16 = ((val << tir.const(7, "int16")) + offset)
return tir.reinterpret("float16", s_f16 | e_f16)


def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
return tir.reinterpret("e5m2_float8", val).astype("float16")


def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)

Expand Down Expand Up @@ -173,6 +188,7 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm

return f_convert


def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)

Expand All @@ -185,4 +201,5 @@ def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):

return f_convert


# fmt: on
Loading