Skip to content

[ExecuTorch][WebGPU] WGSL shader-variant codegen (vec1/vec4, fp16-ready) + rms_norm dedup#20727

Open
JCNTH wants to merge 4 commits into
gh/JCNTH/1/basefrom
gh/JCNTH/1/head
Open

[ExecuTorch][WebGPU] WGSL shader-variant codegen (vec1/vec4, fp16-ready) + rms_norm dedup#20727
JCNTH wants to merge 4 commits into
gh/JCNTH/1/basefrom
gh/JCNTH/1/head

Conversation

@JCNTH

@JCNTH JCNTH commented Jul 5, 2026

Copy link
Copy Markdown

Stack from ghstack (oldest at bottom):

Adds a build-time template engine to gen_wgsl_headers.py so one templated .wgsl plus a per-shader <stem>.yaml variant spec expands into the per-variant embedded _wgsl.h headers, replacing hand-copied shader variants. Mirrors the Vulkan delegate's gen_vulkan_spv.py shader-codegen model (generate_variant_forall / shader_variants, PyYAML parsed with a dup-key-rejecting UniqueKeyLoader), adapted to WGSL: a $if VEC axis for scalar/vec4 packing, and an enable f16; + f32-accumulator path (widen-to-f32 loads, narrow-to-store) so an fp16 variant is a one-entry spec addition when a consumer needs one.

Key changes:

  • scripts/gen_wgsl_headers.py — a $-block transpiler (preprocess/escape/extract_leading_whitespace), a generate_variant_combinations matrix, parse_template_spec (PyYAML with a UniqueKeyLoader copied verbatim from gen_vulkan_spv.py), and 3 WGSL type-helpers (buffer_scalar_type/buffer_gvec_type/accum_scalar_type, named to mirror the Vulkan generator). render_header gains an optional variant name + provenance (backward-compatible with the existing verbatim path).
  • runtime/ops/rms_norm/rms_norm.wgsl becomes one template + rms_norm.yaml (VEC 1/4); the hand-copied rms_norm_vec4.wgsl is deleted.

Constraints: the f32/vec paths regenerate BYTE-IDENTICAL — rms_norm_wgsl.h is unchanged and rms_norm_vec4_wgsl.h differs only in its // @generated from provenance line (wgsl-sha256 unchanged). The fp16 (DTYPE) axis is compiled-in infrastructure with no shader consumer in this diff. Like the gen_vulkan_spv.py it mirrors, this codegen requires PyYAML (a declared ExecuTorch codegen dependency), so it runs under the ExecuTorch build/dev env rather than a bare stdlib python3. No runtime or numeric change.
@exported-using-ghexport

Differential Revision: D110659979

Differential Revision: D110659979

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jul 5, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20727

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job

As of commit b06c0d3 with merge base 4af91c3 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla

linux-foundation-easycla Bot commented Jul 5, 2026

Copy link
Copy Markdown

CLA Signed
The committers listed above are authorized under a signed CLA.

@github-actions

github-actions Bot commented Jul 5, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@JCNTH

JCNTH commented Jul 5, 2026

Copy link
Copy Markdown
Author

@claude throoughly review, check for overengineering, correctness, and areaas for modularization

@claude

claude Bot commented Jul 5, 2026

Copy link
Copy Markdown

Claude finished @JCNTH's task in 4m 24s —— View job


Review: WGSL shader-variant codegen

  • Read the diff (base origin/gh/JCNTH/1/base)
  • Review gen_wgsl_headers.py codegen engine
  • Review rms_norm.wgsl template + rms_norm.json
  • Review tests
  • Assess overengineering, correctness, modularization

The core idea is sound and the backward-compatible render_header path plus the byte-identical constraint are the right way to de-risk this. My feedback clusters around the three axes you asked about.


🔴 Overengineering (the headline concern)

You ported gen_vulkan_spv.py's variant machinery wholesale, but this codegen currently serves one shader with one 2-value axis. A large fraction of generate_variant_combinations / parse_template_spec is unreachable by any spec in the repo:

Feature Exercised?
VALUE list + SUFFIX ✅ (VEC 1/4)
combination / combinationN (parameter_names+combos) ❌ dead
RANGE: [a,b] ❌ dead
VALUE: int/uint rejection ❌ guards nothing
per-variant generate_variant_forall override ❌ (only the global one is used)
VARIANT_NAME injected into params ❌ never read by any template (grep confirms it's only written)
multiple top-level template keys in one JSON ❌ (headers_for_shader hard-requires exactly [stem])

CLAUDE.md is explicit: "When uncertain: choose simpler, more concise" and "No trivial single-use helpers." I'd trim generate_variant_combinations to the VALUE/SUFFIX path, drop combination/RANGE/int-reject/per-variant-override/VARIANT_NAME, and re-add each only when a real second consumer needs it. The spec format stays a strict subset of Vulkan's, so future porting is still easy. That removes ~100+ lines of speculative surface.

The DTYPE/f16 axis is the sharpest case — see correctness below; it's not just unused, it's a trap.

wgsl_accum_type() always returns the constant "f32". That's a zero-argument helper returning a literal, injected into exec globals so the template can call ${wgsl_accum_type()}. Per CLAUDE.md's no-trivial-helpers rule this could just be f32 written directly in the template — the "accumulators stay f32" intent lives fine as a shader comment.


🟠 Correctness

1. The fp16 path is not "fp16-ready" — it would emit non-compiling WGSL. With DTYPE=half, buffers become vec4<f16> but wgsl_accum_type() stays f32. The template then generates:

  • local_sq_sum /*f32*/ = local_sq_sum + dot(v /*vec4<f16>*/, v)f32 + f16
  • t_out /*f16*/ = t_in /*f16*/ * rstd /*f32*/ * t_weight

WGSL has no implicit f16↔f32 conversion, so both are type errors, and no enable f16; directive is emitted anywhere. So the "one-entry spec addition" claim isn't real: a consumer would also need explicit f32(...)/{wgsl_scalar_type(DTYPE)}(...) casts in the template and the enable directive. I'd recommend dropping the DTYPE axis entirely for this PR (YAGNI) rather than shipping infra that generates broken code the moment it's used. This single change also deletes the half branch and wgsl_accum_type.

2. Inconsistent error handling in the generator. main() only catches ValueError:

except ValueError as e:
    errors.append(...)

But parse_template_spec raises bare KeyError (params_dict["parameter_names_with_default_values"], variant["NAME"], the "defined twice" check) and AssertionError (assert len(invalid_keys) == 0). A malformed sidecar therefore crashes with a traceback instead of the friendly "Cannot generate header" report — and --check (the CI gate) would exit non-zero on a Python exception rather than a clean drift message. Either broaden the catch to (ValueError, KeyError, AssertionError) or convert those raises to ValueError. The assert in particular should be a raise ValueError(f"unknown keys {invalid_keys}") — asserts vanish under python -O and give no key names. Fix this →

3. if "LINT" in input_line: continue silently drops any template line containing the substring "LINT" anywhere (not just at a delimiter). Inherited from xngen, low risk for WGSL, but a shader comment mentioning "LINT" would silently vanish from output. Worth a narrower match or at least a comment.

4. provenance_stem: str = None should be Optional[str] (you already import Optional). Minor, but lintrunner/mypy may flag it.

5. Byte-identical claim — I was unable to run gen_wgsl_headers.py --check or the pytest suite (Bash approval was denied in this environment). Statically it looks right: rms_norm_wgsl.h is absent from the changed-files list (regenerated identical), and rms_norm_vec4_wgsl.h differs only in the provenance line with an unchanged wgsl-sha256. Please confirm CI's --check gate is green.


🟡 Modularization

gen_wgsl_headers.py went from a small verbatim-embedder to a 526-line file that now also contains a general-purpose $-block template engine. The two concerns are cleanly separable:

  • wgsl_template.pyescape / extract_leading_whitespace / preprocess / generate_variant_combinations / parse_template_spec (the generic engine, mirrors how Vulkan keeps codegen separable).
  • gen_wgsl_headers.pydiscover / render_header / headers_for_shader / main (the WGSL-header-specific glue + the wgsl_*_type helpers).

This keeps the drift-gate logic readable and lets the engine be unit-tested in isolation. Given the trimming suggested above the engine shrinks a lot, so this is optional — but if the machinery stays this size, splitting it is the right call.

Minor: the "apply variant keys onto default_params_copy, skipping reserved_keys" block is duplicated verbatim in both the iterated and non-iterated branches of parse_template_spec; worth hoisting.


Summary

Solid, well-tested foundation with a genuinely backward-compatible seam. The main asks: (1) cut the unused variant-matrix features (esp. the f16/DTYPE axis, which would generate invalid WGSL today), (2) make spec-parse errors flow through main's error report instead of crashing, and (3) consider splitting the template engine into its own module. None of these block the no-op-runtime goal; they're about not shipping speculative complexity into a stdlib codegen that serves one shader.

Note: this is review feedback only — I haven't modified any files. I also couldn't execute the generator/tests here (Bash approval denied), so the byte-identical + drift claims are assessed statically; rely on CI's --check for the ground truth.
· branch gh/JCNTH/1/head

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant