Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
"where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)

# Nvidia produces non-public oem gpu models that are part of drivers but not mapped to correct tvm target
# Remap list to match the oem model name to the closest public model name
NVIDIA_GPU_REMAP = {
"NVIDIA PG506-230": "NVIDIA A100",
"NVIDIA PG506-232": "NVIDIA A100",
}

def get_gpu_model_from_nvidia_smi(gpu_id: int = 0):
"""
Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU.
Expand Down Expand Up @@ -80,11 +87,9 @@ def auto_detect_nvidia_target(gpu_id: int = 0) -> str:
# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id)

# TODO: move to a more res-usable device remapping util method
# compat: Nvidia makes several oem (non-public) versions of A100 and perhaps other models that
# do not have clearly defined TVM matching target so we need to manually map them to the correct one.
if gpu_model == "NVIDIA PG506-230":
gpu_model = "NVIDIA A100"
# Compat: remap oem devices to their correct non-oem model names for tvm target
if gpu_model in NVIDIA_GPU_REMAP:
gpu_model = NVIDIA_GPU_REMAP[gpu_model]

target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"
return target