Skip to content

[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971

Open
pggPL wants to merge 22 commits into
NVIDIA:mainfrom
pggPL:grouped_gemm_nvfp4_and_hopper
Open

[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971
pggPL wants to merge 22 commits into
NVIDIA:mainfrom
pggPL:grouped_gemm_nvfp4_and_hopper

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 8, 2026

Description

Adds Hopper (SM90) support to cuBLAS grouped GEMM and enables NVFP4 / FP8 block scaling recipes.

Type of change

  • Documentation change
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change
  • Infra/Build change
  • Code refactoring

Checklist

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 13 commits March 16, 2026 11:36
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use existing nvte_set_grouped_tensor_param with kNVTEGroupedWithGEMMSwizzledScales
instead of the dedicated set/get functions.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add CUBLAS_NVFP4_GROUPED_GEMM_VERSION and CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION macros (13.4+)
- Update check_grouped_gemm_requirements to allow SM90 with cuBLAS 13.4+
- Refactor execute_grouped_gemm to use GroupedGemmConfig struct
- Add divisibility-by-128 validation for FP8 block scaling in setup kernel and quantizer
- Support scalar alpha/beta for Hopper (no per-group alpha/beta)
- Expose get_grouped_gemm_setup_workspace_size to PyTorch via pybind
- Update PyTorch tests to run grouped GEMM on Hopper with cuBLAS 13.4+

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
… scaling tests on Hopper

Extend nvte_grouped_gemm_with_discrete_inputA to handle NVFP4 (Float4E2M1)
inputs: accept kFloat4E2M1 dtype, propagate scale_inv pointers, collect
contiguous amax from discrete tensors, and enforce swizzled-scales checks
for NVFP4 alongside MXFP8. Also add GTEST_SKIP for FP8 tensor scaling
grouped GEMM on Hopper since cuBLAS does not support it there.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…M tests

The setup kernel computes per-tensor scale pointers as data_offset /
block_size, which assumes no padding in the scale buffer. This is only
correct when first_dim % 128 == 0 and last_dim % 128 == 0 (MXFP8) or
last_dim % 64 == 0 (NVFP4). Add explicit assertions in
build_grouped_tensor to catch any future test shapes that violate this.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…d_hopper

Conflicts resolved (3 files):

* tests/pytorch/test_numerics.py
  test_grouped_gemm_grouped_tensor: combined skip rules — Hopper (SM90) requires
  cuBLAS 13.4+, Blackwell+ (SM100) requires cuBLAS 13.3+. Kept main's
  use_bias_scale parametrization.

* transformer_engine/pytorch/cpp_extensions/gemm.py
  general_grouped_gemm_for_grouped_tensor: combined HEAD's num_alphabeta logic
  (single scalar on Hopper, per-group on Blackwell+) with main's cached
  _get_fp32_ones_tensor / _get_fp32_zeros_tensor helpers.

* transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
  - validate_grouped_gemm_inputs: kept HEAD's NVFP4 / FP8 block-scaling
    consistency checks, wrapped in main's nullptr-guard / continue-on-no-data
    pattern.
  - GroupedGemmConfig struct retained; added sm_count from main and
    propagated config_.sm_count -> gemm_config.sm_count in all three
    public APIs.
  - kMaxTensorsPerKernel rename to kMaxGroups (= 64) adopted from main.
  - execute_grouped_gemm signature uses GroupedGemmConfig (HEAD); body uses
    config.sm_count for CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET (from main).
  - Dropped HEAD's simple grouped_bias_add_kernel (dead code); kept main's
    advanced grouped_bias_add_kernel + find_tensor_for_row helper.
  - Replaced inline SM/cuBLAS preambles with check_grouped_gemm_requirements()
    calls in nvte_grouped_gemm, nvte_grouped_gemm_with_discrete_inputA, and
    nvte_grouped_gemm_with_discrete_out. The helper supports both
    Hopper (SM90 + cuBLAS 13.4+) and Blackwell+ (SM100 + cuBLAS 13.3+).
  - Kept HEAD's validate_grouped_gemm_inputs(..., use_per_group_alpha_beta)
    signature for proper alpha/beta validation across architectures.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…or swizzle tests

cublaslt_grouped_gemm.cu:
- Fix incorrect handling of NVFP4/MXFP8 columnwise data in
  build_grouped_gemm_multi_inputA_args by adding a swap_dims flag
  consistent with choose_grouped_operand_storage. Use A_sel.trans
  (post-flip) for gemm_config.avg_k so K is selected from the
  correct dim with discrete A_list.

tests/cpp/test_common.{h,cu}:
- Add enforce_grouped_gemm_alignment parameter (default true) to
  build_grouped_tensor; the MXFP8/NVFP4 first/last_dim 128/64
  alignment asserts are only relevant for the grouped GEMM setup
  kernel, so callers that bypass it (swizzle/unswizzle) opt out.

tests/cpp/operator/test_swizzle.cu:
- Pass enforce_grouped_gemm_alignment=false to build_grouped_tensor
  in MXFP8 swizzle/unswizzle/roundtrip tests, which intentionally
  exercise non-padded shapes.

tests/cpp/operator/test_grouped_gemm.cu:
- Sync GPU/cuBLAS skip rules across all 3 sub-tests, add
  cudaDeviceSynchronize() after nvte_multi_tensor_gemm reference for
  defensive sync, and skip NVFP4 + AllDifferent in all 3 sub-tests
  due to a known flaky bug in the nvte_multi_tensor_gemm reference.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…and_hopper

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

# Conflicts:
#	transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
@pggPL pggPL force-pushed the grouped_gemm_nvfp4_and_hopper branch 2 times, most recently from 335627f to 3f523e7 Compare May 11, 2026 14:04
pggPL and others added 4 commits May 11, 2026 16:13
Apply the same fix as upstream PR NVIDIA#2954 (MXFP8 unaligned dims) to the
analogous NVFP4 / FP8 block scaling paths in setup_grouped_gemm_kernel.

Background: cuBLAS grouped GEMM expects each expert's scale_inv to live
at a specific offset in the contiguous grouped buffer. The quantizer
allocates each per-expert scale_inv tensor padded to the layout cuBLAS
needs (swizzled 128x4 for MX/NV; ceildiv(., 128) x roundup(., 4) for
block scaling). The setup kernel was computing these offsets as
data_offset / block_size for everything except MXFP8 — silently correct
when dims align to 128, but pointing at the middle of the previous
expert's scale tile when they do not. In MoE forward this is reachable
through variable per-expert token counts.

Add three device helpers mirroring compute_grouped_tensor_mxfp8_-
scale_inv_offset:
- compute_grouped_tensor_nvfp4_scale_inv_offset
- compute_grouped_tensor_block_1d_scale_inv_offset
- compute_grouped_tensor_block_2d_scale_inv_offset
Each sums the same padded per-tensor sizes the quantizer uses at alloc
time (Float8BlockQuantizer::get_scale_shape, NVFP4Quantizer::get_scale_-
shape).

NVFP4 columnwise data is set up via use_columnwise(swap_dims=true), so
sel.shape is already pre-transposed for that recipe — the rowwise
formula on (first, last) recovers the colwise alloc. For block scaling
the formula depends on the canonical orientation, so propagate a new
swap_dims field on GroupedOperandSelection and pass effective_rowwise
(sel.rowwise || sel.swap_dims) into the kernel. MXFP8 is invariant
under this change because swap_dims is always false there and its
helper's byte count is invariant under the rowwise flag anyway.

Test: add ShapeCase::kUnalignedAllSame with (M, N, K) = (160, 288, 416)
— all multiples of 32/16 (per-recipe block size) but none multiples of
128, so each expert's scale tile is padded. Exercise it across MXFP8 /
NVFP4 / FP8 block scaling and the three transpose configs that match
the existing parameter grid. Relax build_grouped_tensor's defensive
%128 / %64 alignment assertions to %32 / %16 (block-size only), which
is the actual quantizer requirement now that the offset arithmetic no
longer assumes zero padding.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…st cleanup

Production:
- nvte_grouped_gemm_with_discrete_inputA no longer requires per-expert amax
  buffers to be contiguous. Add `amax_ptrs[kMaxGroups]` to MultiTensorGroupGemmInputArgs
  and read each tensor's amax via indirection in setup_grouped_gemm_kernel
  (mirrors the existing scale_inv_ptrs pattern). The launcher enables the
  NVFP4 alpha computation when amax is available from either source.
- Consolidate four near-identical
  compute_grouped_tensor_{mxfp8,nvfp4,block_1d,block_2d}_scale_inv_offset
  into a single template `compute_grouped_scale_inv_offset<PaddedFn>` and
  collapse the A/B recipe-switch in setup_grouped_gemm_kernel into a local
  `fill_scale_ptr` lambda.

Tests:
- Drop the per-test amax staging workaround in run_grouped_gemm_discrete_in_case
  (no longer needed after the contiguity relax).
- Fix amax management in make_nvfp4_operand: copy values into result's own
  amax buffers instead of aliasing pointers (prevents double-free).
- Extract the three duplicated cuBLAS-version/compute-capability skip blocks
  into a shared `grouped_gemm_skip_reason` helper.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Silences -Wunused-variable (NVIDIA#177-D in nvcc).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@pggPL pggPL force-pushed the grouped_gemm_nvfp4_and_hopper branch from fcefde1 to ce0e4d2 Compare May 11, 2026 14:14
pggPL added 2 commits May 11, 2026 16:17
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the grouped_gemm_nvfp4_and_hopper branch from ce0e4d2 to a4df7bd Compare May 11, 2026 14:20
pggPL and others added 2 commits May 11, 2026 16:44
- nvte_grouped_gemm and nvte_grouped_gemm_with_discrete_out now validate
  per-operand amax for NVFP4 (previously silently dropped the global-scale
  factor when amax was missing). discrete_inputA path also checks B's amax.
- Remove unused ShapeCase::kUnalignedAllSameNVFP4 enum and its comment.
- OperandStorageChoice::swap_dims now defaults to false; rowwise returns
  no longer pass spurious swap_dims=true.
- Unify GroupedGemmSetupWorkspace layout: from_buffers(nullptr, n) returns
  the total byte count, and required_setup_size derives its result from it
  so the layout cannot drift between the two.
- test_common.cu: consolidate the three gather_*_scales lambdas into a
  single gather_scale_inv(bytes_per_elem, get_shape, get_cpu_ptr) helper.
- test_grouped_gemm.cu: extract make_grouped_gemm_ref / make_alpha_beta /
  compare_grouped_d_to_multi helpers; the three run_* variants drop from
  ~1029 to 774 lines with no behavior change.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@pggPL pggPL marked this pull request as ready for review May 11, 2026 15:20
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented May 11, 2026

/te-ci pytorch

@pggPL pggPL requested a review from vthumbe1503 May 11, 2026 15:23
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR extends the cuBLAS grouped GEMM implementation to support Hopper (SM90) with cuBLAS 13.4+ (previously Blackwell-only), and adds two new quantization recipes — NVFP4 (Float4E2M1 with per-block E4M3 scales) and FP8 1D/2D block scaling. The implementation refactors workspace sizing, operand layout selection, and per-expert scale-offset arithmetic into a unified, recipe-agnostic framework.

  • Hopper support: runtime/compile-time guards route to a single shared scalar alpha/beta on SM90 and per-group arrays on SM100+; the setup kernel handles both branches transparently.
  • NVFP4 recipe: introduces padded_nvfp4_scale_inv_bytes offset arithmetic, swizzled E4M3 scale gathering, and a computed per-group alpha that folds in amax_A × amax_B / (6²×448²) as the global dequantization factor (Blackwell-only).
  • FP8 block scaling: adds padded_block_1d/2d_scale_inv_floats helpers and cuBLAS VEC128_32F/BLK128x128_32F scale-mode setup (Hopper-only, requires split accumulator).

Confidence Score: 3/5

The core GEMM dispatch and scale-offset arithmetic look correct for the tested recipes, but the three grouped GEMM entry points are not consistently guarded: only nvte_grouped_gemm rejects FP16 output with NVFP4 inputs, leaving the two discrete variants able to reach cuBLAS with an unsupported output dtype.

The missing FP4+FP16 output guard is an explicit check the author added to one entry point but not the other two that accept the same input recipes. In practice this would trigger a cuBLAS error rather than silent data corruption, but the inconsistency makes the API surface fragile. The non_tn_fp8_ok override omission for the discrete-A path is benign today but reflects a copy-paste drift that could cause layout mismatches if fp8-block support is extended to newer hardware.

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — the discrete output (nvte_grouped_gemm_with_discrete_out) and discrete input-A (nvte_grouped_gemm_with_discrete_inputA) functions both need the FP4+FP16 output restriction and a consistent non_tn_fp8_ok override.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Core implementation: adds Hopper (SM90) support, NVFP4, and FP8 block-scaling to grouped GEMM. Missing FP4+FP16 output guard in discrete GEMM variants; non_tn_fp8_ok is not overridden for fp8_block in the discrete-A path.
tests/cpp/operator/test_grouped_gemm.cu Adds NVFP4 and FP8 block-scaling test cases; refactors duplicated setup into shared helpers. Duplicates the Hopper version constant locally.
tests/cpp/test_common.cu Extends build_grouped_tensor with sub-byte (FP4) support, FP8 block scaling, and NVFP4 amax gathering. Adds enforce_grouped_gemm_alignment flag correctly used to relax alignment checks for swizzle tests.
transformer_engine/pytorch/cpp_extensions/gemm.py Removes duplicated Python workspace-size calculation and delegates to the new C++ API; correctly sizes alpha/beta tensors to 1 (Hopper) vs num_tensors (Blackwell+).
transformer_engine/pytorch/csrc/extensions/gemm.cpp Loosens alpha/beta validation to accept either 1 or num_tensors elements, matching the Hopper (scalar) vs Blackwell+ (per-group) distinction.
tests/pytorch/test_numerics.py Reorders skip guards so Hopper + cuBLAS 13.4+ can run the grouped GEMM test; logic is correct.
transformer_engine/common/transformer_engine.cpp Adds FP8 block-scaling case to CheckGroupedScaleInv, ensuring scale_inv dtype is validated as float32.
transformer_engine/common/common.h Adds is_fp8_block_scaling helper covering both 1D and 2D block scaling modes; consistent with existing pattern.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    Entry["nvte_grouped_gemm / discrete_out / discrete_inputA"] --> HW["check_grouped_gemm_requirements\n(SM>=90 + cuBLAS 13.3+; SM<100 needs cuBLAS 13.4+)"]
    HW --> SM{"SM >= 100 (Blackwell+)?"}
    SM -- Yes --> PAB["use_per_group_alpha_beta = true\nAlpha/Beta: per-group arrays"]
    SM -- No --> SAB["use_per_group_alpha_beta = false\nAlpha/Beta: single scalar"]
    PAB & SAB --> SEL["select_grouped_operand(A, B)\nchoose_grouped_operand_storage\n(TN forced for MXFP8, NVFP4, FP8-block, tensor-FP8)"]
    SEL --> RECIPE{"Scaling Recipe?"}
    RECIPE -- "BF16/FP16" --> PLAIN["No scale handling"]
    RECIPE -- "MXFP8 (SM>=100)" --> MX["set_mxfp8_scale_pointers\npadded_mxfp8_scale_inv_bytes offsets"]
    RECIPE -- "NVFP4 (SM>=100)" --> FP4["set_nvfp4_scale_pointers\npadded_nvfp4_scale_inv_bytes offsets\nComputed alpha: a x amax_A x amax_B / factor"]
    RECIPE -- "FP8 block (SM=90)" --> BS["set_fp8_block_scaling_scale_pointers\npadded_block_1d/2d_scale_inv_floats offsets"]
    RECIPE -- "FP8 tensor" --> FP8["set_fp8_scale_pointers\nOne float per tensor"]
    PLAIN & MX & FP4 & BS & FP8 --> SETUP["setup_grouped_gemm_kernel (GPU)\nFills A/B/C/D/scale/alpha/beta pointer arrays"]
    SETUP --> CUBLASLT["cublasLtMatmul (grouped GEMM)"]
Loading

Comments Outside Diff (1)

  1. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 1758-1763 (link)

    P1 Missing FP4+FP16 output restriction in discrete GEMM variants

    nvte_grouped_gemm explicitly rejects FP16 output when using NVFP4 inputs ("FP4 GEMM does not support FP16 output!"), but nvte_grouped_gemm_with_discrete_out and nvte_grouped_gemm_with_discrete_inputA (which uses validate_grouped_gemm_outputs) both allow FP16 output regardless of input dtype. A caller using NVFP4 inputs with an FP16 output buffer through either discrete variant would bypass this guard and likely encounter a cuBLAS error or silently produce incorrect results.

Reviews (1): Last reviewed commit: "Merge branch 'main' into grouped_gemm_nv..." | Re-trigger Greptile

Comment on lines 1641 to +1645
const bool is_fp8 = is_fp8_dtype(rep_dtype);
const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported();
const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode);
const bool nvfp4 = transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode);
const bool fp8_block = transformer_engine::is_fp8_block_scaling(A_list_info.scaling_mode);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Inconsistent non_tn_fp8_ok override for FP8 block scaling in discrete-A path

select_grouped_operand (used for B) always forces non_tn_fp8_ok = false for FP8 block scaling, but the discrete-A code path retains the device-capability value of nvte_is_non_tn_fp8_gemm_supported(). On hardware where that function returns true, A and B would disagree on whether TN is required, causing a layout mismatch. It is harmless today because FP8 block scaling is restricted to Hopper where the function returns false, but the inconsistency is fragile.

Suggested change
const bool is_fp8 = is_fp8_dtype(rep_dtype);
const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported();
const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode);
const bool nvfp4 = transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode);
const bool fp8_block = transformer_engine::is_fp8_block_scaling(A_list_info.scaling_mode);
const bool is_fp8 = is_fp8_dtype(rep_dtype);
const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode);
const bool nvfp4 = transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode);
const bool fp8_block = transformer_engine::is_fp8_block_scaling(A_list_info.scaling_mode);
// FP8 block scaling on Hopper requires TN layout (matches select_grouped_operand logic for B).
const bool non_tn_fp8_ok = fp8_block ? false : nvte_is_non_tn_fp8_gemm_supported();

Comment on lines +294 to +295
// Compile-time version macro for Hopper grouped GEMM support (mirrors cublaslt_grouped_gemm.cu)
#define CUBLAS_GROUPED_GEMM_HOPPER_VERSION 130400
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicated version constant that can drift out of sync

CUBLAS_GROUPED_GEMM_HOPPER_VERSION is defined both here (as a test-local macro) and in cublaslt_grouped_gemm.cu. If the implementation version is ever updated, this copy may be forgotten. Consider exposing the constant via a shared header or at least using a static_assert to catch drift at compile time.

Suggested change
// Compile-time version macro for Hopper grouped GEMM support (mirrors cublaslt_grouped_gemm.cu)
#define CUBLAS_GROUPED_GEMM_HOPPER_VERSION 130400
// Compile-time version macro for Hopper grouped GEMM support (mirrors cublaslt_grouped_gemm.cu).
// Keep in sync with CUBLAS_GROUPED_GEMM_HOPPER_VERSION in cublaslt_grouped_gemm.cu.
#define CUBLAS_GROUPED_GEMM_HOPPER_VERSION 130400
static_assert(CUBLAS_GROUPED_GEMM_HOPPER_VERSION == 130400,
"Update this copy to match cublaslt_grouped_gemm.cu");

Comment on lines +144 to +163
Tensor make_nvfp4_rowwise(const std::string& name, const std::vector<size_t>& shape) {
Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16);
fillUniform(&input_bf16);

Tensor nvfp4(name, shape, DType::kFloat4E2M1, /*rowwise=*/true, /*columnwise=*/false,
NVTE_NVFP4_1D_SCALING);

QuantizationConfigWrapper quant_config;
nvte_quantize_v2(input_bf16.data(), nvfp4.data(), quant_config, 0);

Tensor nvfp4_sw(name + "_sw", shape, DType::kFloat4E2M1,
/*rowwise=*/true, /*columnwise=*/false, NVTE_NVFP4_1D_SCALING);
nvfp4_sw.set_with_gemm_swizzled_scales(true);
size_t data_bytes = test::bytes(nvfp4.rowwise_shape(), nvfp4.dtype());
NVTE_CHECK_CUDA(cudaMemcpy(nvfp4_sw.rowwise_dptr(), nvfp4.rowwise_dptr(),
data_bytes, cudaMemcpyDeviceToDevice));
nvte_swizzle_scaling_factors(nvfp4.data(), nvfp4_sw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
return nvfp4_sw;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just swizzle the scales in the first nvte_quantize_v2 call instead of going through 2 tensors?

Comment on lines +174 to +178
Tensor rowwise = make_nvfp4_rowwise(name + "_row", shape);

// 2. Columnwise: transpose input, quantize + swizzle as rowwise of transposed shape
std::vector<size_t> t_shape = {shape[1], shape[0]};
Tensor colwise = make_nvfp4_rowwise(name + "_col", t_shape);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of those tensors are using different inputs (fillUniform called in both invocations of the make_nvfp4_rowwise function).

Copy link
Copy Markdown
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are issues with the test code (most notably the rowwise and columnwise taken from different inputs).

// Creates an NVFP4 operand with both rowwise and columnwise data, swizzled scales.
// NVFP4 "columnwise" data is the transposed tensor quantized rowwise.
// We quantize rowwise directly, and for columnwise we quantize the transposed input rowwise.
Tensor make_nvfp4_operand(const std::string& name, const std::vector<size_t>& shape,
Copy link
Copy Markdown
Member

@ptrendx ptrendx May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I don't understand why you can't just quantize the tensor in both directions at the same time, then all of those issues with using different data for both would not be there.
The FP8 blockwise counterpart is doing just that.

return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}};
case ShapeCase::kUnalignedAllSame:
default:
// (M, N, K) all multiples of 32 (MXFP8 block) and 16 (NVFP4 block), but NONE
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiple of 16 for NVFP4 is wrong as TMA requires the alignment to 16B, which is 32 elements in case of the NVFP4. So if anything for the "not nice" shapes we should test multiple of 16 for MXFP8 (although I'm not sure if that would be currently passed by the rest of the logic in cublaslt_gemm.cu, I do have a separate PR to relax some of those requirements) and multiple of 32 for NVFP4.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, the values that you have here are actually common to both of the types and both have K being multiple of 32, so this comment is just wrong.

Comment on lines +299 to +300
return "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " +
std::to_string(CUBLAS_VERSION) + ".";
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this is correct here. The code in TE itself needs to check both compile and runtime versions of cublas. The test should not care at all about the compilation version (since it doesn't actually use any API from cublas) and instead should check against the runtime version.

std::vector<size_t>{M, N},
DType::kBFloat16));
s.D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
std::vector<size_t>{M, N}, DType::kBFloat16));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we only support BF16 output?

AlphaBetaTensors ab = make_alpha_beta(num_gemms);

constexpr size_t cublas_ws_bytes = 32ull * 1024 * 1024;
const size_t setup_ws_bytes = nvte_get_grouped_gemm_setup_workspace_size(num_gemms);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH not a fan of the name of this function, but I guess that ship has sailed already.

if (auto reason = grouped_gemm_skip_reason(params.input_case); !reason.empty()) {
GTEST_SKIP() << reason;
}
#if CUBLAS_VERSION >= 130300
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this guard? We are not using any cublas API here so compilation is OK and we already skipped the test at this point anyway if the cublas version is too low.

kFP8Current,
kBF16,
kMXFP8,
kNVFP4,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, why do we have another enum here that is basically the same as the scaling mode?

Comment thread tests/cpp/test_common.cu
Comment on lines +1285 to +1287
NVTE_CHECK(last_dims[i] % 32 == 0,
"MXFP8 grouped GEMM test: last_dim must be divisible by 32, got ",
last_dims[i]);
Copy link
Copy Markdown
Member

@ptrendx ptrendx May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to have 16 here, but probably not yet, see my PR #2894.

ws.d_cols = reinterpret_cast<int *>(setup_ws_ptr + offset);

// 8 pointer arrays (each 16-byte aligned), then 6 int arrays, then 1 float array.
align_ptr();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong if the base is not aligned already (since it only takes the offset into account and that at the beginning is always 0). Fixing this then produces an issue where the returned workspace size is not actually between the get workspace size call and the actual execution, since in the get workspace size call the base is set to nullptr, so always aligned. This call should therefore assume the worst-case alignment requirement when calculating the workspace size.

Comment on lines +329 to +331
NVTE_CHECK(sm >= 100, api_name, " requires Blackwell (SM100) or newer architecture.");
NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name,
" requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong - what about the situation where we compiled against the cublas version that is not enough for any grouped gemm support (and so some stuff is not compiled I think?) but then run it on a system with newer cublas? This would pass, but the functionality would still not be there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants