[ExecuTorch][WebGPU] WGSL shader-variant codegen (vec1/vec4, fp16-ready) + rms_norm dedup#20727
[ExecuTorch][WebGPU] WGSL shader-variant codegen (vec1/vec4, fp16-ready) + rms_norm dedup#20727JCNTH wants to merge 4 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20727
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled JobAs of commit b06c0d3 with merge base 4af91c3 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude throoughly review, check for overengineering, correctness, and areaas for modularization |
|
Claude finished @JCNTH's task in 4m 24s —— View job Review: WGSL shader-variant codegen
The core idea is sound and the backward-compatible 🔴 Overengineering (the headline concern)You ported
CLAUDE.md is explicit: "When uncertain: choose simpler, more concise" and "No trivial single-use helpers." I'd trim The
🟠 Correctness1. The fp16 path is not "fp16-ready" — it would emit non-compiling WGSL. With
WGSL has no implicit f16↔f32 conversion, so both are type errors, and no 2. Inconsistent error handling in the generator. except ValueError as e:
errors.append(...)But 3. 4. 5. Byte-identical claim — I was unable to run 🟡 Modularization
This keeps the drift-gate logic readable and lets the engine be unit-tested in isolation. Given the trimming suggested above the engine shrinks a lot, so this is optional — but if the machinery stays this size, splitting it is the right call. Minor: the "apply variant keys onto SummarySolid, well-tested foundation with a genuinely backward-compatible seam. The main asks: (1) cut the unused variant-matrix features (esp. the f16/DTYPE axis, which would generate invalid WGSL today), (2) make spec-parse errors flow through Note: this is review feedback only — I haven't modified any files. I also couldn't execute the generator/tests here (Bash approval denied), so the byte-identical + drift claims are assessed statically; rely on CI's |
Stack from ghstack (oldest at bottom):
Adds a build-time template engine to
gen_wgsl_headers.pyso one templated.wgslplus a per-shader<stem>.yamlvariant spec expands into the per-variant embedded_wgsl.hheaders, replacing hand-copied shader variants. Mirrors the Vulkan delegate'sgen_vulkan_spv.pyshader-codegen model (generate_variant_forall/shader_variants, PyYAML parsed with a dup-key-rejectingUniqueKeyLoader), adapted to WGSL: a$if VECaxis for scalar/vec4 packing, and anenable f16;+ f32-accumulator path (widen-to-f32 loads, narrow-to-store) so an fp16 variant is a one-entry spec addition when a consumer needs one.Key changes:
scripts/gen_wgsl_headers.py— a$-block transpiler (preprocess/escape/extract_leading_whitespace), agenerate_variant_combinationsmatrix,parse_template_spec(PyYAML with aUniqueKeyLoadercopied verbatim fromgen_vulkan_spv.py), and 3 WGSL type-helpers (buffer_scalar_type/buffer_gvec_type/accum_scalar_type, named to mirror the Vulkan generator).render_headergains an optional variant name + provenance (backward-compatible with the existing verbatim path).runtime/ops/rms_norm/rms_norm.wgslbecomes one template +rms_norm.yaml(VEC 1/4); the hand-copiedrms_norm_vec4.wgslis deleted.Constraints: the f32/vec paths regenerate BYTE-IDENTICAL —
rms_norm_wgsl.his unchanged andrms_norm_vec4_wgsl.hdiffers only in its// @generated fromprovenance line (wgsl-sha256unchanged). The fp16 (DTYPE) axis is compiled-in infrastructure with no shader consumer in this diff. Like thegen_vulkan_spv.pyit mirrors, this codegen requires PyYAML (a declared ExecuTorch codegen dependency), so it runs under the ExecuTorch build/dev env rather than a bare stdlibpython3. No runtime or numeric change.@exported-using-ghexport
Differential Revision: D110659979
Differential Revision: D110659979