diff --git "a/bundle.js" "b/bundle.js" new file mode 100644--- /dev/null +++ "b/bundle.js" @@ -0,0 +1,8514 @@ +/* + * ,; + * \@@#\: :/. .:;;: + * _@@@@@@#+\|/!;;!-@@@--; ,@@@@@; + * .!_*@@@@@@@@@@@@@@@@@@@; |@@@@@\ + * .:!|+@@@@@##@@@@@@@#! -@@@@@#, + * .\@@@*;,\@@@@@@@@+,*@@@@@@+. + * :*#@@@@@@@@@@@@@@-+@@@@@@@\@@@@-. + * .#@@@@@#@@@@#*@@@+ /@@@@@@;\@@@@+. + * ;\/:, -@@@@;|@@@\ ,+@@@@!.+@@@@*: + * ,@@@@#*@@@@@#+__!. ,*@@@@@/ + * \##+_@@@@@@@@, ,+@@@_: + * ;;,,..,: !;. + */ +var __defProp = Object.defineProperty; +var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); +var __export = (target, all) => { + for (var name in all) + __defProp(target, name, { get: all[name], enumerable: true }); +}; + +// src/config.js +var QWEN25_3B = { + hiddenSize: 2048, + numLayers: 36, + numHeads: 16, + numKVHeads: 2, + headDim: 128, + intermediateSize: 11008, + vocabSize: 151936, + rmsNormEps: 1e-6, + ropeTheta: 1e6, + /* + * TECHNIQUE: Tie word embeddings + * input embedding == output head. + * Simplifies loading (one tensor), schema, and final projection math. + * Required by the current model_uploader + schema. + */ + tieWordEmbeddings: true, + // QKV projections carry a bias in Qwen2.5; o_proj and the MLP do not. + attentionBias: true +}; + +// src/readers.js +function urlReader(baseUrl, headers = {}) { + const base = baseUrl.endsWith("/") ? baseUrl : baseUrl + "/"; + return { + async range(path, start, end) { + const r = await fetch(base + path, { + headers: { ...headers, Range: `bytes=${start}-${end - 1}` } + }); + if (!r.ok && r.status !== 206) { + throw new Error(`range ${path} ${start}-${end}: ${r.status}`); + } + return await r.arrayBuffer(); + }, + async text(path) { + const r = await fetch(base + path, { headers }); + if (!r.ok) throw new Error(`fetch ${path}: ${r.status}`); + return await r.text(); + } + }; +} +__name(urlReader, "urlReader"); +function hfReader(repo, token = "", rev = "main") { + return urlReader( + `https://huggingface.co/${repo}/resolve/${rev}`, + token ? { Authorization: `Bearer ${token}` } : {} + ); +} +__name(hfReader, "hfReader"); +function fileReader(fileMap) { + const pick = /* @__PURE__ */ __name((path) => fileMap[path] || fileMap[path.split("/").pop()], "pick"); + return { + async range(path, start, end) { + const f = pick(path); + if (!f) throw new Error(`file not provided: ${path}`); + return await f.slice(start, end).arrayBuffer(); + }, + async text(path) { + const f = pick(path); + if (!f) throw new Error(`file not provided: ${path}`); + return await f.text(); + } + }; +} +__name(fileReader, "fileReader"); + +// src/services/adapter_registry.js +var AdapterRegistry = class { + static { + __name(this, "AdapterRegistry"); + } + constructor() { + this.adapters = { none: null }; + } + add(name, modules) { + this.adapters[name] = { modules }; + return this.adapters[name]; + } + get(name) { + return this.adapters[name] || null; + } + /* + * TECHNIQUE: Runtime adapter swapping via setLora + * Registry holds pre-uploaded A/B buffers. applyToRuntime calls + * rt.setLora which just swaps references — no weight reload. + */ + applyToRuntime(name, rt) { + const adapter = this.get(name); + if (adapter) rt.setLora(adapter); + else rt.clearLora(); + return adapter; + } +}; + +// src/qwgpu/kernels.js +var GEMV = ` +enable subgroups; +requires immediate_address_space; +requires subgroup_id; +struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var w: array; // [N][K/4] int8 +@group(0) @binding(2) var scale: array; // [N] +@group(0) @binding(3) var bias: array; // [N] or dummy +@group(0) @binding(4) var loraD: array; // [rank] precomputed x@A (or dummy) +@group(0) @binding(5) var loraB: array; // [rank][N] (or dummy) +@group(0) @binding(6) var y: array; // [N] +var m: Meta; +var part: array; // one slot per subgroup +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32, + @builtin(subgroup_id) sgroup: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.N) { return; } // workgroup-uniform: whole group exits together + let K4 = m.K/4u; let rb = n*K4; + var acc = 0.0; + for (var k = tid; k < K4; k = k + 64u) { + let p = w[rb+k]; + let v = unpack4xI8(p); // vec4 + let kk = k*4u; + acc = acc + x[kk]*f32(v.x) + x[kk+1u]*f32(v.y) + x[kk+2u]*f32(v.z) + x[kk+3u]*f32(v.w); + } + let ssum = subgroupAdd(acc); // reduce within subgroup (no barrier) + if (sgid == 0u) { part[tid / sgsz] = ssum; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (64u + sgsz - 1u) / sgsz; var red = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { red = red + part[i]; } + var o = red * scale[n]; + if (m.hasBias == 1u) { o = o + bias[n]; } + if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } + y[n] = o; + } +}`; +var LORA_A = ` +enable subgroups; +requires immediate_address_space; +@group(0) @binding(0) var x: array; // [K] +@group(0) @binding(1) var A: array; // [rank][K] (transposed) +@group(0) @binding(2) var d: array; // [rank] +var m: vec2; // K, rank +var part: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let r = wid.x; let K = m.x; if (r >= m.y) { return; } + let rb = r*K; var acc = 0.0; + for (var k = lid.x; k < K; k = k + 64u) { acc = acc + x[k]*A[rb + k]; } + let s = subgroupAdd(acc); + if (sgid == 0u) { part[lid.x / sgsz] = s; } + workgroupBarrier(); + if (lid.x == 0u) { let nsg=(64u+sgsz-1u)/sgsz; var o=0.0; for(var i=0u;i x: array; // [T][K] +@group(0) @binding(1) var A: array; // [rank][K] +@group(0) @binding(2) var d: array; // [T][rank] +var m: vec4; // K, rank, T, _ +var part: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let r = wid.x; let t = wid.y; let K = m.x; let rank = m.y; if (r >= rank || t >= m.z) { return; } + let xb = t*K; let ab = r*K; var acc = 0.0; + for (var k = lid.x; k < K; k = k + 64u) { acc = acc + x[xb + k]*A[ab + k]; } + let s = subgroupAdd(acc); + if (sgid == 0u) { part[lid.x / sgsz] = s; } + workgroupBarrier(); + if (lid.x == 0u) { let nsg=(64u+sgsz-1u)/sgsz; var o=0.0; for(var i=0u;i d: array; // [T][rank] +@group(0) @binding(1) var B: array; // [rank][N] +@group(0) @binding(2) var Y: array; // [T][N] +var m: Meta; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i = gid.y * (m.gx * 256u) + gid.x; + if (i >= m.T * m.N) { return; } + let t = i / m.N; let n = i % m.N; var acc = 0.0; + for (var r = 0u; r < m.rank; r = r + 1u) { acc = acc + d[t*m.rank + r] * B[r*m.N + n]; } + Y[i] = Y[i] + m.scale * acc; +}`; +var LORA_B_ADD = ` +requires immediate_address_space; +struct Meta { N:u32, rank:u32, p0:u32, p1:u32, scale:f32, f0:f32, f1:f32, f2:f32 }; +@group(0) @binding(0) var d: array; // [rank] +@group(0) @binding(1) var B: array; // [rank][N] +@group(0) @binding(2) var y: array; // [N] +var m: Meta; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let n = gid.x; + if (n >= m.N) { return; } + var acc = 0.0; + for (var r = 0u; r < m.rank; r = r + 1u) { acc = acc + d[r] * B[r*m.N + n]; } + y[n] = y[n] + m.scale * acc; +}`; +var RMSNORM = ` +requires immediate_address_space; +override WG: u32 = 256u; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var g: array; +@group(0) @binding(2) var y: array; +var m: vec2; // K, eps +var part: array; +@compute @workgroup_size(WG) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; let K = u32(m.x); + var s = 0.0; for (var k = tid; k < K; k = k + WG) { let v = x[k]; s = s + v*v; } + part[tid] = s; workgroupBarrier(); + for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } + let inv = inverseSqrt(part[0]/m.x + m.y); + for (var k = tid; k < K; k = k + WG) { y[k] = x[k]*inv*g[k]; } +}`; +var RMSNORM_F16 = ` +requires immediate_address_space; +enable f16; +override WG: u32 = 256u; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var g: array; +@group(0) @binding(2) var y: array; +var m: vec2; // K, eps +// Reduction accumulates in f32 even though the normalize is f16: summing v*v over +// thousands of dims overflows f16 (>65504) at high-magnitude tokens (the attention +// sink), which collapses inv to 0. Keeping the sum in f32 is the overflow-safe path. +var part: array; +@compute @workgroup_size(WG) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; let K = u32(m.x); + var s = 0.0; + for (var k = tid; k < K; k = k + WG) { let v = f32(x[k]); s = s + v*v; } + part[tid] = s; workgroupBarrier(); + for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } + let inv = f16(inverseSqrt(part[0]/m.x + m.y)); + for (var k = tid; k < K; k = k + WG) { y[k] = f32( f16(x[k]) * inv * f16(g[k]) ); } +}`; +var ROPE = ` +requires immediate_address_space; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var cosT: array; +@group(0) @binding(2) var sinT: array; +var m: vec3; // nHeads, headDim, pos +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let g = gid.x; let H = m.x; let D = m.y; let pos = m.z; let half = D/2u; + if (g >= H*half) { return; } + let h = g / half; let j = g % half; + let lo = h*D + j; let hi = lo + half; let off = pos*D + j; + let c = cosT[off]; let s = sinT[off]; + let xl = x[lo]; let xh = x[hi]; + // EXACT rotate-half: separately-rounded products (fma(a,b,0)) prevent the + // compiler from contracting x*c - x*s into a single fma, matching the PyTorch + // reference rounding exactly. + x[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); + x[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); +}`; +var ROPE_F16 = ` +requires immediate_address_space; +enable f16; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var cosT: array; +@group(0) @binding(2) var sinT: array; +var m: vec3; // nHeads, headDim, pos +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let g = gid.x; let H = m.x; let D = m.y; let pos = m.z; let half = D/2u; + if (g >= H*half) { return; } + let h = g / half; let j = g % half; + let lo = h*D + j; let hi = lo + half; let off = pos*D + j; + let c = f16(cosT[off]); let s = f16(sinT[off]); + let xl = f16(x[lo]); let xh = f16(x[hi]); + x[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); + x[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); +}`; +var ROPE_QK = ` +requires immediate_address_space; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var k: array; +@group(0) @binding(2) var cosT: array; +@group(0) @binding(3) var sinT: array; +var m: vec4; // qHeads, kvHeads, headDim, pos +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let g = gid.x; let qH = m.x; let kH = m.y; let D = m.z; let pos = m.w; let half = D/2u; + let qPairs = qH * half; let kPairs = kH * half; let total = qPairs + kPairs; + if (g >= total) { return; } + let isK = g >= qPairs; + var r = g; + if (isK) { r = g - qPairs; } + let h = r / half; let j = r % half; + let lo = h*D + j; let hi = lo + half; let off = pos*D + j; + let c = cosT[off]; let s = sinT[off]; + if (isK) { + let xl = k[lo]; let xh = k[hi]; + k[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); k[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); + } else { + let xl = q[lo]; let xh = q[hi]; + q[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); q[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); + } +}`; +var ROPE_QK_F16 = ` +requires immediate_address_space; +enable f16; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var k: array; +@group(0) @binding(2) var cosT: array; +@group(0) @binding(3) var sinT: array; +var m: vec4; // qHeads, kvHeads, headDim, pos +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let g = gid.x; let qH = m.x; let kH = m.y; let D = m.z; let pos = m.w; let half = D/2u; + let qPairs = qH * half; let kPairs = kH * half; let total = qPairs + kPairs; + if (g >= total) { return; } + let isK = g >= qPairs; + var r = g; + if (isK) { r = g - qPairs; } + let h = r / half; let j = r % half; + let lo = h*D + j; let hi = lo + half; let off = pos*D + j; + let c = f16(cosT[off]); let s = f16(sinT[off]); + if (isK) { + let xl = f16(k[lo]); let xh = f16(k[hi]); + k[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); k[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); + } else { + let xl = f16(q[lo]); let xh = f16(q[hi]); + q[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); q[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); + } +}`; +var ATTN_PARTIAL = ` +requires immediate_address_space; +enable subgroups; +override WG: u32 = 128u; +struct AttnP { nHeads: u32, nKV: u32, ctx: u32, hd: u32, nsplit: u32, chunk: u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var pm: array; // [nHeads*nsplit] per-split max +@group(0) @binding(4) var pz: array; // [nHeads*nsplit] per-split sum +@group(0) @binding(5) var po: array; // [nHeads*nsplit*hd] unnorm weighted V +var m: AttnP; +var sc: array; +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let h = wid.x; let s = wid.y; let tid = lid.x; + let nHeads = m.nHeads; let nKV = m.nKV; let ctx = m.ctx; let hd = m.hd; let nsplit = m.nsplit; let chunk = m.chunk; + let kvh = h / (nHeads / nKV); + let qbase = h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scale = 1.0/sqrt(f32(hd)); + let nsg = (128u + sgsz - 1u) / sgsz; + let t0 = s*chunk; var t1 = t0 + chunk; if (t1 > ctx) { t1 = ctx; } + let t = t0 + tid; var sv = -1e30; + if (t < t1) { var dot = 0.0; let kb = t*stride + hoff; for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } sv = dot*scale; } + let sgm = subgroupMax(sv); if (sgid == 0u) { red[tid/sgsz] = sgm; } + workgroupBarrier(); + var M = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { M = max(M, red[i]); } + workgroupBarrier(); + var ev = 0.0; if (t < t1) { ev = exp(sv - M); } sc[tid] = ev; + let sgs = subgroupAdd(ev); if (sgid == 0u) { red[tid/sgsz] = sgs; } + workgroupBarrier(); + var Z = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { Z = Z + red[i]; } + workgroupBarrier(); + let len = t1 - t0; let pbase = (h*nsplit + s)*hd; + for (var d = tid; d < hd; d = d + 128u) { + var acc = 0.0; for (var tt = 0u; tt < len; tt = tt + 1u) { acc = acc + sc[tt]*vc[(t0+tt)*stride + hoff + d]; } + po[pbase + d] = acc; + } + if (tid == 0u) { pm[h*nsplit + s] = M; pz[h*nsplit + s] = Z; } +}`; +var ATTN_PARTIAL_F16 = ` +requires immediate_address_space; +enable subgroups; +enable f16; +override WG: u32 = 128u; +struct AttnP { nHeads: u32, nKV: u32, ctx: u32, hd: u32, nsplit: u32, chunk: u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var pm: array; // [nHeads*nsplit] per-split max +@group(0) @binding(4) var pz: array; // [nHeads*nsplit] per-split sum +@group(0) @binding(5) var po: array; // [nHeads*nsplit*hd] unnorm weighted V +var m: AttnP; +// f16 "staging" mode: Q/K/V values are read through f16 (so they carry f16 rounding, +// modelling an f16 KV cache), but every REDUCTION \u2014 the QK dot, the softmax max/sum, +// and the weighted-V accumulation \u2014 runs in f32. Accumulating scores in f16 overflows +// at long context / high-magnitude tokens; f32 accumulation is the overflow-safe path +// (matches the Gemma-4 "scores/PV accumulate in f32, only K/V carry f16 rounding"). +var sc: array; +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let h = wid.x; let s = wid.y; let tid = lid.x; + let nHeads = m.nHeads; let nKV = m.nKV; let ctx = m.ctx; let hd = m.hd; let nsplit = m.nsplit; let chunk = m.chunk; + let kvh = h / (nHeads / nKV); + let qbase = h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scale = 1.0 / sqrt(f32(hd)); + let nsg = (WG + sgsz - 1u) / sgsz; + let t0 = s*chunk; var t1 = t0 + chunk; if (t1 > ctx) { t1 = ctx; } + let t = t0 + tid; var sv = -1e30; + if (t < t1) { var dot = 0.0; let kb = t*stride + hoff; for (var d = 0u; d < hd; d = d + 1u) { dot = dot + f32(f16(q[qbase+d])) * f32(f16(kc[kb+d])); } sv = dot*scale; } + let sgm = subgroupMax(sv); if (sgid == 0u) { red[tid/sgsz] = sgm; } + workgroupBarrier(); + var M = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { M = max(M, red[i]); } + workgroupBarrier(); + var ev = 0.0; if (t < t1) { ev = exp(sv - M); } sc[tid] = ev; + let sgs = subgroupAdd(ev); if (sgid == 0u) { red[tid/sgsz] = sgs; } + workgroupBarrier(); + var Z = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { Z = Z + red[i]; } + workgroupBarrier(); + let len = t1 - t0; let pbase = (h*nsplit + s)*hd; + for (var d = tid; d < hd; d = d + WG) { + var acc = 0.0; for (var tt = 0u; tt < len; tt = tt + 1u) { acc = acc + sc[tt] * f32(f16(vc[(t0+tt)*stride + hoff + d])); } + po[pbase + d] = acc; + } + if (tid == 0u) { pm[h*nsplit + s] = M; pz[h*nsplit + s] = Z; } +}`; +var ATTN_COMBINE = ` +requires immediate_address_space; +override WG: u32 = 128u; +@group(0) @binding(0) var pm: array; +@group(0) @binding(1) var pz: array; +@group(0) @binding(2) var po: array; +@group(0) @binding(3) var o: array; +var m: vec4; // nHeads, hd, nsplit, _ +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let h = wid.x; let tid = lid.x; let hd = m.y; let nsplit = m.z; let base = h*nsplit; + var M = -1e30; for (var s = 0u; s < nsplit; s = s + 1u) { M = max(M, pm[base+s]); } + var Z = 0.0; for (var s = 0u; s < nsplit; s = s + 1u) { Z = Z + pz[base+s]*exp(pm[base+s]-M); } + let invZ = 1.0 / Z; + for (var d = tid; d < hd; d = d + WG) { + var acc = 0.0; + for (var s = 0u; s < nsplit; s = s + 1u) { acc = acc + exp(pm[base+s]-M)*po[(base+s)*hd + d]; } + o[h*hd + d] = acc * invZ; + } +}`; +var ATTN_COMBINE_F16 = ` +requires immediate_address_space; +enable f16; +override WG: u32 = 128u; +@group(0) @binding(0) var pm: array; +@group(0) @binding(1) var pz: array; +@group(0) @binding(2) var po: array; +@group(0) @binding(3) var o: array; +var m: vec4; // nHeads, hd, nsplit, _ +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let h = wid.x; let tid = lid.x; let hd = m.y; let nsplit = m.z; let base = h*nsplit; + // Cross-split softmax merge accumulates max/sum in f32 (overflow-safe); only the + // final per-element weighting carries f16 rounding. + var M = -1e30; for (var s = 0u; s < nsplit; s = s + 1u) { M = max(M, pm[base+s]); } + var Z = 0.0; for (var s = 0u; s < nsplit; s = s + 1u) { Z = Z + pz[base+s] * exp(pm[base+s] - M); } + let invZ = 1.0 / Z; + for (var d = tid; d < hd; d = d + WG) { + var acc = 0.0; + for (var s = 0u; s < nsplit; s = s + 1u) { acc = acc + exp(pm[base+s] - M) * f32(f16(po[(base+s)*hd + d])); } + o[h*hd + d] = acc * invZ; + } +}`; +var GEMM4 = ` +requires immediate_address_space; +struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; +@group(0) @binding(0) var A: array; // [T][K] +@group(0) @binding(1) var W: array; // [N][K/8] int4 +@group(0) @binding(2) var scale: array; // [N][gpr] +@group(0) @binding(3) var bias: array; // [N] or dummy +@group(0) @binding(4) var Y: array; // [T][N] +var m: Meta; +const BM = 16u; const BN = 64u; +var As: array; // BM*8 \u2014 A staged for one 8-wide K chunk +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; + let K8 = m.K/8u; let rb = col*K8; + var acc: array; + for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } + for (var c = 0u; c < K8; c = c + 1u) { + for (var l = lid.x; l < BM*8u; l = l + 64u) { + let tt = l / 8u; let trow = tTile + tt; + As[l] = select(0.0, A[trow*m.K + c*8u + (l % 8u)], trow < m.T); + } + workgroupBarrier(); + if (valid) { + let word = W[rb + c]; let sc = scale[col*m.gpr + ((c*8u) >> 7u)]; + let w0=f32(i32(word<<28u)>>28u)*sc; let w1=f32(i32(word<<24u)>>28u)*sc; + let w2=f32(i32(word<<20u)>>28u)*sc; let w3=f32(i32(word<<16u)>>28u)*sc; + let w4=f32(i32(word<<12u)>>28u)*sc; let w5=f32(i32(word<<8u)>>28u)*sc; + let w6=f32(i32(word<<4u)>>28u)*sc; let w7=f32(i32(word)>>28u)*sc; + for (var t = 0u; t < BM; t = t + 1u) { + let b = t*8u; + acc[t] = acc[t] + As[b]*w0+As[b+1u]*w1+As[b+2u]*w2+As[b+3u]*w3+As[b+4u]*w4+As[b+5u]*w5+As[b+6u]*w6+As[b+7u]*w7; + } + } + workgroupBarrier(); + } + if (valid) { + let bv = select(0.0, bias[col], m.hasBias == 1u); + for (var t = 0u; t < BM; t = t + 1u) { let trow = tTile + t; if (trow < m.T) { Y[trow*m.N + col] = acc[t] + bv; } } + } +}`; +var GEMM4_ADD_T = ` +requires immediate_address_space; +struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; +@group(0) @binding(0) var A: array; +@group(0) @binding(1) var W: array; +@group(0) @binding(2) var scale: array; +@group(0) @binding(3) var bias: array; +@group(0) @binding(4) var Y: array; +var m: Meta; +const BM = 16u; const BN = 64u; +var As: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; + let K8 = m.K/8u; let rb = col*K8; + var acc: array; + for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } + for (var c = 0u; c < K8; c = c + 1u) { + for (var l = lid.x; l < BM*8u; l = l + 64u) { + let tt = l / 8u; let trow = tTile + tt; + As[l] = select(0.0, A[trow*m.K + c*8u + (l % 8u)], trow < m.T); + } + workgroupBarrier(); + if (valid) { + let word = W[rb + c]; let sc = scale[col*m.gpr + ((c*8u) >> 7u)]; + let w0=f32(i32(word<<28u)>>28u)*sc; let w1=f32(i32(word<<24u)>>28u)*sc; + let w2=f32(i32(word<<20u)>>28u)*sc; let w3=f32(i32(word<<16u)>>28u)*sc; + let w4=f32(i32(word<<12u)>>28u)*sc; let w5=f32(i32(word<<8u)>>28u)*sc; + let w6=f32(i32(word<<4u)>>28u)*sc; let w7=f32(i32(word)>>28u)*sc; + for (var t = 0u; t < BM; t = t + 1u) { + let b = t*8u; + acc[t] = acc[t] + As[b]*w0+As[b+1u]*w1+As[b+2u]*w2+As[b+3u]*w3+As[b+4u]*w4+As[b+5u]*w5+As[b+6u]*w6+As[b+7u]*w7; + } + } + workgroupBarrier(); + } + if (valid) { + let bv = select(0.0, bias[col], m.hasBias == 1u); + for (var t = 0u; t < BM; t = t + 1u) { + let trow = tTile + t; + if (trow < m.T) { Y[trow*m.N + col] = Y[trow*m.N + col] + acc[t] + bv; } + } + } +}`; +var ADD = ` +requires immediate_address_space; +requires linear_indexing; +override WG: u32 = 256u; +@group(0) @binding(0) var a: array; +@group(0) @binding(1) var y: array; +var n: u32; +@compute @workgroup_size(WG) +fn main(@builtin(global_invocation_index) gid: u32, @builtin(num_workgroups) nwg: vec3) { + let stride = nwg.x * WG; + for (var i = gid; i < n; i = i + stride) { y[i] = y[i] + a[i]; } +}`; +var ADD_F16 = ` +requires immediate_address_space; +requires linear_indexing; +enable f16; +override WG: u32 = 256u; +@group(0) @binding(0) var a: array; +@group(0) @binding(1) var y: array; +var n: u32; +@compute @workgroup_size(WG) +fn main(@builtin(global_invocation_index) gid: u32, @builtin(num_workgroups) nwg: vec3) { + let stride = nwg.x * WG; + for (var i = gid; i < n; i = i + stride) { y[i] = f32(f16(y[i]) + f16(a[i])); } +}`; +var SILUMUL_F16 = ` +requires immediate_address_space; +enable f16; +override WG: u32 = 256u; +@group(0) @binding(0) var gate: array; +@group(0) @binding(1) var up: array; +var n: u32; +@compute @workgroup_size(WG) +fn main(@builtin(global_invocation_id) g: vec3, @builtin(num_workgroups) nwg: vec3) { + let stride = nwg.x * WG; + // Activation (silu) in f32 to avoid the f16 exp(-v) -> Inf intermediate for very + // negative v; only the bandwidth-bound elementwise multiply carries f16 rounding. + for (var i = g.x; i < n; i = i + stride) { let v = gate[i]; let sg = v / (1.0 + exp(-v)); gate[i] = f32( f16(sg) * f16(up[i]) ); } +}`; +var SILUMUL = ` +requires immediate_address_space; +override WG: u32 = 256u; +@group(0) @binding(0) var gate: array; +@group(0) @binding(1) var up: array; +var n: u32; +@compute @workgroup_size(WG) +fn main(@builtin(global_invocation_id) g: vec3, @builtin(num_workgroups) nwg: vec3) { + let stride = nwg.x * WG; + for (var i = g.x; i < n; i = i + stride) { let v = gate[i]; gate[i] = (v/(1.0+exp(-v)))*up[i]; } +}`; +var EMBED = ` +requires immediate_address_space; +@group(0) @binding(0) var w: array; +@group(0) @binding(1) var scale: array; +@group(0) @binding(2) var out: array; +var m: vec2; // id, hidden +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) g: vec3) { + let k = g.x; let id = m.x; let H = m.y; if (k >= H) { return; } + let v = unpack4xI8(w[id*(H/4u) + (k>>2u)]); let lane = k & 3u; + var b: i32; if (lane==0u){b=v.x;} else if (lane==1u){b=v.y;} else if (lane==2u){b=v.z;} else {b=v.w;} + out[k] = f32(b) * scale[id]; +}`; +var EMBED_BUF = ` +requires immediate_address_space; +@group(0) @binding(0) var w: array; +@group(0) @binding(1) var scale: array; +@group(0) @binding(2) var out: array; +@group(0) @binding(3) var idbuf: array; // idbuf[0] = token id +var H: u32; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) g: vec3) { + let k = g.x; let id = idbuf[0]; if (k >= H) { return; } + let v = unpack4xI8(w[id*(H/4u) + (k>>2u)]); let lane = k & 3u; + var b: i32; if (lane==0u){b=v.x;} else if (lane==1u){b=v.y;} else if (lane==2u){b=v.z;} else {b=v.w;} + out[k] = f32(b) * scale[id]; +}`; +var RMSNORM_T = ` +requires immediate_address_space; +override WG: u32 = 256u; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var g: array; +@group(0) @binding(2) var y: array; +var m: vec2; // K, eps +var part: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; let K = u32(m.x); let base = wid.x * K; + var s = 0.0; for (var k = tid; k < K; k = k + WG) { let v = x[base+k]; s = s + v*v; } + part[tid] = s; workgroupBarrier(); + for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } + let inv = inverseSqrt(part[0]/m.x + m.y); + for (var k = tid; k < K; k = k + WG) { y[base+k] = x[base+k]*inv*g[k]; } +}`; +var RMSNORM_T_F16 = ` +requires immediate_address_space; +enable f16; +override WG: u32 = 256u; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var g: array; +@group(0) @binding(2) var y: array; +var m: vec2; // K, eps +// f32 reduction (see RMSNORM_F16): overflow-safe sum-of-squares, f16 normalize. +var part: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; let K = u32(m.x); let base = wid.x * K; + var s = 0.0; + for (var k = tid; k < K; k = k + WG) { let v = f32(x[base+k]); s = s + v*v; } + part[tid] = s; workgroupBarrier(); + for (var t = WG / 2u; t > 0u; t = t/2u) { if (tid < t) { part[tid] = part[tid] + part[tid+t]; } workgroupBarrier(); } + let inv = f16(inverseSqrt(part[0]/m.x + m.y)); + for (var k = tid; k < K; k = k + WG) { y[base+k] = f32( f16(x[base+k]) * inv * f16(g[k]) ); } +}`; +var ROPE_T = ` +requires immediate_address_space; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var cosT: array; +@group(0) @binding(2) var sinT: array; +var m: vec4; // nHeads, headDim, T, pos0 +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let g = gid.x; let H = m.x; let D = m.y; let T = m.z; let pos0 = m.w; let half = D/2u; + let perRow = H*half; if (g >= T*perRow) { return; } + let row = g / perRow; let r = g % perRow; let h = r / half; let j = r % half; + let rb = row*H*D; let lo = rb + h*D + j; let hi = lo + half; let off = (pos0+row)*D + j; + let c = cosT[off]; let s = sinT[off]; let xl = x[lo]; let xh = x[hi]; + x[lo] = fma(xl, c, 0.0) + fma(-xh, s, 0.0); x[hi] = fma(xh, c, 0.0) + fma(xl, s, 0.0); +}`; +var ROPE_T_F16 = ` +requires immediate_address_space; +enable f16; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var cosT: array; +@group(0) @binding(2) var sinT: array; +var m: vec4; // nHeads, headDim, T, pos0 +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let g = gid.x; let H = m.x; let D = m.y; let T = m.z; let pos0 = m.w; let half = D/2u; + let perRow = H*half; if (g >= T*perRow) { return; } + let row = g / perRow; let r = g % perRow; let h = r / half; let j = r % half; + let rb = row*H*D; let lo = rb + h*D + j; let hi = lo + half; let off = (pos0+row)*D + j; + let c = f16(cosT[off]); let s = f16(sinT[off]); let xl = f16(x[lo]); let xh = f16(x[hi]); + x[lo] = f32( fma(xl, c, 0.0h) + fma(-xh, s, 0.0h) ); x[hi] = f32( fma(xh, c, 0.0h) + fma(xl, s, 0.0h) ); +}`; +var EMBED_T = ` +requires immediate_address_space; +@group(0) @binding(0) var w: array; +@group(0) @binding(1) var scale: array; +@group(0) @binding(2) var out: array; +@group(0) @binding(3) var ids: array; +var m: vec4; // T, H, idOffset, _ +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let T = m.x; let H = m.y; let N = T*H; let stride = nwg.x * 256u; + for (var i = gid.x; i < N; i = i + stride) { + let t = i / H; let k = i % H; let id = ids[m.z + t]; + let v = unpack4xI8(w[id*(H/4u) + (k>>2u)]); let lane = k & 3u; + var b: i32; if (lane==0u){b=v.x;} else if (lane==1u){b=v.y;} else if (lane==2u){b=v.z;} else {b=v.w;} + out[i] = f32(b) * scale[id]; + } +}`; +var ATTN_PREFILL = ` +enable subgroups; +requires immediate_address_space; +@group(0) @binding(0) var q: array; // [T][nHeads*hd] +@group(0) @binding(1) var kc: array; // [ctx][nKV*hd] +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var o: array; // [T][nHeads*hd] +var m: vec4; // nHeads, nKV, hd, T +var ps: array; // exp-scores for the current key block +var acc: array; // running weighted-V accumulator (hd<=128) +var red: array; +@compute @workgroup_size(256) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let h = wid.x; let t = wid.y; let tid = lid.x; let nHeads = m.x; let nKV = m.y; let hd = m.z; + let ctx = t + 1u; let kvh = h / (nHeads / nKV); + let qbase = t*nHeads*hd + h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scl = 1.0/sqrt(f32(hd)); + let nsg = (256u + sgsz - 1u) / sgsz; + for (var d = tid; d < hd; d = d + 256u) { acc[d] = 0.0; } + var mrun = -1e30; var lrun = 0.0; + let nblk = (ctx + 255u) / 256u; + for (var blk = 0u; blk < nblk; blk = blk + 1u) { + let kbase = blk*256u; let kk = kbase + tid; + var s = -1e30; + if (kk < ctx) { var dot = 0.0; let kb = kk*stride + hoff; for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } s = dot*scl; } + let sgm = subgroupMax(s); if (sgid == 0u) { red[tid/sgsz] = sgm; } + workgroupBarrier(); // A: block-max partials visible + var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[i]); } + let mnew = max(mrun, bm); let corr = exp(mrun - mnew); + var p = 0.0; if (kk < ctx) { p = exp(s - mnew); } + ps[tid] = p; + workgroupBarrier(); // B: bm reads done + ps visible + let sgs = subgroupAdd(p); if (sgid == 0u) { red[tid/sgsz] = sgs; } + workgroupBarrier(); // C: block-sum partials visible + var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[i]; } + lrun = lrun*corr + bs; + let bcount = min(256u, ctx - kbase); + for (var d = tid; d < hd; d = d + 256u) { + var aa = acc[d]*corr; + for (var j = 0u; j < bcount; j = j + 1u) { aa = aa + ps[j]*vc[(kbase+j)*stride + hoff + d]; } + acc[d] = aa; + } + mrun = mnew; + workgroupBarrier(); // D: acc's ps reads done before next block + } + let invL = 1.0/lrun; + for (var d = tid; d < hd; d = d + 256u) { o[qbase + d] = acc[d]*invL; } +}`; +var ATTN_PREFILL_BLOCK = ` +enable subgroups; +requires immediate_address_space; +struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32, qStart:u32, ctx:u32, p0:u32, p1:u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var o: array; +var m: Meta; +const BQ = 4u; const BK = 128u; +var ps: array; // BQ*BK +var acc: array; // BQ*hd (hd<=128) +var red: array; // BQ*subgroup-count +@compute @workgroup_size(128) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let h = wid.x; let qBlock = wid.y; let tid = lid.x; let hd = m.hd; + let kvh = h / (m.nHeads / m.nKV); let stride = m.nKV * hd; let hoff = kvh * hd; + let nsg = (128u + sgsz - 1u) / sgsz; let scl = 1.0 / sqrt(f32(hd)); + var mrun: array; var lrun: array; + for (var r = 0u; r < BQ; r = r + 1u) { mrun[r] = -1e30; lrun[r] = 0.0; } + for (var i = tid; i < BQ*hd; i = i + 128u) { acc[i] = 0.0; } + workgroupBarrier(); + let nblk = (m.ctx + BK - 1u) / BK; + for (var blk = 0u; blk < nblk; blk = blk + 1u) { + let kbase = blk * BK; let kk = kbase + tid; + var score: array; + var validQ: array; + var dot: array; + var corrRun: array; + for (var r = 0u; r < BQ; r = r + 1u) { + let qt = qBlock * BQ + r; let absQ = m.qStart + qt; + validQ[r] = qt < m.T && kk < m.ctx && kk <= absQ; + dot[r] = 0.0; score[r] = -1e30; + } + if (kk < m.ctx) { + let kb = kk*stride + hoff; + for (var d = 0u; d < hd; d = d + 1u) { + let kval = kc[kb+d]; + for (var r = 0u; r < BQ; r = r + 1u) { + let qt = qBlock * BQ + r; + if (validQ[r]) { dot[r] = dot[r] + q[qt*m.nHeads*hd + h*hd + d] * kval; } + } + } + for (var r = 0u; r < BQ; r = r + 1u) { + if (validQ[r]) { score[r] = dot[r] * scl; } + } + } + for (var r = 0u; r < BQ; r = r + 1u) { + let s = score[r]; + let sgm = subgroupMax(s); + if (sgid == 0u) { red[r*32u + tid/sgsz] = sgm; } + workgroupBarrier(); + var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[r*32u+i]); } + let mnew = max(mrun[r], bm); let corr = exp(mrun[r] - mnew); + corrRun[r] = corr; + var p = 0.0; if (validQ[r]) { p = exp(s - mnew); } + ps[r*BK + tid] = p; + workgroupBarrier(); + let sgs = subgroupAdd(p); + if (sgid == 0u) { red[r*32u + tid/sgsz] = sgs; } + workgroupBarrier(); + var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[r*32u+i]; } + lrun[r] = lrun[r] * corr + bs; + mrun[r] = mnew; + workgroupBarrier(); + } + let bcount = min(BK, m.ctx - kbase); + for (var d = tid; d < hd; d = d + 128u) { + var aa: array; + for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = acc[r*hd+d] * corrRun[r]; } + for (var j = 0u; j < bcount; j = j + 1u) { + let vv = vc[(kbase+j)*stride + hoff + d]; + for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = aa[r] + ps[r*BK+j] * vv; } + } + for (var r = 0u; r < BQ; r = r + 1u) { acc[r*hd+d] = aa[r]; } + } + workgroupBarrier(); + } + for (var r = 0u; r < BQ; r = r + 1u) { + let qt = qBlock * BQ + r; + if (qt < m.T) { + let invL = 1.0 / lrun[r]; let ob = qt*m.nHeads*hd + h*hd; + for (var d = tid; d < hd; d = d + 128u) { o[ob+d] = acc[r*hd+d] * invL; } + } + } +}`; +var ARGMAX = ` +requires immediate_address_space; +@group(0) @binding(0) var logits: array; +@group(0) @binding(1) var out: array; +var n: u32; +var bv: array; var bi: array; +@compute @workgroup_size(256) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; var v = -1e30; var idx = 0xffffffffu; + for (var i = tid; i < n; i = i + 256u) { let x = logits[i]; if (x > v || (x == v && i < idx)) { v = x; idx = i; } } + bv[tid] = v; bi[tid] = idx; workgroupBarrier(); + for (var s = 128u; s > 0u; s = s/2u) { if (tid < s) { let ov = bv[tid+s]; let oi = bi[tid+s]; if (ov > bv[tid] || (ov == bv[tid] && oi < bi[tid])) { bv[tid] = ov; bi[tid] = oi; } } workgroupBarrier(); } + if (tid == 0u) { out[0] = bi[0]; } +}`; +var TOPK_SELECT = ` +requires immediate_address_space; +@group(0) @binding(0) var logits: array; +@group(0) @binding(1) var ids: array; +@group(0) @binding(2) var vals: array; +var m: vec2; // vocabSize, selectedCount +var bv: array; var bi: array; +fn alreadySelected(id: u32, n: u32) -> bool { + for (var j = 0u; j < n; j = j + 1u) { if (ids[j] == id) { return true; } } + return false; +} +@compute @workgroup_size(256) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; let n = m.x; let selected = m.y; + var v = -1e30; var idx = 0xffffffffu; + for (var i = tid; i < n; i = i + 256u) { + let x = logits[i]; + if (!alreadySelected(i, selected) && (x > v || (x == v && i < idx))) { v = x; idx = i; } + } + bv[tid] = v; bi[tid] = idx; workgroupBarrier(); + for (var s = 128u; s > 0u; s = s/2u) { + if (tid < s) { + let ov = bv[tid+s]; let oi = bi[tid+s]; + if (ov > bv[tid] || (ov == bv[tid] && oi < bi[tid])) { bv[tid] = ov; bi[tid] = oi; } + } + workgroupBarrier(); + } + if (tid == 0u) { ids[selected] = bi[0]; vals[selected] = bv[0]; } +}`; +var SAMPLE_TOPK = ` +requires immediate_address_space; +struct Meta { k:u32, pad:u32, temp:f32, r:f32 }; +@group(0) @binding(0) var ids: array; +@group(0) @binding(1) var vals: array; +@group(0) @binding(2) var outId: array; // [1] the chosen token +var m: Meta; +var s: array; // working softmax probs / prefix sums (small k) +var red: array; // reduction scratch for the softmax denominator +@compute @workgroup_size(64) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let k = m.k; + let temp = m.temp; + let r = m.r; + let t = select(temp, 1.0, temp <= 0.0); + + // Load + temperature scale into shared (one thread per slot) + var v = -1e30; + if (tid < k) { + let lv = vals[tid]; + v = lv; + if (t != 1.0) { v = lv / t; } + } + let ev = select(0.0, exp(v), tid < k); + s[tid] = ev; + red[tid] = ev; + workgroupBarrier(); + + // sum + for (var stride = 32u; stride > 0u; stride = stride / 2u) { + if (tid < stride && (tid + stride) < 64u) { red[tid] = red[tid] + red[tid + stride]; } + workgroupBarrier(); + } + let sum = red[0]; + let invSum = select(0.0, 1.0 / sum, sum > 0.0); + + // normalize + prefix sum for nucleus / categorical pick + if (tid < k) { + s[tid] = s[tid] * invSum; + } else { + s[tid] = 0.0; + } + workgroupBarrier(); + + // prefix sum (small k, simple scan) + for (var stride = 1u; stride < 64u; stride = stride * 2u) { + var add = 0.0; + if (tid >= stride && tid < 64u) { + add = s[tid - stride]; + } + workgroupBarrier(); + if (tid >= stride && tid < 64u) { + s[tid] = s[tid] + add; + } + workgroupBarrier(); + } + + // find the smallest j such that prefix[j] >= r (or last if r>=1) + if (tid == 0u) { + var chosen = select(0u, k - 1u, k > 0u); + if (sum > 0.0) { + for (var j = 0u; j < k; j = j + 1u) { + let pj = s[j]; + if (r <= pj) { chosen = j; break; } + } + } + outId[0] = select(0u, ids[chosen], k > 0u); + } +}`; +var GEMV4 = ` +enable subgroups; +requires immediate_address_space; +struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var w: array; +@group(0) @binding(2) var scale: array; +@group(0) @binding(3) var bias: array; +@group(0) @binding(4) var loraD: array; +@group(0) @binding(5) var loraB: array; +@group(0) @binding(6) var y: array; +var m: Meta; +var part: array; // one slot per subgroup +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.N) { return; } // workgroup-uniform: whole group exits together + let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; + var acc = 0.0; + for (var c = tid; c < K8; c = c + 64u) { + let word = w[rb+c]; let bk = c*8u; let sc = scale[sbase + (bk >> 7u)]; + var p = 0.0; + p = p + x[bk] * f32(i32(word << 28u) >> 28u); + p = p + x[bk+1u] * f32(i32(word << 24u) >> 28u); + p = p + x[bk+2u] * f32(i32(word << 20u) >> 28u); + p = p + x[bk+3u] * f32(i32(word << 16u) >> 28u); + p = p + x[bk+4u] * f32(i32(word << 12u) >> 28u); + p = p + x[bk+5u] * f32(i32(word << 8u) >> 28u); + p = p + x[bk+6u] * f32(i32(word << 4u) >> 28u); + p = p + x[bk+7u] * f32(i32(word) >> 28u); + acc = acc + p * sc; + } + let ssum = subgroupAdd(acc); // reduce within subgroup (no barrier) + if (sgid == 0u) { part[tid / sgsz] = ssum; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (64u + sgsz - 1u) / sgsz; var o = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } + if (m.hasBias == 1u) { o = o + bias[n]; } + if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } + y[n] = o; + } +}`; +var GEMV4_ADD = ` +enable subgroups; +requires immediate_address_space; +struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var w: array; +@group(0) @binding(2) var scale: array; +@group(0) @binding(3) var bias: array; +@group(0) @binding(4) var loraD: array; +@group(0) @binding(5) var loraB: array; +@group(0) @binding(6) var y: array; +var m: Meta; +var part: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.N) { return; } + let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; + var acc = 0.0; + for (var c = tid; c < K8; c = c + 64u) { + let word = w[rb+c]; let bk = c*8u; let sc = scale[sbase + (bk >> 7u)]; + var p = 0.0; + p = p + x[bk] * f32(i32(word << 28u) >> 28u); + p = p + x[bk+1u] * f32(i32(word << 24u) >> 28u); + p = p + x[bk+2u] * f32(i32(word << 20u) >> 28u); + p = p + x[bk+3u] * f32(i32(word << 16u) >> 28u); + p = p + x[bk+4u] * f32(i32(word << 12u) >> 28u); + p = p + x[bk+5u] * f32(i32(word << 8u) >> 28u); + p = p + x[bk+6u] * f32(i32(word << 4u) >> 28u); + p = p + x[bk+7u] * f32(i32(word) >> 28u); + acc = acc + p * sc; + } + let ssum = subgroupAdd(acc); + if (sgid == 0u) { part[tid / sgsz] = ssum; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (64u + sgsz - 1u) / sgsz; var o = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } + if (m.hasBias == 1u) { o = o + bias[n]; } + if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } + y[n] = y[n] + o; + } +}`; +var QKV_GEMV4 = ` +enable subgroups; +requires immediate_address_space; +struct Meta { K:u32, totalN:u32, qN:u32, kN:u32, vN:u32, gpr:u32, gridX:u32, p0:u32 }; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var w: array; +@group(0) @binding(2) var scale: array; +@group(0) @binding(3) var bias: array; +@group(0) @binding(4) var qOut: array; +@group(0) @binding(5) var kOut: array; +@group(0) @binding(6) var vOut: array; +var m: Meta; +var part: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.totalN) { return; } + let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; + var acc = 0.0; + for (var c = tid; c < K8; c = c + 64u) { + let word = w[rb+c]; let bk = c*8u; let sc = scale[sbase + (bk >> 7u)]; + var p = 0.0; + p = p + x[bk] * f32(i32(word << 28u) >> 28u); + p = p + x[bk+1u] * f32(i32(word << 24u) >> 28u); + p = p + x[bk+2u] * f32(i32(word << 20u) >> 28u); + p = p + x[bk+3u] * f32(i32(word << 16u) >> 28u); + p = p + x[bk+4u] * f32(i32(word << 12u) >> 28u); + p = p + x[bk+5u] * f32(i32(word << 8u) >> 28u); + p = p + x[bk+6u] * f32(i32(word << 4u) >> 28u); + p = p + x[bk+7u] * f32(i32(word) >> 28u); + acc = acc + p * sc; + } + let ssum = subgroupAdd(acc); + if (sgid == 0u) { part[tid / sgsz] = ssum; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (64u + sgsz - 1u) / sgsz; var o = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } + o = o + bias[n]; + if (n < m.qN) { + qOut[n] = o; + } else if (n < m.qN + m.kN) { + kOut[n - m.qN] = o; + } else { + vOut[n - m.qN - m.kN] = o; + } + } +}`; +var GATE_UP_SILU_GEMV4 = ` +enable subgroups; +requires immediate_address_space; +struct Meta { K:u32, N:u32, gpr:u32, gridX:u32, gateRank:u32, upRank:u32, hasGateLora:u32, hasUpLora:u32, gateScaleLo:f32, upScaleLo:f32, p0:f32, p1:f32 }; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var w: array; +@group(0) @binding(2) var scale: array; +@group(0) @binding(3) var y: array; +@group(0) @binding(4) var gateD: array; +@group(0) @binding(5) var gateB: array; +@group(0) @binding(6) var upD: array; +@group(0) @binding(7) var upB: array; +var m: Meta; +var partG: array; +var partU: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.N) { return; } + let K8 = m.K/8u; let rbG = n*K8; let rbU = (m.N + n)*K8; + let sbG = n*m.gpr; let sbU = (m.N + n)*m.gpr; + var accG = 0.0; var accU = 0.0; + for (var c = tid; c < K8; c = c + 64u) { + let bk = c*8u; let wg = w[rbG+c]; let wu = w[rbU+c]; + let scG = scale[sbG + (bk >> 7u)]; let scU = scale[sbU + (bk >> 7u)]; + let x0=x[bk]; let x1=x[bk+1u]; let x2=x[bk+2u]; let x3=x[bk+3u]; + let x4=x[bk+4u]; let x5=x[bk+5u]; let x6=x[bk+6u]; let x7=x[bk+7u]; + var pg = 0.0; var pu = 0.0; + pg = pg + x0*f32(i32(wg<<28u)>>28u) + x1*f32(i32(wg<<24u)>>28u) + x2*f32(i32(wg<<20u)>>28u) + x3*f32(i32(wg<<16u)>>28u); + pg = pg + x4*f32(i32(wg<<12u)>>28u) + x5*f32(i32(wg<<8u)>>28u) + x6*f32(i32(wg<<4u)>>28u) + x7*f32(i32(wg)>>28u); + pu = pu + x0*f32(i32(wu<<28u)>>28u) + x1*f32(i32(wu<<24u)>>28u) + x2*f32(i32(wu<<20u)>>28u) + x3*f32(i32(wu<<16u)>>28u); + pu = pu + x4*f32(i32(wu<<12u)>>28u) + x5*f32(i32(wu<<8u)>>28u) + x6*f32(i32(wu<<4u)>>28u) + x7*f32(i32(wu)>>28u); + accG = accG + pg * scG; accU = accU + pu * scU; + } + let sg = subgroupAdd(accG); let su = subgroupAdd(accU); + if (sgid == 0u) { partG[tid / sgsz] = sg; partU[tid / sgsz] = su; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (64u + sgsz - 1u) / sgsz; var gate = 0.0; var up = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { gate = gate + partG[i]; up = up + partU[i]; } + if (m.hasGateLora == 1u) { + var dl = 0.0; for (var r = 0u; r < m.gateRank; r = r + 1u) { dl = dl + gateD[r] * gateB[r*m.N + n]; } + gate = gate + m.gateScaleLo * dl; + } + if (m.hasUpLora == 1u) { + var dl = 0.0; for (var r = 0u; r < m.upRank; r = r + 1u) { dl = dl + upD[r] * upB[r*m.N + n]; } + up = up + m.upScaleLo * dl; + } + y[n] = (gate / (1.0 + exp(-gate))) * up; + } +}`; +var DYN_QUANT_X = ` +requires immediate_address_space; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var x_q: array; +@group(0) @binding(2) var scale_x: array; +var K: u32; +var sh_max: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let g = wid.x; let tid = lid.x; let base = g * 128u; + var local_max = 0.0; + let idx0 = base + tid; let idx1 = base + tid + 64u; + if (idx0 < K) { local_max = max(local_max, abs(x[idx0])); } + if (idx1 < K) { local_max = max(local_max, abs(x[idx1])); } + sh_max[tid] = local_max; + workgroupBarrier(); + for (var s = 32u; s > 0u; s = s / 2u) { + if (tid < s) { sh_max[tid] = max(sh_max[tid], sh_max[tid + s]); } + workgroupBarrier(); + } + let gmax = sh_max[0]; let scale = select(gmax / 127.0, 1.0, gmax == 0.0); + if (tid == 0u) { scale_x[g] = scale; } + let pidx = base + tid * 4u; + if (pidx < K) { + let q0 = clamp(i32(round(x[pidx] / scale)), -128, 127) & 0xff; + let q1 = clamp(i32(round(x[pidx + 1u] / scale)), -128, 127) & 0xff; + let q2 = clamp(i32(round(x[pidx + 2u] / scale)), -128, 127) & 0xff; + let q3 = clamp(i32(round(x[pidx + 3u] / scale)), -128, 127) & 0xff; + x_q[g * 32u + tid] = u32(q0 | (q1 << 8u) | (q2 << 16u) | (q3 << 24u)); + } +}`; +var DYN_QUANT_X_T = ` +requires immediate_address_space; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var x_q: array; +@group(0) @binding(2) var scale_x: array; +var m: vec2; // K, T +var sh_max: array; +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let g = wid.x; let t = wid.y; let tid = lid.x; let K = m.x; let T = m.y; + if (t >= T) { return; } + let row_base = t * K; let base = row_base + g * 128u; + var local_max = 0.0; + let idx0 = base + tid; let idx1 = base + tid + 64u; + if (g * 128u + tid < K) { local_max = max(local_max, abs(x[idx0])); } + if (g * 128u + tid + 64u < K) { local_max = max(local_max, abs(x[idx1])); } + sh_max[tid] = local_max; + workgroupBarrier(); + for (var s = 32u; s > 0u; s = s / 2u) { + if (tid < s) { sh_max[tid] = max(sh_max[tid], sh_max[tid + s]); } + workgroupBarrier(); + } + let gmax = sh_max[0]; let scale = select(gmax / 127.0, 1.0, gmax == 0.0); + let groupsPerRow = K / 128u; + if (tid == 0u) { scale_x[t * groupsPerRow + g] = scale; } + let pidx = base + tid * 4u; + if (g * 128u + tid * 4u < K) { + let q0 = clamp(i32(round(x[pidx] / scale)), -128, 127) & 0xff; + let q1 = clamp(i32(round(x[pidx + 1u] / scale)), -128, 127) & 0xff; + let q2 = clamp(i32(round(x[pidx + 2u] / scale)), -128, 127) & 0xff; + let q3 = clamp(i32(round(x[pidx + 3u] / scale)), -128, 127) & 0xff; + x_q[t * (K / 4u) + g * 32u + tid] = u32(q0 | (q1 << 8u) | (q2 << 16u) | (q3 << 24u)); + } +}`; +var GEMV4_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` +enable subgroups; +${hasDP4a ? ` +enable packed_4x8_integer_dot_product; +` : ""} +requires immediate_address_space; +struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; +@group(0) @binding(0) var x_q: array; +@group(0) @binding(1) var scale_x: array; +@group(0) @binding(2) var w: array; +@group(0) @binding(3) var scale: array; +@group(0) @binding(4) var bias: array; +@group(0) @binding(5) var loraD: array; +@group(0) @binding(6) var loraB: array; +@group(0) @binding(7) var y: array; +var m: Meta; + +${hasDP4a ? "" : ` +fn dot4I8Packed(a: u32, b: u32) -> i32 { + let va = unpack4xI8(a); + let vb = unpack4xI8(b); + return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; +} +`} + +var part: array; +@compute @workgroup_size(${wgSize}) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.N) { return; } + let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; + var acc = 0.0; + for (var c = tid; c < K8; c = c + ${wgSize}u) { + let word = w[rb+c]; let bk = c*8u; + let sc_w = scale[sbase + (bk >> 7u)]; + let sc_x = scale_x[bk >> 7u]; + let w0 = (i32(word << 28u) >> 28u) & 0xff; + let w1 = (i32(word << 24u) >> 28u) & 0xff; + let w2 = (i32(word << 20u) >> 28u) & 0xff; + let w3 = (i32(word << 16u) >> 28u) & 0xff; + let w4 = (i32(word << 12u) >> 28u) & 0xff; + let w5 = (i32(word << 8u) >> 28u) & 0xff; + let w6 = (i32(word << 4u) >> 28u) & 0xff; + let w7 = (i32(word) >> 28u) & 0xff; + let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); + let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); + let px0 = x_q[c * 2u]; + let px1 = x_q[c * 2u + 1u]; + let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); + acc = acc + f32(sum) * sc_w * sc_x; + } + let ssum = subgroupAdd(acc); + if (sgid == 0u) { part[tid / sgsz] = ssum; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var o = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } + if (m.hasBias == 1u) { o = o + bias[n]; } + if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } + y[n] = o; + } +} +`, "GEMV4_W4A8"); +var GEMV4_ADD_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` +enable subgroups; +${hasDP4a ? ` +enable packed_4x8_integer_dot_product; +` : ""} +requires immediate_address_space; +struct Meta { K:u32, N:u32, rank:u32, hasBias:u32, hasLora:u32, gridX:u32, scaleLo:f32, gpr:u32 }; +@group(0) @binding(0) var x_q: array; +@group(0) @binding(1) var scale_x: array; +@group(0) @binding(2) var w: array; +@group(0) @binding(3) var scale: array; +@group(0) @binding(4) var bias: array; +@group(0) @binding(5) var loraD: array; +@group(0) @binding(6) var loraB: array; +@group(0) @binding(7) var y: array; +var m: Meta; + +${hasDP4a ? "" : ` +fn dot4I8Packed(a: u32, b: u32) -> i32 { + let va = unpack4xI8(a); + let vb = unpack4xI8(b); + return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; +} +`} + +var part: array; +@compute @workgroup_size(${wgSize}) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.N) { return; } + let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; + var acc = 0.0; + for (var c = tid; c < K8; c = c + ${wgSize}u) { + let word = w[rb+c]; let bk = c*8u; + let sc_w = scale[sbase + (bk >> 7u)]; + let sc_x = scale_x[bk >> 7u]; + let w0 = (i32(word << 28u) >> 28u) & 0xff; + let w1 = (i32(word << 24u) >> 28u) & 0xff; + let w2 = (i32(word << 20u) >> 28u) & 0xff; + let w3 = (i32(word << 16u) >> 28u) & 0xff; + let w4 = (i32(word << 12u) >> 28u) & 0xff; + let w5 = (i32(word << 8u) >> 28u) & 0xff; + let w6 = (i32(word << 4u) >> 28u) & 0xff; + let w7 = (i32(word) >> 28u) & 0xff; + let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); + let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); + let px0 = x_q[c * 2u]; + let px1 = x_q[c * 2u + 1u]; + let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); + acc = acc + f32(sum) * sc_w * sc_x; + } + let ssum = subgroupAdd(acc); + if (sgid == 0u) { part[tid / sgsz] = ssum; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var o = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } + if (m.hasBias == 1u) { o = o + bias[n]; } + if (m.hasLora == 1u) { var dl = 0.0; for (var r = 0u; r < m.rank; r = r + 1u) { dl = dl + loraD[r] * loraB[r*m.N + n]; } o = o + m.scaleLo * dl; } + y[n] = y[n] + o; + } +} +`, "GEMV4_ADD_W4A8"); +var QKV_GEMV4_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` +enable subgroups; +${hasDP4a ? ` +enable packed_4x8_integer_dot_product; +` : ""} +requires immediate_address_space; +struct Meta { K:u32, totalN:u32, qN:u32, kN:u32, vN:u32, gpr:u32, gridX:u32, p0:u32 }; +@group(0) @binding(0) var x_q: array; +@group(0) @binding(1) var scale_x: array; +@group(0) @binding(2) var w: array; +@group(0) @binding(3) var scale: array; +@group(0) @binding(4) var bias: array; +@group(0) @binding(5) var qOut: array; +@group(0) @binding(6) var kOut: array; +@group(0) @binding(7) var vOut: array; +var m: Meta; + +${hasDP4a ? "" : ` +fn dot4I8Packed(a: u32, b: u32) -> i32 { + let va = unpack4xI8(a); + let vb = unpack4xI8(b); + return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; +} +`} + +var part: array; +@compute @workgroup_size(${wgSize}) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.totalN) { return; } + let K8 = m.K/8u; let rb = n*K8; let sbase = n*m.gpr; + var acc = 0.0; + for (var c = tid; c < K8; c = c + ${wgSize}u) { + let word = w[rb+c]; let bk = c*8u; + let sc_w = scale[sbase + (bk >> 7u)]; + let sc_x = scale_x[bk >> 7u]; + let w0 = (i32(word << 28u) >> 28u) & 0xff; + let w1 = (i32(word << 24u) >> 28u) & 0xff; + let w2 = (i32(word << 20u) >> 28u) & 0xff; + let w3 = (i32(word << 16u) >> 28u) & 0xff; + let w4 = (i32(word << 12u) >> 28u) & 0xff; + let w5 = (i32(word << 8u) >> 28u) & 0xff; + let w6 = (i32(word << 4u) >> 28u) & 0xff; + let w7 = (i32(word) >> 28u) & 0xff; + let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); + let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); + let px0 = x_q[c * 2u]; + let px1 = x_q[c * 2u + 1u]; + let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); + acc = acc + f32(sum) * sc_w * sc_x; + } + let ssum = subgroupAdd(acc); + if (sgid == 0u) { part[tid / sgsz] = ssum; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var o = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { o = o + part[i]; } + o = o + bias[n]; + if (n < m.qN) { + qOut[n] = o; + } else if (n < m.qN + m.kN) { + kOut[n - m.qN] = o; + } else { + vOut[n - m.qN - m.kN] = o; + } + } +} +`, "QKV_GEMV4_W4A8"); +var GATE_UP_SILU_GEMV4_W4A8 = /* @__PURE__ */ __name((hasDP4a, wgSize = 64) => ` +enable subgroups; +${hasDP4a ? ` +enable packed_4x8_integer_dot_product; +` : ""} +requires immediate_address_space; +struct Meta { K:u32, N:u32, gpr:u32, gridX:u32, gateRank:u32, upRank:u32, hasGateLora:u32, hasUpLora:u32, gateScaleLo:f32, upScaleLo:f32, p0:f32, p1:f32 }; +@group(0) @binding(0) var x_q: array; +@group(0) @binding(1) var scale_x: array; +@group(0) @binding(2) var w: array; +@group(0) @binding(3) var scale: array; +@group(0) @binding(4) var y: array; +@group(0) @binding(5) var gateD: array; +@group(0) @binding(6) var gateB: array; +@group(0) @binding(7) var upD: array; +@group(0) @binding(8) var upB: array; +var m: Meta; + +${hasDP4a ? "" : ` +fn dot4I8Packed(a: u32, b: u32) -> i32 { + let va = unpack4xI8(a); + let vb = unpack4xI8(b); + return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; +} +`} + +var partG: array; +var partU: array; +@compute @workgroup_size(${wgSize}) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let n = wid.x + wid.y * m.gridX; let tid = lid.x; + if (n >= m.N) { return; } + let K8 = m.K/8u; let rbG = n*K8; let rbU = (m.N + n)*K8; + let sbG = n*m.gpr; let sbU = (m.N + n)*m.gpr; + var accG = 0.0; var accU = 0.0; + for (var c = tid; c < K8; c = c + ${wgSize}u) { + let wg = w[rbG+c]; let wu = w[rbU+c]; + let bk = c*8u; + let scG = scale[sbG + (bk >> 7u)]; let scU = scale[sbU + (bk >> 7u)]; + let sc_x = scale_x[bk >> 7u]; + let wg0 = (i32(wg << 28u) >> 28u) & 0xff; + let wg1 = (i32(wg << 24u) >> 28u) & 0xff; + let wg2 = (i32(wg << 20u) >> 28u) & 0xff; + let wg3 = (i32(wg << 16u) >> 28u) & 0xff; + let wg4 = (i32(wg << 12u) >> 28u) & 0xff; + let wg5 = (i32(wg << 8u) >> 28u) & 0xff; + let wg6 = (i32(wg << 4u) >> 28u) & 0xff; + let wg7 = (i32(wg) >> 28u) & 0xff; + let pwg0 = u32(wg0 | (wg1 << 8u) | (wg2 << 16u) | (wg3 << 24u)); + let pwg1 = u32(wg4 | (wg5 << 8u) | (wg6 << 16u) | (wg7 << 24u)); + let wu0 = (i32(wu << 28u) >> 28u) & 0xff; + let wu1 = (i32(wu << 24u) >> 28u) & 0xff; + let wu2 = (i32(wu << 20u) >> 28u) & 0xff; + let wu3 = (i32(wu << 16u) >> 28u) & 0xff; + let wu4 = (i32(wu << 12u) >> 28u) & 0xff; + let wu5 = (i32(wu << 8u) >> 28u) & 0xff; + let wu6 = (i32(wu << 4u) >> 28u) & 0xff; + let wu7 = (i32(wu) >> 28u) & 0xff; + let pwu0 = u32(wu0 | (wu1 << 8u) | (wu2 << 16u) | (wu3 << 24u)); + let pwu1 = u32(wu4 | (wu5 << 8u) | (wu6 << 16u) | (wu7 << 24u)); + let px0 = x_q[c * 2u]; + let px1 = x_q[c * 2u + 1u]; + let sumG = dot4I8Packed(pwg0, px0) + dot4I8Packed(pwg1, px1); + let sumU = dot4I8Packed(pwu0, px0) + dot4I8Packed(pwu1, px1); + accG = accG + f32(sumG) * scG * sc_x; + accU = accU + f32(sumU) * scU * sc_x; + } + let sg = subgroupAdd(accG); let su = subgroupAdd(accU); + if (sgid == 0u) { partG[tid / sgsz] = sg; partU[tid / sgsz] = su; } + workgroupBarrier(); + if (tid == 0u) { + let nsg = (${wgSize}u + sgsz - 1u) / sgsz; var gate = 0.0; var up = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { gate = gate + partG[i]; up = up + partU[i]; } + if (m.hasGateLora == 1u) { + var dl = 0.0; for (var r = 0u; r < m.gateRank; r = r + 1u) { dl = dl + gateD[r] * gateB[r*m.N + n]; } + gate = gate + m.gateScaleLo * dl; + } + if (m.hasUpLora == 1u) { + var dl = 0.0; for (var r = 0u; r < m.upRank; r = r + 1u) { dl = dl + upD[r] * upB[r*m.N + n]; } + up = up + m.upScaleLo * dl; + } + y[n] = (gate / (1.0 + exp(-gate))) * up; + } +} +`, "GATE_UP_SILU_GEMV4_W4A8"); +var GEMM4_W4A8 = /* @__PURE__ */ __name((hasDP4a) => ` +enable subgroups; +${hasDP4a ? ` +enable packed_4x8_integer_dot_product; +` : ""} +requires immediate_address_space; +struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; +@group(0) @binding(0) var A_q: array; +@group(0) @binding(1) var scale_x: array; +@group(0) @binding(2) var W: array; +@group(0) @binding(3) var scale: array; +@group(0) @binding(4) var bias: array; +@group(0) @binding(5) var Y: array; +var m: Meta; + +${hasDP4a ? "" : ` +fn dot4I8Packed(a: u32, b: u32) -> i32 { + let va = unpack4xI8(a); + let vb = unpack4xI8(b); + return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; +} +`} + +const BM = 16u; const BN = 64u; +var As_q: array; +var As_scale: array; + +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; + let K8 = m.K/8u; let rb = col*K8; + var acc: array; + for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } + let groupsPerRow = m.K / 128u; + for (var c = 0u; c < K8; c = c + 1u) { + if (lid.x < BM * 2u) { + let tt = lid.x / 2u; let trow = tTile + tt; let wordIdx = lid.x % 2u; + As_q[lid.x] = select(0u, A_q[trow * (m.K / 4u) + c * 2u + wordIdx], trow < m.T); + } + if (lid.x < BM) { + let trow = tTile + lid.x; + As_scale[lid.x] = select(0.0, scale_x[trow * groupsPerRow + ((c * 8u) >> 7u)], trow < m.T); + } + workgroupBarrier(); + if (valid) { + let word = W[rb + c]; let sc_w = scale[col*m.gpr + ((c*8u) >> 7u)]; + let w0 = (i32(word << 28u) >> 28u) & 0xff; + let w1 = (i32(word << 24u) >> 28u) & 0xff; + let w2 = (i32(word << 20u) >> 28u) & 0xff; + let w3 = (i32(word << 16u) >> 28u) & 0xff; + let w4 = (i32(word << 12u) >> 28u) & 0xff; + let w5 = (i32(word << 8u) >> 28u) & 0xff; + let w6 = (i32(word << 4u) >> 28u) & 0xff; + let w7 = (i32(word) >> 28u) & 0xff; + let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); + let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); + for (var t = 0u; t < BM; t = t + 1u) { + let px0 = As_q[t * 2u]; let px1 = As_q[t * 2u + 1u]; + let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); + acc[t] = acc[t] + f32(sum) * sc_w * As_scale[t]; + } + } + workgroupBarrier(); + } + if (valid) { + let bv = select(0.0, bias[col], m.hasBias == 1u); + for (var t = 0u; t < BM; t = t + 1u) { let trow = tTile + t; if (trow < m.T) { Y[trow*m.N + col] = acc[t] + bv; } } + } +} +`, "GEMM4_W4A8"); +var GEMM4_ADD_T_W4A8 = /* @__PURE__ */ __name((hasDP4a) => ` +enable subgroups; +${hasDP4a ? ` +enable packed_4x8_integer_dot_product; +` : ""} +requires immediate_address_space; +struct Meta { K:u32, N:u32, T:u32, gpr:u32, hasBias:u32, p0:u32, p1:u32, p2:u32 }; +@group(0) @binding(0) var A_q: array; +@group(0) @binding(1) var scale_x: array; +@group(0) @binding(2) var W: array; +@group(0) @binding(3) var scale: array; +@group(0) @binding(4) var bias: array; +@group(0) @binding(5) var Y: array; +var m: Meta; + +${hasDP4a ? "" : ` +fn dot4I8Packed(a: u32, b: u32) -> i32 { + let va = unpack4xI8(a); + let vb = unpack4xI8(b); + return va.x * vb.x + va.y * vb.y + va.z * vb.z + va.w * vb.w; +} +`} + +const BM = 16u; const BN = 64u; +var As_q: array; +var As_scale: array; + +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let tTile = wid.y * BM; let col = wid.x * BN + lid.x; let valid = col < m.N; + let K8 = m.K/8u; let rb = col*K8; + var acc: array; + for (var i = 0u; i < BM; i = i + 1u) { acc[i] = 0.0; } + let groupsPerRow = m.K / 128u; + for (var c = 0u; c < K8; c = c + 1u) { + if (lid.x < BM * 2u) { + let tt = lid.x / 2u; let trow = tTile + tt; let wordIdx = lid.x % 2u; + As_q[lid.x] = select(0u, A_q[trow * (m.K / 4u) + c * 2u + wordIdx], trow < m.T); + } + if (lid.x < BM) { + let trow = tTile + lid.x; + As_scale[lid.x] = select(0.0, scale_x[trow * groupsPerRow + ((c * 8u) >> 7u)], trow < m.T); + } + workgroupBarrier(); + if (valid) { + let word = W[rb + c]; let sc_w = scale[col*m.gpr + ((c*8u) >> 7u)]; + let w0 = (i32(word << 28u) >> 28u) & 0xff; + let w1 = (i32(word << 24u) >> 28u) & 0xff; + let w2 = (i32(word << 20u) >> 28u) & 0xff; + let w3 = (i32(word << 16u) >> 28u) & 0xff; + let w4 = (i32(word << 12u) >> 28u) & 0xff; + let w5 = (i32(word << 8u) >> 28u) & 0xff; + let w6 = (i32(word << 4u) >> 28u) & 0xff; + let w7 = (i32(word) >> 28u) & 0xff; + let pw0 = u32(w0 | (w1 << 8u) | (w2 << 16u) | (w3 << 24u)); + let pw1 = u32(w4 | (w5 << 8u) | (w6 << 16u) | (w7 << 24u)); + for (var t = 0u; t < BM; t = t + 1u) { + let px0 = As_q[t * 2u]; let px1 = As_q[t * 2u + 1u]; + let sum = dot4I8Packed(pw0, px0) + dot4I8Packed(pw1, px1); + acc[t] = acc[t] + f32(sum) * sc_w * As_scale[t]; + } + } + workgroupBarrier(); + } + if (valid) { + let bv = select(0.0, bias[col], m.hasBias == 1u); + for (var t = 0u; t < BM; t = t + 1u) { + let trow = tTile + t; + if (trow < m.T) { Y[trow*m.N + col] = Y[trow*m.N + col] + acc[t] + bv; } + } + } +} +`, "GEMM4_ADD_T_W4A8"); +var WRITE_KV_PAGE = ` +requires immediate_address_space; +@group(0) @binding(0) var k_src: array; +@group(0) @binding(1) var v_src: array; +@group(0) @binding(2) var kc: array; +@group(0) @binding(3) var vc: array; +@group(0) @binding(4) var block_table: array; +var m: vec4; // pos, seq_id, max_blocks, kvd +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; let pos = m.x; let seq_id = m.y; let max_blocks = m.z; let kvd = m.w; + if (idx >= kvd) { return; } + let page_idx = block_table[seq_id * max_blocks + (pos / 16u)]; + let page_offset = pos % 16u; + let physical_pos = page_idx * 16u + page_offset; + let dst_offset = physical_pos * kvd + idx; + kc[dst_offset] = k_src[idx]; + vc[dst_offset] = v_src[idx]; +}`; +var WRITE_KV_PAGE_BATCH = ` +requires immediate_address_space; +struct KVBatchMeta { T:u32, seq_id:u32, max_blocks:u32, kvd:u32, off:u32 }; +@group(0) @binding(0) var k_src: array; +@group(0) @binding(1) var v_src: array; +@group(0) @binding(2) var kc: array; +@group(0) @binding(3) var vc: array; +@group(0) @binding(4) var block_table: array; +var m: KVBatchMeta; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; let T = m.T; let seq_id = m.seq_id; let max_blocks = m.max_blocks; let kvd = m.kvd; let off = m.off; + let total = T * kvd; if (idx >= total) { return; } + let t = idx / kvd; let d = idx % kvd; + let page_idx = block_table[seq_id * max_blocks + ((off + t) / 16u)]; + let page_offset = (off + t) % 16u; + let physical_pos = page_idx * 16u + page_offset; + let dst_offset = physical_pos * kvd + d; + kc[dst_offset] = k_src[idx]; + vc[dst_offset] = v_src[idx]; +}`; +var ATTN_PARTIAL_PAGED = ` +enable subgroups; +requires immediate_address_space; +struct Meta { nHeads:u32, nKV:u32, ctx:u32, hd:u32, nsplit:u32, chunk:u32, seq_id:u32, max_blocks:u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var pm: array; +@group(0) @binding(4) var pz: array; +@group(0) @binding(5) var po: array; +@group(0) @binding(6) var block_table: array; +var m: Meta; +var sc: array; +var red: array; +@compute @workgroup_size(128) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let h = wid.x; let s = wid.y; let tid = lid.x; + let nHeads = m.nHeads; let nKV = m.nKV; let ctx = m.ctx; let hd = m.hd; + let nsplit = m.nsplit; let chunk = m.chunk; let seq_id = m.seq_id; let max_blocks = m.max_blocks; + let kvh = h / (nHeads / nKV); + let qbase = h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scale = 1.0/sqrt(f32(hd)); + let nsg = (128u + sgsz - 1u) / sgsz; + let t0 = s*chunk; var t1 = t0 + chunk; if (t1 > ctx) { t1 = ctx; } + let t = t0 + tid; var sv = -1e30; + if (t < t1) { + var dot = 0.0; + let page_idx = block_table[seq_id * max_blocks + (t / 16u)]; + let page_offset = t % 16u; + let kb = (page_idx * 16u + page_offset) * stride + hoff; + for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } + sv = dot*scale; + } + let sgm = subgroupMax(sv); if (sgid == 0u) { red[tid/sgsz] = sgm; } + workgroupBarrier(); + var M = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { M = max(M, red[i]); } + workgroupBarrier(); + var ev = 0.0; if (t < t1) { ev = exp(sv - M); } sc[tid] = ev; + let sgs = subgroupAdd(ev); if (sgid == 0u) { red[tid/sgsz] = sgs; } + workgroupBarrier(); + var Z = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { Z = Z + red[i]; } + workgroupBarrier(); + let len = t1 - t0; let pbase = (h*nsplit + s)*hd; + for (var d = tid; d < hd; d = d + 128u) { + var acc = 0.0; + for (var tt = 0u; tt < len; tt = tt + 1u) { + let t_curr = t0 + tt; + let page_idx = block_table[seq_id * max_blocks + (t_curr / 16u)]; + let page_offset = t_curr % 16u; + let physical_t = page_idx * 16u + page_offset; + acc = acc + sc[tt]*vc[physical_t*stride + hoff + d]; + } + po[pbase + d] = acc; + } + if (tid == 0u) { pm[h*nsplit + s] = M; pz[h*nsplit + s] = Z; } +}`; +var ATTN_PREFILL_PAGED = ` +enable subgroups; +requires immediate_address_space; +struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32, seq_id:u32, max_blocks:u32, p0:u32, p1:u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var o: array; +@group(0) @binding(4) var block_table: array; +var m: Meta; +var ps: array; +var acc: array; +var red: array; +@compute @workgroup_size(256) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let h = wid.x; let t = wid.y; let tid = lid.x; let nHeads = m.nHeads; let nKV = m.nKV; let hd = m.hd; + let ctx = t + 1u; let kvh = h / (nHeads / nKV); + let qbase = t*nHeads*hd + h*hd; let stride = nKV*hd; let hoff = kvh*hd; let scl = 1.0/sqrt(f32(hd)); + let nsg = (256u + sgsz - 1u) / sgsz; + let seq_id = m.seq_id; let max_blocks = m.max_blocks; + for (var d = tid; d < hd; d = d + 256u) { acc[d] = 0.0; } + var mrun = -1e30; var lrun = 0.0; + let nblk = (ctx + 255u) / 256u; + for (var blk = 0u; blk < nblk; blk = blk + 1u) { + let kbase = blk*256u; let kk = kbase + tid; + var s = -1e30; + if (kk < ctx) { + var dot = 0.0; + let page_idx = block_table[seq_id * max_blocks + (kk / 16u)]; + let page_offset = kk % 16u; + let kb = (page_idx * 16u + page_offset)*stride + hoff; + for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qbase+d]*kc[kb+d]; } + s = dot*scl; + } + let sgm = subgroupMax(s); if (sgid == 0u) { red[tid/sgsz] = sgm; } + workgroupBarrier(); + var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[i]); } + let mnew = max(mrun, bm); let corr = exp(mrun - mnew); + var p = 0.0; if (kk < ctx) { p = exp(s - mnew); } + ps[tid] = p; + workgroupBarrier(); + let sgs = subgroupAdd(p); if (sgid == 0u) { red[tid/sgsz] = sgs; } + workgroupBarrier(); + var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[i]; } + lrun = lrun*corr + bs; + let bcount = min(256u, ctx - kbase); + for (var d = tid; d < hd; d = d + 256u) { + var aa = acc[d]*corr; + for (var j = 0u; j < bcount; j = j + 1u) { + let t_curr = kbase + j; + let page_idx = block_table[seq_id * max_blocks + (t_curr / 16u)]; + let page_offset = t_curr % 16u; + let physical_t = page_idx * 16u + page_offset; + aa = aa + ps[j]*vc[physical_t*stride + hoff + d]; + } + acc[d] = aa; + } + mrun = mnew; + workgroupBarrier(); + } + let invL = 1.0/lrun; + for (var d = tid; d < hd; d = d + 256u) { o[qbase + d] = acc[d]*invL; } +}`; +var ATTN_PREFILL_BLOCK_PAGED = ` +enable subgroups; +requires immediate_address_space; +struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32, qStart:u32, ctx:u32, seq_id:u32, max_blocks:u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var o: array; +@group(0) @binding(4) var block_table: array; +var m: Meta; +const BQ = 4u; const BK = 128u; +var ps: array; +var acc: array; +var red: array; +@compute @workgroup_size(128) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + let h = wid.x; let qBlock = wid.y; let tid = lid.x; let hd = m.hd; + let kvh = h / (m.nHeads / m.nKV); let stride = m.nKV * hd; let hoff = kvh * hd; + let nsg = (128u + sgsz - 1u) / sgsz; let scl = 1.0 / sqrt(f32(hd)); + let seq_id = m.seq_id; let max_blocks = m.max_blocks; + var mrun: array; var lrun: array; + for (var r = 0u; r < BQ; r = r + 1u) { mrun[r] = -1e30; lrun[r] = 0.0; } + for (var i = tid; i < BQ*hd; i = i + 128u) { acc[i] = 0.0; } + workgroupBarrier(); + let nblk = (m.ctx + BK - 1u) / BK; + for (var blk = 0u; blk < nblk; blk = blk + 1u) { + let kbase = blk * BK; let kk = kbase + tid; + var score: array; + var validQ: array; + var dot: array; + var corrRun: array; + for (var r = 0u; r < BQ; r = r + 1u) { + let qt = qBlock * BQ + r; let absQ = m.qStart + qt; + validQ[r] = qt < m.T && kk < m.ctx && kk <= absQ; + dot[r] = 0.0; score[r] = -1e30; + } + if (kk < m.ctx) { + let page_idx = block_table[seq_id * max_blocks + (kk / 16u)]; + let page_offset = kk % 16u; + let kb = (page_idx * 16u + page_offset)*stride + hoff; + for (var d = 0u; d < hd; d = d + 1u) { + let kval = kc[kb+d]; + for (var r = 0u; r < BQ; r = r + 1u) { + let qt = qBlock * BQ + r; + if (validQ[r]) { dot[r] = dot[r] + q[qt*m.nHeads*hd + h*hd + d] * kval; } + } + } + for (var r = 0u; r < BQ; r = r + 1u) { + if (validQ[r]) { score[r] = dot[r] * scl; } + } + } + for (var r = 0u; r < BQ; r = r + 1u) { + let s = score[r]; + let sgm = subgroupMax(s); + if (sgid == 0u) { red[r*32u + tid/sgsz] = sgm; } + workgroupBarrier(); + var bm = -1e30; for (var i = 0u; i < nsg; i = i + 1u) { bm = max(bm, red[r*32u+i]); } + let mnew = max(mrun[r], bm); let corr = exp(mrun[r] - mnew); + corrRun[r] = corr; + var p = 0.0; if (validQ[r]) { p = exp(s - mnew); } + ps[r*BK + tid] = p; + workgroupBarrier(); + let sgs = subgroupAdd(p); + if (sgid == 0u) { red[r*32u + tid/sgsz] = sgs; } + workgroupBarrier(); + var bs = 0.0; for (var i = 0u; i < nsg; i = i + 1u) { bs = bs + red[r*32u+i]; } + lrun[r] = lrun[r] * corr + bs; + mrun[r] = mnew; + workgroupBarrier(); + } + let bcount = min(BK, m.ctx - kbase); + for (var d = tid; d < hd; d = d + 128u) { + var aa: array; + for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = acc[r*hd+d] * corrRun[r]; } + for (var j = 0u; j < bcount; j = j + 1u) { + let t_curr = kbase + j; + let page_idx = block_table[seq_id * max_blocks + (t_curr / 16u)]; + let page_offset = t_curr % 16u; + let physical_t = page_idx * 16u + page_offset; + let vv = vc[physical_t*stride + hoff + d]; + for (var r = 0u; r < BQ; r = r + 1u) { aa[r] = aa[r] + ps[r*BK+j] * vv; } + } + for (var r = 0u; r < BQ; r = r + 1u) { acc[r*hd+d] = aa[r]; } + } + workgroupBarrier(); + } + for (var r = 0u; r < BQ; r = r + 1u) { + let qt = qBlock * BQ + r; + if (qt < m.T) { + let invL = 1.0 / lrun[r]; let ob = qt*m.nHeads*hd + h*hd; + for (var d = tid; d < hd; d = d + 128u) { o[ob+d] = acc[r*hd+d] * invL; } + } + } +}`; +var GEMV4_QKV_ROPE_RMS = ` +enable subgroups; +requires immediate_address_space; +struct Meta { + K: u32, totalPairs: u32, qPairs: u32, kPairs: u32, vPairs: u32, gpr: u32, gridX: u32, + pos: u32, headDim: u32, eps: f32, + qN: u32, kN: u32 +}; + +@group(0) @binding(0) var hidden: array; +@group(0) @binding(1) var rms_g: array; +@group(0) @binding(2) var w: array; +@group(0) @binding(3) var scale: array; +@group(0) @binding(4) var bias: array; +@group(0) @binding(5) var cosT: array; +@group(0) @binding(6) var sinT: array; +@group(0) @binding(7) var qOut: array; +@group(0) @binding(8) var kOut: array; +@group(0) @binding(9) var vOut: array; +var m: Meta; + +var partSum: array; + +@compute @workgroup_size(64) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3, + @builtin(subgroup_size) sgsz: u32, @builtin(subgroup_invocation_id) sgid: u32) { + + let pair_idx = wid.x + wid.y * m.gridX; + if (pair_idx >= m.totalPairs) { return; } + let tid = lid.x; + + var s = 0.0; + for (var k = tid; k < m.K; k = k + 64u) { let v = hidden[k]; s = s + v*v; } + let ssum = subgroupAdd(s); + if (sgid == 0u) { partSum[tid / sgsz] = ssum; } + workgroupBarrier(); + + if (tid == 0u) { + let nsg = (64u + sgsz - 1u) / sgsz; var red = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { red = red + partSum[i]; } + partSum[0] = inverseSqrt(red / f32(m.K) + m.eps); + } + workgroupBarrier(); + let inv = partSum[0]; + + let half = m.headDim / 2u; + var n0: u32; var n1: u32; + var isQ = false; var isK = false; var isV = false; + var out_idx0: u32; var out_idx1: u32; + var rope_j: u32 = 0u; + + if (pair_idx < m.qPairs) { + isQ = true; + let h = pair_idx / half; let j = pair_idx % half; + n0 = h * m.headDim + j; + n1 = n0 + half; + out_idx0 = n0; out_idx1 = n1; + rope_j = j; + } else if (pair_idx < m.qPairs + m.kPairs) { + isK = true; + let p = pair_idx - m.qPairs; + let h = p / half; let j = p % half; + n0 = m.qN + h * m.headDim + j; + n1 = n0 + half; + out_idx0 = h * m.headDim + j; out_idx1 = out_idx0 + half; + rope_j = j; + } else { + isV = true; + let p = pair_idx - m.qPairs - m.kPairs; + n0 = m.qN + m.kN + p * 2u; + n1 = n0 + 1u; + out_idx0 = p * 2u; out_idx1 = out_idx0 + 1u; + } + + let K8 = m.K / 8u; + let rb0 = n0 * K8; let rb1 = n1 * K8; + let sbase0 = n0 * m.gpr; let sbase1 = n1 * m.gpr; + + var acc0 = 0.0; var acc1 = 0.0; + + for (var c = tid; c < K8; c = c + 64u) { + let w0 = w[rb0 + c]; let w1 = w[rb1 + c]; + let bk = c * 8u; + let sc0 = scale[sbase0 + (bk >> 7u)]; let sc1 = scale[sbase1 + (bk >> 7u)]; + + // We compute normalized X on the fly + let x0 = hidden[bk] * inv * rms_g[bk]; + let x1 = hidden[bk+1u] * inv * rms_g[bk+1u]; + let x2 = hidden[bk+2u] * inv * rms_g[bk+2u]; + let x3 = hidden[bk+3u] * inv * rms_g[bk+3u]; + let x4 = hidden[bk+4u] * inv * rms_g[bk+4u]; + let x5 = hidden[bk+5u] * inv * rms_g[bk+5u]; + let x6 = hidden[bk+6u] * inv * rms_g[bk+6u]; + let x7 = hidden[bk+7u] * inv * rms_g[bk+7u]; + + var p0 = 0.0; var p1 = 0.0; + p0 = p0 + x0 * f32(i32(w0 << 28u) >> 28u); p1 = p1 + x0 * f32(i32(w1 << 28u) >> 28u); + p0 = p0 + x1 * f32(i32(w0 << 24u) >> 28u); p1 = p1 + x1 * f32(i32(w1 << 24u) >> 28u); + p0 = p0 + x2 * f32(i32(w0 << 20u) >> 28u); p1 = p1 + x2 * f32(i32(w1 << 20u) >> 28u); + p0 = p0 + x3 * f32(i32(w0 << 16u) >> 28u); p1 = p1 + x3 * f32(i32(w1 << 16u) >> 28u); + p0 = p0 + x4 * f32(i32(w0 << 12u) >> 28u); p1 = p1 + x4 * f32(i32(w1 << 12u) >> 28u); + p0 = p0 + x5 * f32(i32(w0 << 8u) >> 28u); p1 = p1 + x5 * f32(i32(w1 << 8u) >> 28u); + p0 = p0 + x6 * f32(i32(w0 << 4u) >> 28u); p1 = p1 + x6 * f32(i32(w1 << 4u) >> 28u); + p0 = p0 + x7 * f32(i32(w0) >> 28u); p1 = p1 + x7 * f32(i32(w1) >> 28u); + + acc0 = acc0 + p0 * sc0; + acc1 = acc1 + p1 * sc1; + } + + let ssum0 = subgroupAdd(acc0); let ssum1 = subgroupAdd(acc1); + if (sgid == 0u) { partSum[tid / sgsz] = ssum0; partSum[32u + tid / sgsz] = ssum1; } + workgroupBarrier(); + + if (tid == 0u) { + let nsg = (64u + sgsz - 1u) / sgsz; + var o0 = 0.0; var o1 = 0.0; + for (var i = 0u; i < nsg; i = i + 1u) { o0 = o0 + partSum[i]; o1 = o1 + partSum[32u + i]; } + + o0 = o0 + bias[n0]; + o1 = o1 + bias[n1]; + + if (isQ || isK) { + let off = m.pos * m.headDim + rope_j; + let c = cosT[off]; let s = sinT[off]; + let rl = fma(o0, c, 0.0) + fma(-o1, s, 0.0); + let rh = fma(o1, c, 0.0) + fma(o0, s, 0.0); + o0 = rl; o1 = rh; + } + + if (isQ) { qOut[out_idx0] = o0; qOut[out_idx1] = o1; } + else if (isK) { kOut[out_idx0] = o0; kOut[out_idx1] = o1; } + else { vOut[out_idx0] = o0; vOut[out_idx1] = o1; } + } +}`; + +// src/qwgpu/model_schema.js +var arrEq = /* @__PURE__ */ __name((a, b) => a.length === b.length && a.every((v, i) => v === b[i]), "arrEq"); +function projDesc(layer, subpath, outDim, inDim, { bias = false } = {}) { + const name = `model.layers.${layer}.${subpath}.weight`; + const m = subpath.match(/^(self_attn|mlp)\.(.+)$/); + const loraKey = `layers.${layer}.${m[1]}.${m[2]}`; + return { + name, + role: "projection", + quant: "int4", + shape: [outDim, inDim], + loraKey, + biasName: bias ? name.replace(/\.weight$/, ".bias") : null + }; +} +__name(projDesc, "projDesc"); +function f32Desc(name, shape, role = "f32") { + return { name, role, quant: "f32", shape }; +} +__name(f32Desc, "f32Desc"); +function createQwenSchema(cfg) { + if (!cfg.tieWordEmbeddings && cfg.tieWordEmbeddings !== void 0) { + throw new Error("QwenWGPU currently requires tied input/output embeddings"); + } + const H = cfg.hiddenSize; + const QD = cfg.numHeads * cfg.headDim; + const KVD = cfg.numKVHeads * cfg.headDim; + const I = cfg.intermediateSize; + const tensors = []; + const layers = []; + const add = /* @__PURE__ */ __name((d) => { + tensors.push(d); + return d; + }, "add"); + const embed = add({ name: "model.embed_tokens.weight", role: "embedding", quant: "int8", shape: [cfg.vocabSize, H] }); + const finalNorm = add(f32Desc("model.norm.weight", [H], "final_norm")); + for (let i = 0; i < cfg.numLayers; i++) { + const p = `model.layers.${i}`; + const layer = { + index: i, + inputNorm: add(f32Desc(`${p}.input_layernorm.weight`, [H], "input_norm")), + postAttentionNorm: add(f32Desc(`${p}.post_attention_layernorm.weight`, [H], "post_attention_norm")), + projections: {}, + biases: {} + }; + layer.projections.q = add(projDesc(i, "self_attn.q_proj", QD, H, { bias: !!cfg.attentionBias })); + layer.projections.k = add(projDesc(i, "self_attn.k_proj", KVD, H, { bias: !!cfg.attentionBias })); + layer.projections.v = add(projDesc(i, "self_attn.v_proj", KVD, H, { bias: !!cfg.attentionBias })); + layer.projections.o = add(projDesc(i, "self_attn.o_proj", H, QD)); + layer.projections.gate = add(projDesc(i, "mlp.gate_proj", I, H)); + layer.projections.up = add(projDesc(i, "mlp.up_proj", I, H)); + layer.projections.down = add(projDesc(i, "mlp.down_proj", H, I)); + for (const key of ["q", "k", "v"]) { + const proj = layer.projections[key]; + if (proj.biasName) { + const bias = add(f32Desc(proj.biasName, [proj.shape[0]], `${key}_bias`)); + layer.biases[key] = bias; + } + } + layers.push(layer); + } + const byName = new Map(tensors.map((t) => [t.name, t])); + const expectedNames = new Set(byName.keys()); + return { + cfg, + tensors, + byName, + expectedNames, + layers, + embed, + finalNorm, + projectionDescs: tensors.filter((t) => t.role === "projection"), + validateTensor(name, shape) { + const desc = byName.get(name); + if (!desc) return null; + if (!arrEq(shape, desc.shape)) { + throw new Error(`shape mismatch for ${name}: got [${shape.join(",")}], expected [${desc.shape.join(",")}]`); + } + return desc; + }, + assertComplete(seen) { + const missing = []; + for (const name of expectedNames) if (!seen.has(name)) missing.push(name); + if (missing.length) { + const sample = missing.slice(0, 12).join(", "); + throw new Error(`missing ${missing.length} required tensor(s): ${sample}${missing.length > 12 ? ", \u2026" : ""}`); + } + } + }; +} +__name(createQwenSchema, "createQwenSchema"); +function moduleKeyFromTensorName(name) { + const m = name.match(/layers\.(\d+)\.(self_attn|mlp)\.([a-z_]+?)(_proj)?\.(lora_[ABab])/i); + if (!m) return null; + return `layers.${m[1]}.${m[2]}.${m[3].replace(/_proj$/, "")}_proj`; +} +__name(moduleKeyFromTensorName, "moduleKeyFromTensorName"); + +// src/qwgpu/dispatch_plan.js +function createDispatchPlan(schema) { + return { + embed: schema.embed, + finalNorm: schema.finalNorm, + layers: schema.layers.map((layer) => ({ + index: layer.index, + inputNorm: layer.inputNorm.name, + postAttentionNorm: layer.postAttentionNorm.name, + q: { + weight: layer.projections.q.name, + bias: layer.biases.q?.name || null, + loraKey: layer.projections.q.loraKey + }, + k: { + weight: layer.projections.k.name, + bias: layer.biases.k?.name || null, + loraKey: layer.projections.k.loraKey + }, + v: { + weight: layer.projections.v.name, + bias: layer.biases.v?.name || null, + loraKey: layer.projections.v.loraKey + }, + o: { + weight: layer.projections.o.name, + bias: null, + loraKey: layer.projections.o.loraKey + }, + gate: { + weight: layer.projections.gate.name, + bias: null, + loraKey: layer.projections.gate.loraKey + }, + up: { + weight: layer.projections.up.name, + bias: null, + loraKey: layer.projections.up.loraKey + }, + down: { + weight: layer.projections.down.name, + bias: null, + loraKey: layer.projections.down.loraKey + } + })) + }; +} +__name(createDispatchPlan, "createDispatchPlan"); + +// src/qwgpu/safetensors_loader.js +function decodeBf16ToF32(u8, numel) { + const u16 = new Uint16Array(u8.buffer, u8.byteOffset, numel); + const out = new Float32Array(numel); + const o32 = new Uint32Array(out.buffer); + for (let i = 0; i < numel; i++) o32[i] = u16[i] << 16; + return out; +} +__name(decodeBf16ToF32, "decodeBf16ToF32"); +function decodeF16ToF32(u8, numel) { + const u16 = new Uint16Array(u8.buffer, u8.byteOffset, numel); + const out = new Float32Array(numel); + for (let i = 0; i < numel; i++) { + const h = u16[i], s = (h & 32768) >> 15, e = (h & 31744) >> 10, f = h & 1023; + if (e === 0) out[i] = (s ? -1 : 1) * Math.pow(2, -14) * (f / 1024); + else if (e === 31) out[i] = f ? NaN : s ? -Infinity : Infinity; + else out[i] = (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024); + } + return out; +} +__name(decodeF16ToF32, "decodeF16ToF32"); +function decodeF32(u8, numel) { + return new Float32Array(u8.buffer.slice(u8.byteOffset, u8.byteOffset + numel * 4)); +} +__name(decodeF32, "decodeF32"); +var DECODERS = { + BF16: decodeBf16ToF32, + F16: decodeF16ToF32, + FP16: decodeF16ToF32, + F32: decodeF32, + FP32: decodeF32 +}; +async function loadIndex(reader) { + try { + const idx = JSON.parse(await reader.text("model.safetensors.index.json")); + return { weightMap: idx.weight_map || {}, shards: [...new Set(Object.values(idx.weight_map || {}))] }; + } catch { + return { weightMap: null, shards: ["model.safetensors"] }; + } +} +__name(loadIndex, "loadIndex"); +function shardPlan(shards, weightMap, names) { + if (!weightMap || !names) return new Map(shards.map((shard) => [shard, null])); + const plan = /* @__PURE__ */ new Map(); + for (const name of names) { + const shard = weightMap[name]; + if (!shard) continue; + if (!plan.has(shard)) plan.set(shard, /* @__PURE__ */ new Set()); + plan.get(shard).add(name); + } + return plan; +} +__name(shardPlan, "shardPlan"); +async function streamSafetensors(source, { names = null, onTensor, onProgress = /* @__PURE__ */ __name(() => { +}, "onProgress") } = {}) { + if (!onTensor) throw new Error("streamSafetensors requires onTensor"); + const reader = typeof source === "string" ? urlReader(source) : source; + const { weightMap, shards } = await loadIndex(reader); + const plan = shardPlan(shards, weightMap, names); + let visited = 0; + const total = names?.size || 0; + for (const [shard, wantedInShard] of plan) { + const lenBuf = await reader.range(shard, 0, 8); + const headerLen = Number(new DataView(lenBuf).getBigUint64(0, true)); + const hdrBuf = await reader.range(shard, 8, 8 + headerLen); + const header = JSON.parse(new TextDecoder().decode(new Uint8Array(hdrBuf))); + const dataStart = 8 + headerLen; + const allNames = Object.keys(header).filter((k) => k !== "__metadata__"); + const tensorNames = wantedInShard ? allNames.filter((n) => wantedInShard.has(n)) : names ? allNames.filter((n) => names.has(n)) : allNames; + for (const name of tensorNames) { + const t = header[name]; + if (!t) continue; + const dtype = String(t.dtype || "").toUpperCase(); + const dec = DECODERS[dtype]; + if (!dec) throw new Error(`unsupported dtype ${dtype} for ${name}`); + const numel = t.shape.reduce((a, b) => a * b, 1); + const [s, e] = t.data_offsets; + const buf = await reader.range(shard, dataStart + s, dataStart + e); + const data = dec(new Uint8Array(buf), numel); + await onTensor({ name, shape: t.shape, dtype, data, shard }); + visited++; + onProgress(name, total ? Math.min(0.95, visited / total) : 0.3); + } + } +} +__name(streamSafetensors, "streamSafetensors"); + +// src/qwgpu/quantize.js +function quantizeInt8RowMajor(f322, outDim, inDim) { + const scale = new Float32Array(outDim); + const q = new Int8Array(outDim * inDim); + for (let o = 0; o < outDim; o++) { + const base = o * inDim; + let amax = 0; + for (let i = 0; i < inDim; i++) { + const a = Math.abs(f322[base + i]); + if (a > amax) amax = a; + } + const s = amax > 0 ? amax / 127 : 1; + scale[o] = s; + const inv = 1 / s; + for (let i = 0; i < inDim; i++) { + let v = Math.round(f322[base + i] * inv); + if (v > 127) v = 127; + else if (v < -128) v = -128; + q[base + i] = v; + } + } + const packed = new Uint32Array(outDim * inDim / 4); + const u8 = new Uint8Array(q.buffer); + for (let w = 0; w < packed.length; w++) { + packed[w] = u8[w * 4] | u8[w * 4 + 1] << 8 | u8[w * 4 + 2] << 16 | u8[w * 4 + 3] << 24; + } + return { packed, scale, outDim, inDim }; +} +__name(quantizeInt8RowMajor, "quantizeInt8RowMajor"); +function quantizeInt4Group(f322, outDim, inDim, group = 128) { + const groupsPerRow = inDim / group; + const scale = new Float32Array(outDim * groupsPerRow); + const q = new Int8Array(outDim * inDim); + for (let o = 0; o < outDim; o++) { + for (let g = 0; g < groupsPerRow; g++) { + const base = o * inDim + g * group; + let amax = 0; + for (let i = 0; i < group; i++) { + const a = Math.abs(f322[base + i]); + if (a > amax) amax = a; + } + const s = amax > 0 ? amax / 7 : 1; + scale[o * groupsPerRow + g] = s; + const inv = 1 / s; + for (let i = 0; i < group; i++) { + let v = Math.round(f322[base + i] * inv); + if (v > 7) v = 7; + else if (v < -8) v = -8; + q[base + i] = v; + } + } + } + const packed = new Uint32Array(outDim * inDim / 8); + for (let w = 0; w < packed.length; w++) { + let acc = 0; + for (let j = 0; j < 8; j++) acc |= (q[w * 8 + j] & 15) << j * 4; + packed[w] = acc >>> 0; + } + return { packed, scale, groupsPerRow }; +} +__name(quantizeInt4Group, "quantizeInt4Group"); + +// src/qwgpu/model_uploader.js +var ModelUploader = class { + static { + __name(this, "ModelUploader"); + } + constructor({ schema, q, q4, bufs, uploadF32, uploadU32, groupSize = 128 }) { + this.schema = schema; + this.q = q; + this.q4 = q4; + this.bufs = bufs; + this.uploadF32 = uploadF32; + this.uploadU32 = uploadU32; + this.groupSize = groupSize; + this.seen = /* @__PURE__ */ new Set(); + } + visit({ name, shape, data }) { + const desc = this.schema.validateTensor(name, shape); + if (!desc) return; + if (this.seen.has(name)) throw new Error(`duplicate tensor ${name}`); + if (desc.quant === "int8") { + const { packed, scale } = quantizeInt8RowMajor(data, shape[0], shape[1]); + this.q[name] = { w: this.uploadU32(packed), scale: this.uploadF32(scale), N: shape[0], K: shape[1] }; + } else if (desc.quant === "int4") { + const { packed, scale, groupsPerRow } = quantizeInt4Group(data, shape[0], shape[1], this.groupSize); + this.q4[name] = { + w: this.uploadU32(packed), + scale: this.uploadF32(scale), + N: shape[0], + K: shape[1], + gpr: groupsPerRow, + desc + }; + } else if (desc.quant === "f32") { + this.bufs[name] = this.uploadF32(data); + } else { + throw new Error(`unsupported quant mode ${desc.quant} for ${name}`); + } + this.seen.add(name); + } + finalize() { + this.schema.assertComplete(this.seen); + } +}; + +// src/qwgpu/buffer_pool.js +var GPUBufferPool = class { + static { + __name(this, "GPUBufferPool"); + } + constructor(device, { cacheBindGroups = true } = {}) { + this.dev = device; + this.cacheBindGroups = cacheBindGroups; + this.uniformPool = []; + this.uniformIdx = 0; + this.staticUniforms = /* @__PURE__ */ new Map(); + this.bindGroups = /* @__PURE__ */ new Map(); + this.sensitiveBindGroups = /* @__PURE__ */ new Set(); + this.bufferIds = /* @__PURE__ */ new WeakMap(); + this.pipelineIds = /* @__PURE__ */ new WeakMap(); + this.nextBufferId = 1; + this.nextPipelineId = 1; + this._stats = this._emptyStats(); + } + /* + * TECHNIQUE: Bind group caching (opt-in per call site) + * Frequently reused (pipeline + buffer set) combinations are stored in a Map. + * Avoids repeated GPU bind group creation on the hot GEMV / attention paths. + * Sensitive / one-shot groups are deliberately not cached. + */ + _emptyStats() { + return { + buffersCreated: 0, + dynamicUniformWrites: 0, + staticUniformHits: 0, + staticUniformMisses: 0, + bindGroupHits: 0, + bindGroupMisses: 0, + uncachedBindGroups: 0 + }; + } + resetStats() { + this._stats = this._emptyStats(); + } + stats() { + return { + ...this._stats, + uniformPoolSize: this.uniformPool.length, + staticUniforms: this.staticUniforms.size, + bindGroups: this.bindGroups.size + }; + } + buffer(size, usage) { + this._stats.buffersCreated++; + return this.dev.createBuffer({ size, usage }); + } + uploadF32(arr, usage) { + const b = this.buffer(arr.byteLength, usage); + this.dev.queue.writeBuffer(b, 0, arr); + return b; + } + uploadU32(arr, usage) { + const b = this.buffer(arr.byteLength, usage); + this.dev.queue.writeBuffer(b, 0, arr); + return b; + } + dynamicUniform(arr, usage) { + let b = this.uniformPool[this.uniformIdx]; + if (!b) { + b = this.buffer(32, usage); + this.uniformPool[this.uniformIdx] = b; + } + this.uniformIdx++; + this._stats.dynamicUniformWrites++; + this.dev.queue.writeBuffer(b, 0, arr.buffer, arr.byteOffset, arr.byteLength); + return b; + } + resetUniforms() { + this.uniformIdx = 0; + } + staticUniform(key, arr, usage) { + let b = this.staticUniforms.get(key); + if (!b) { + this._stats.staticUniformMisses++; + b = this.buffer(32, usage); + this.dev.queue.writeBuffer(b, 0, arr.buffer, arr.byteOffset, arr.byteLength); + this.staticUniforms.set(key, b); + } else this._stats.staticUniformHits++; + return b; + } + idForBuffer(buffer) { + let id = this.bufferIds.get(buffer); + if (!id) { + id = this.nextBufferId++; + this.bufferIds.set(buffer, id); + } + return id; + } + idForPipeline(pipe) { + let id = this.pipelineIds.get(pipe); + if (!id) { + id = this.nextPipelineId++; + this.pipelineIds.set(pipe, id); + } + return id; + } + uncachedBindGroup(pipe, buffers) { + this._stats.uncachedBindGroups++; + return this.dev.createBindGroup({ + label: pipe.__name ? `${pipe.__name}:bg:${buffers.length}` : void 0, + layout: pipe.getBindGroupLayout(0), + entries: buffers.map((buffer, i) => ({ binding: i, resource: { buffer } })) + }); + } + cachedBindGroup(pipe, buffers, key, { sensitive = false } = {}) { + if (!this.cacheBindGroups || !key) return this.uncachedBindGroup(pipe, buffers); + const fullKey = `${this.idForPipeline(pipe)}:${key}:${buffers.map((b) => this.idForBuffer(b)).join(",")}`; + let bg = this.bindGroups.get(fullKey); + if (!bg) { + this._stats.bindGroupMisses++; + bg = this.uncachedBindGroup(pipe, buffers); + this.bindGroups.set(fullKey, bg); + if (sensitive) this.sensitiveBindGroups.add(fullKey); + } else this._stats.bindGroupHits++; + return bg; + } + clearSensitiveBindGroups() { + for (const key of this.sensitiveBindGroups) this.bindGroups.delete(key); + this.sensitiveBindGroups.clear(); + } +}; + +// src/qwgpu/runtime.js +var STORAGE = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; +var UNIFORM = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST; +var QwenWGPU = class { + static { + __name(this, "QwenWGPU"); + } + // opts: { maxCtx, maxPrefillT, decodeBatchSize, samplingTopK } — context + // window + batched-prefill cap (default 8192 each; KV cache grows linearly). + constructor(device, cfg, opts = {}) { + this.dev = device; + this.cfg = cfg; + this.lora = null; + this.bufs = {}; + this.opts = opts; + this.features = this._normalizeFeatures(opts); + this.pool = new GPUBufferPool(device, { cacheBindGroups: opts.cacheBindGroups !== false }); + this._loraEpoch = 0; + this.lastDispatchCount = 0; + this.packedBytes = 0; + this.workgroupAutotunePromise = null; + this._argmaxReadBusy = false; + this._topKReadBusy = false; + } + _normalizeFeatures(opts = {}) { + const prefillAttention = opts.prefillAttention || "block"; + if (!["row", "block"].includes(prefillAttention)) + throw new Error(`unsupported prefillAttention ${prefillAttention}`); + return { + // fuseRMSNormQKVRoPE: fused RMSNorm + int4 QKV GEMV + RoPE for no-LoRA decode + // (one workgroup per (head,rot) pair; verified logitDiff 0 vs PyTorch ref). + // fuseQKV selects the alternate qkvGemv4 path and stays OFF by default since + // the fused-RMS path already covers the fast no-LoRA decode; LoRA layers are + // routed to the unfused gemv4x3 + ropeQK path automatically (see step()). + fuseQKV: opts.fuseQKV === true, + fuseRoPE: opts.fuseRoPE !== false, + fuseMLP: opts.fuseMLP !== false, + fuseResidual: opts.fuseResidual !== false, + prefillAttention, + prefillChunkSize: Math.max(0, opts.prefillChunkSize || 0), + actQuant: !!opts.actQuant, + // Default OFF: the GEMV4_QKV_ROPE_RMS kernel still computes zero outputs even + // with the corrected (totalPairs) dispatch — there is a deeper bug in the + // fused kernel itself. The unfused gemv4x3 + ropeQK decode is verified + // logitDiff 0 vs the PyTorch ref, so it stays the default until the fused + // kernel is debugged. The wrapper dispatch is now correct for that work. + fuseRMSNormQKVRoPE: opts.fuseRMSNormQKVRoPE === true, + pagedAttention: !!opts.pagedAttention + }; + } + setFeatureFlags(flags = {}) { + this.features = this._normalizeFeatures({ ...this.features, ...flags }); + this.pool.clearSensitiveBindGroups(); + } + featureFlags() { + return { ...this.features }; + } + // Phase 3 (f16): when shader-f16 is available we can switch hot kernels to f16 + // storage/compute for bandwidth wins. Stub for now; real kernel variants + selection + // will be added. Evaluation: compare f16 vs f32 logits within tolerance + bench speedup. + hasF16Compute() { + return !!this.hasF16; + } + setUseF16(v) { + this._useF16 = !!v && this.hasF16Compute(); + } + usingF16() { + return !!this._useF16; + } + // Phase 4: allow caller / autotuner to override workgroup size after build if desired. + // Note: affects *future* pipes / re-pipes; existing pipes keep their specialization. + setWorkgroupSize(wg) { + if (wg && wg > 0) this.workgroupSize = wg | 0; + } + // Basic load-time / on-demand workgroup autotuner (Phase 4). + // Tries a few WG sizes for simple override-supporting kernels (add / rms for now). + // Uses wall time + onSubmittedWorkDone for broad compatibility. + // Returns a map of best sizes; optionally hot-swaps the pipe for 'add'. + async autotuneWorkgroups(opts = {}) { + const iters = opts.iters || 6; + const cands = opts.candidates || [32, 64, 128, 256]; + const results = {}; + const useTS = this.hasTimestampQuery; + const timeKernel = /* @__PURE__ */ __name(async (spec, pipe, label) => { + const n = spec.n; + const a = this._buf(n * 4); + const g = this._buf(n * 4); + const y = this._buf(n * 4); + const buffers = spec.buffers(a, y, g); + const imm = spec.imm(n); + let gpuMs = 0; + let usedGPU = false; + if (useTS) { + const qs = this.dev.createQuerySet({ type: "timestamp", count: 2 }); + const resolveBuf = this._buf(16, GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC); + const readBuf = this._buf(16, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + const tWall0 = typeof performance !== "undefined" ? performance.now() : Date.now(); + for (let i = 0; i < iters; i++) { + const enc = this.dev.createCommandEncoder(); + const bg = this._bg(pipe, buffers); + const p = enc.beginComputePass({ + timestampWrites: { + querySet: qs, + beginningOfPassWriteIndex: 0, + endOfPassWriteIndex: 1 + } + }); + p.setPipeline(pipe); + if (bg) p.setBindGroup(0, bg); + if (imm) p.setImmediates(0, imm); + p.dispatchWorkgroups(Math.ceil(n / (pipe.__wg || 256)), 1); + p.end(); + enc.resolveQuerySet(qs, 0, 2, resolveBuf, 0); + enc.copyBufferToBuffer(resolveBuf, 0, readBuf, 0, 16); + this.dev.queue.submit([enc.finish()]); + if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); + await readBuf.mapAsync(GPUMapMode.READ); + const t = new BigInt64Array(readBuf.getMappedRange()); + const us = Number(t[1] - t[0]) / 1e3; + gpuMs += us; + readBuf.unmap(); + } + const wallMs = (typeof performance !== "undefined" ? performance.now() : Date.now()) - tWall0; + resolveBuf.destroy?.(); + readBuf.destroy?.(); + qs.destroy?.(); + usedGPU = true; + a.destroy?.(); + g.destroy?.(); + y.destroy?.(); + return gpuMs / iters / 1e3; + } + const t0 = typeof performance !== "undefined" ? performance.now() : Date.now(); + for (let i = 0; i < iters; i++) { + const enc = this.dev.createCommandEncoder(); + const bg = this._bg(pipe, buffers); + this._dispatch(enc, pipe, bg, Math.ceil(n / (pipe.__wg || 256)), 1, label + ":bench", imm); + this.dev.queue.submit([enc.finish()]); + if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); + } + const ms = (typeof performance !== "undefined" ? performance.now() : Date.now()) - t0; + a.destroy?.(); + g.destroy?.(); + y.destroy?.(); + return ms / iters; + }, "timeKernel"); + const kernels = [ + { name: "add", src: ADD, n: 8192, buffers: /* @__PURE__ */ __name((a, y) => [a, y], "buffers"), imm: /* @__PURE__ */ __name((n) => new Uint32Array([n]), "imm") }, + { name: "rms", src: RMSNORM, n: 4096, buffers: /* @__PURE__ */ __name((a, y, g) => [a, g, y], "buffers"), imm: /* @__PURE__ */ __name((n) => new Float32Array([n, this.cfg.rmsNormEps]), "imm") }, + { name: "silu", src: SILUMUL, n: 8192, buffers: /* @__PURE__ */ __name((a, y) => [a, y], "buffers"), imm: /* @__PURE__ */ __name((n) => new Uint32Array([n]), "imm") } + ]; + for (const k of kernels) { + try { + let best = { wg: 256, ms: Infinity }; + for (const wg of cands) { + const p = this._pipe(k.src, `${k.name}:autotune:${wg}`, { WG: wg }); + p.__wg = wg; + const ms = await timeKernel(k, p, `${k.name}${wg}`); + results[`${k.name}:${wg}`] = ms; + if (ms < best.ms) best = { wg, ms }; + } + results[`best${k.name[0].toUpperCase()}${k.name.slice(1)}`] = best; + if (opts.apply && this.pipes[k.name]) { + this.pipes[k.name] = this._pipe(k.src, k.name, { WG: best.wg }); + this.pipes[k.name].__wg = best.wg; + } + } catch (e) { + results[`${k.name}Error`] = String(e); + } + } + this.bestWorkgroupSizes = { + add: results.bestAdd?.wg, + rms: results.bestRms?.wg, + silu: results.bestSilu?.wg, + source: useTS ? "gpu-ts" : "wall" + }; + console.log("[autotune] WG microbench results (ms/iter, source=" + (useTS ? "gpu-ts" : "wall") + "):", results); + return results; + } + _buf(size, usage = STORAGE) { + return this.pool.buffer(size, usage); + } + _f32(arr, usage = STORAGE) { + return this.pool.uploadF32(arr, usage); + } + _u32(arr) { + return this.pool.uploadU32(arr, STORAGE); + } + _uni(arr) { + return this.pool.dynamicUniform(arr, UNIFORM); + } + _staticUni(key, arr) { + return this.pool.staticUniform(key, arr, UNIFORM); + } + _resetUni() { + this.pool.resetUniforms(); + this.lastDispatchCount = 0; + } + _pipe(code, name, overrides = null) { + const processedCode = typeof code === "string" ? code.replaceAll("WG_SIZE", this.workgroupSize || 64) : code; + const m = this.dev.createShaderModule({ + label: name || void 0, + code: processedCode + }); + const comp = { module: m, entryPoint: "main" }; + if (overrides && typeof overrides === "object") comp.constants = overrides; + const pipe = this.dev.createComputePipeline({ + label: name ? `${name}-pipeline` : void 0, + layout: "auto", + compute: comp + }); + if (overrides?.WG) pipe.__wg = overrides.WG; + if (name) pipe.__name = name; + return pipe; + } + /* + * TECHNIQUE: Specialization via pipeline constants (overrides) + * Workgroup size and other small values are passed as pipeline-overridable + * constants instead of uniforms or JS branches. Allows the shader compiler + * to specialize the binary (better than runtime if). + */ + // `source` is a base URL string OR a reader { range, text } (e.g. hfReader/fileReader). + async build(source, onProgress = () => { + }) { + const shaderCompileStart = performance.now(); + const dev = this.dev, c = this.cfg; + this.CHUNK = 128; + this._initRuntimeOptions(); + this.maxCtx = this.opts.maxCtx || 8192; + this.maxPrefillT = Math.min(this.opts.maxPrefillT || 8192, this.maxCtx); + const isAppleSilicon = this.dev.limits.minStorageBufferOffsetAlignment === 4; + const isIntelArc = this.dev.limits.minStorageBufferOffsetAlignment === 256; + this.workgroupSize = isAppleSilicon || isIntelArc ? 32 : 64; + onProgress && onProgress(`workgroup size chosen: ${this.workgroupSize} (apple/intel bias toward 32)`, 0); + let hasDP4a = false; + if (typeof navigator !== "undefined" && navigator.gpu?.wgslLanguageFeatures?.has?.("packed_4x8_integer_dot_product")) { + dev.pushErrorScope("validation"); + try { + dev.createShaderModule({ + code: `enable packed_4x8_integer_dot_product; @compute @workgroup_size(1) fn main() {}` + }); + const error = await dev.popErrorScope(); + if (!error) { + hasDP4a = true; + } + } catch (e) { + await dev.popErrorScope(); + } + } + this.hasDP4a = hasDP4a; + const hasF16 = this.dev.features.has("shader-f16"); + this.hasF16 = hasF16; + this.hasTimestampQuery = this.dev.features.has("timestamp-query"); + this.pam = new PagedAttentionManager(this.maxCtx); + this.pipes = { + gemv: this._pipe(GEMV, "gemv"), + loraA: this._pipe(LORA_A, "loraA"), + loraABatch: this._pipe(LORA_A_BATCH, "loraABatch"), + loraBAdd: this._pipe(LORA_B_ADD, "loraBAdd"), + loraBAddT: this._pipe(LORA_B_ADD_T, "loraBAddT"), + rms: this._pipe(RMSNORM, "rms", { WG: this.workgroupSize || 256 }), + rmsF16: hasF16 ? this._pipe(RMSNORM_F16, "rmsF16", { WG: this.workgroupSize || 256 }) : null, + rope: this._pipe(ROPE, "rope"), + ropeF16: hasF16 ? this._pipe(ROPE_F16, "ropeF16") : null, + ropeQK: this._pipe(ROPE_QK, "ropeQK"), + ropeQKF16: hasF16 ? this._pipe(ROPE_QK_F16, "ropeQKF16") : null, + ropeT: this._pipe(ROPE_T, "ropeT"), + ropeTF16: hasF16 ? this._pipe(ROPE_T_F16, "ropeTF16") : null, + attnP: this._pipe(ATTN_PARTIAL, "attnP", { WG: 128 }), + attnPF16: hasF16 ? this._pipe(ATTN_PARTIAL_F16, "attnPF16", { WG: 128 }) : null, + attnC: this._pipe(ATTN_COMBINE, "attnC", { WG: 128 }), + attnCF16: hasF16 ? this._pipe(ATTN_COMBINE_F16, "attnCF16", { WG: 128 }) : null, + add: this._pipe(ADD, "add", { WG: this.workgroupSize || 256 }), + silu: this._pipe(SILUMUL, "silu", { WG: this.workgroupSize || 256 }), + addF16: hasF16 ? this._pipe(ADD_F16, "addF16", { WG: this.workgroupSize || 256 }) : null, + siluF16: hasF16 ? this._pipe(SILUMUL_F16, "siluF16", { WG: this.workgroupSize || 256 }) : null, + embed: this._pipe(EMBED, "embed"), + embedBuf: this._pipe(EMBED_BUF, "embedBuf"), + argmax: this._pipe(ARGMAX, "argmax"), + gemv4: this._pipe(GEMV4, "gemv4"), + gemv4Add: this._pipe(GEMV4_ADD, "gemv4Add"), + qkvGemv4: this._pipe(QKV_GEMV4, "qkvGemv4"), + gateUpSiluGemv4: this._pipe(GATE_UP_SILU_GEMV4, "gateUpSiluGemv4"), + topkSelect: this._pipe(TOPK_SELECT, "topkSelect"), + sampleTopK: this._pipe(SAMPLE_TOPK, "sampleTopK"), + gemm4: this._pipe(GEMM4, "gemm4"), + gemm4AddT: this._pipe(GEMM4_ADD_T, "gemm4AddT"), + rmsT: this._pipe(RMSNORM_T, "rmsT", { WG: this.workgroupSize || 256 }), + rmsTF16: hasF16 ? this._pipe(RMSNORM_T_F16, "rmsTF16", { WG: this.workgroupSize || 256 }) : null, + embedT: this._pipe(EMBED_T, "embedT"), + attnPrefill: this._pipe(ATTN_PREFILL, "attnPrefill"), + attnPrefillBlock: this._pipe(ATTN_PREFILL_BLOCK, "attnPrefillBlock"), + dynQuant: this._pipe(DYN_QUANT_X, "dynQuant"), + dynQuantT: this._pipe(DYN_QUANT_X_T, "dynQuantT"), + gemv4W4A8: this._pipe(GEMV4_W4A8(hasDP4a, this.workgroupSize), "gemv4W4A8"), + gemv4AddW4A8: this._pipe(GEMV4_ADD_W4A8(hasDP4a, this.workgroupSize), "gemv4AddW4A8"), + qkvGemv4W4A8: this._pipe(QKV_GEMV4_W4A8(hasDP4a, this.workgroupSize), "qkvGemv4W4A8"), + gateUpSiluGemv4W4A8: this._pipe(GATE_UP_SILU_GEMV4_W4A8(hasDP4a, this.workgroupSize), "gateUpSiluGemv4W4A8"), + gemm4W4A8: this._pipe(GEMM4_W4A8(hasDP4a), "gemm4W4A8"), + gemm4AddTW4A8: this._pipe(GEMM4_ADD_T_W4A8(hasDP4a), "gemm4AddTW4A8"), + rmsNormQkvRope: this._pipe(GEMV4_QKV_ROPE_RMS, "rmsNormQkvRope"), + writeKvPage: this._pipe(WRITE_KV_PAGE, "writeKvPage"), + writeKvPageBatch: this._pipe(WRITE_KV_PAGE_BATCH, "writeKvPageBatch"), + attnPartialPaged: this._pipe(ATTN_PARTIAL_PAGED, "attnPartialPaged"), + attnPrefillPaged: this._pipe(ATTN_PREFILL_PAGED, "attnPrefillPaged"), + attnPrefillBlockPaged: this._pipe(ATTN_PREFILL_BLOCK_PAGED, "attnPrefillBlockPaged") + }; + this.shaderCompileMs = performance.now() - shaderCompileStart; + if (hasF16) { + this.setUseF16(true); + onProgress("f16 compute enabled (add/silu/rms/rope/attn-partial/combine paths)", 0); + } + if (this.hasTimestampQuery) { + onProgress("timestamp-query available (precise GPU timing + autotune)", 0); + } + onProgress("streaming + quantizing weights", 0); + this.schema = createQwenSchema(c); + this.plan = createDispatchPlan(this.schema); + this.q = {}; + this.q4 = {}; + this.qkv = []; + this.gateUp = []; + const uploader = new ModelUploader({ + schema: this.schema, + q: this.q, + q4: this.q4, + bufs: this.bufs, + uploadF32: /* @__PURE__ */ __name((arr) => this._f32(arr), "uploadF32"), + uploadU32: /* @__PURE__ */ __name((arr) => this._u32(arr), "uploadU32") + }); + if (source === "mock") { + for (const name of this.schema.expectedNames) { + const desc = this.schema.tensors.find((t) => t.name === name); + const shape = desc.shape; + const numel = shape.reduce((a, b) => a * b, 1); + const type = desc.quant === "int8" ? "I8" : "F32"; + uploader.visit({ name, shape, data: new Uint8Array(numel * (type === "I8" ? 1 : 4)), type }); + } + } else { + await streamSafetensors(source, { + names: this.schema.expectedNames, + onProgress, + onTensor: /* @__PURE__ */ __name(async (tensor) => { + uploader.visit(tensor); + if (uploader.seen.size % 48 === 0) await new Promise((r) => setTimeout(r, 0)); + }, "onTensor") + }); + } + uploader.finalize(); + await this._buildPackedProjectionBuffers(); + this._buildRope(this.maxCtx); + this.kc = [], this.vc = []; + const kvSize = c.numKVHeads * this.maxCtx * c.headDim * 4; + for (let i = 0; i < c.numLayers; i++) { + this.kc.push(this._buf(kvSize)); + this.vc.push(this._buf(kvSize)); + } + const H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize; + const NSPLITMAX = Math.ceil(this.maxCtx / this.CHUNK); + this.s = { + hidden: this._buf(H * 4), + normed: this._buf(H * 4), + q: this._buf(qd * 4), + k: this._buf(kvd * 4), + v: this._buf(kvd * 4), + attn: this._buf(qd * 4), + tmp: this._buf(Math.max(qd, I) * 4), + tmp2: this._buf(I * 4), + logits: this._buf(c.vocabSize * 4), + dummy: this._buf(64), + loraD: this._buf(256 * 4), + loraD2: this._buf(256 * 4), + amax: this._buf(4), + pm: this._buf(c.numHeads * NSPLITMAX * 4), + pz: this._buf(c.numHeads * NSPLITMAX * 4), + po: this._buf(c.numHeads * NSPLITMAX * c.headDim * 4), + idsBuf: this._buf(this.decodeBatchCapacity * 4), + sampleIds: this._buf(this.maxSamplingTopK * 4), + sampleVals: this._buf(this.maxSamplingTopK * 4), + sampled: this._buf(4), + // single u32 chosen by GPU sampler (Phase 5) + x_q: this._buf(Math.max(qd, I) * 4), + scale_x: this._buf(256 * 4), + blockTableBuf: this._buf(this.pam.maxBlocksPerSeq * 4, STORAGE | GPUBufferUsage.COPY_DST) + }; + this.idsRead = this._buf(this.decodeBatchCapacity * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + this.argmaxRead = this._buf(4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + this.sampleIdsRead = this._buf(this.maxSamplingTopK * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + this.sampleValsRead = this._buf(this.maxSamplingTopK * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + this.sampledRead = this._buf(4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + this.sT = null; + this.sTcap = 0; + this._initStaticUniforms(); + if (this.decodeBatchMode === "auto") { + onProgress("autotuning decode batch", 0.98); + await this.autotuneDecodeBatch(); + } + onProgress("ready", 1); + if (!this._didAutoWG) { + this._didAutoWG = true; + this.workgroupAutotunePromise = this.autotuneWorkgroups({ iters: 2, apply: true }).catch((e) => ({ + error: String(e) + })); + } + return this; + } + _initRuntimeOptions() { + const opts = this.opts; + this.decodeBatchMode = opts.decodeBatchSize === "auto" ? "auto" : "fixed"; + this.decodeBatchCandidates = (opts.decodeBatchCandidates || [1, 2, 4, 8, 16, 32]).map((x) => Math.max(1, Math.floor(Number(x) || 0))).filter(Boolean); + const requested = opts.decodeBatchSize === void 0 || opts.decodeBatchSize === "auto" ? 16 : Math.max(1, Math.floor(Number(opts.decodeBatchSize))); + this.maxDecodeBatchSize = Math.max( + 1, + Math.floor(Number(opts.maxDecodeBatchSize || Math.max(requested, ...this.decodeBatchCandidates, 16))) + ); + this.decodeBatchCapacity = Math.min(this.maxDecodeBatchSize, Math.max(requested, ...this.decodeBatchCandidates)); + this.MAXBATCH = Math.min(requested, this.decodeBatchCapacity); + this.decodeBatchWarmupTokens = Math.max(0, Math.floor(Number(opts.decodeBatchWarmupTokens ?? 4))); + this.decodeBatchWarmupSize = Math.min( + this.decodeBatchCapacity, + Math.max(1, Math.floor(Number(opts.decodeBatchWarmupSize ?? 4))) + ); + this.decodeBatchMaxLatencyMs = Number(opts.decodeBatchMaxLatencyMs ?? 250); + this.samplingTopK = Math.max(1, Math.floor(Number(opts.samplingTopK ?? 40))); + this.maxSamplingTopK = Math.max(this.samplingTopK, Math.floor(Number(opts.maxSamplingTopK ?? 64))); + this.decodeBatchTuning = { + selected: this.MAXBATCH, + candidates: [], + reason: this.decodeBatchMode === "auto" ? "pending" : "fixed" + }; + } + _buildRope(maxSeq) { + const { headDim, ropeTheta } = this.cfg; + const half = headDim / 2; + const cos = new Float32Array(maxSeq * headDim), sin = new Float32Array(maxSeq * headDim); + for (let p = 0; p < maxSeq; p++) + for (let i = 0; i < half; i++) { + const a = p / Math.pow(ropeTheta, 2 * i / headDim); + const cc = Math.cos(a), ss = Math.sin(a); + cos[p * headDim + i] = cc; + cos[p * headDim + half + i] = cc; + sin[p * headDim + i] = ss; + sin[p * headDim + half + i] = ss; + } + this.ropeCos = this._f32(cos); + this.ropeSin = this._f32(sin); + this._ropeRow = headDim * 4; + } + _initStaticUniforms() { + const c = this.cfg; + const rms = new ArrayBuffer(8); + const rmsDv = new DataView(rms); + rmsDv.setFloat32(0, c.hiddenSize, true); + rmsDv.setFloat32(4, c.rmsNormEps, true); + this.u = { + rmsHidden: this._staticUni(`rms:${c.hiddenSize}:${c.rmsNormEps}`, new Uint8Array(rms)), + addHidden: this._staticUni(`u32:${c.hiddenSize}`, new Uint32Array([c.hiddenSize])), + siluIntermediate: this._staticUni(`u32:${c.intermediateSize}`, new Uint32Array([c.intermediateSize])), + embedBuf: this._staticUni(`embedBuf:${c.hiddenSize}`, new Uint32Array([c.hiddenSize])), + argmax: this._staticUni(`argmax:${c.vocabSize}`, new Uint32Array([c.vocabSize])) + }; + } + async _buildPackedProjectionBuffers() { + const enc = this.dev.createCommandEncoder(); + const copy = /* @__PURE__ */ __name((src, dst, dstOffset, bytes) => enc.copyBufferToBuffer(src, 0, dst, dstOffset, bytes), "copy"); + this.packedBytes = 0; + for (const L of this.plan.layers) { + const q = this.q4[L.q.weight], k = this.q4[L.k.weight], v = this.q4[L.v.weight]; + if (q.K !== k.K || q.K !== v.K || q.gpr !== k.gpr || q.gpr !== v.gpr) + throw new Error(`layer ${L.index} qkv packing requires matching K/gpr`); + const totalN = q.N + k.N + v.N; + const wBytes = totalN * (q.K / 8) * 4; + const scaleBytes = totalN * q.gpr * 4; + const biasBytes = totalN * 4; + const w = this._buf(wBytes); + const scale = this._buf(scaleBytes); + const bias = this._buf(biasBytes); + enc.clearBuffer(bias); + let wOff = 0, sOff = 0, bOff = 0; + for (const part of [L.q, L.k, L.v]) { + const qq = this.q4[part.weight]; + const rowsW = qq.N * (qq.K / 8) * 4; + const rowsS = qq.N * qq.gpr * 4; + copy(qq.w, w, wOff, rowsW); + wOff += rowsW; + copy(qq.scale, scale, sOff, rowsS); + sOff += rowsS; + if (part.bias) copy(this.bufs[part.bias], bias, bOff, qq.N * 4); + bOff += qq.N * 4; + } + this.qkv[L.index] = { w, scale, bias, K: q.K, qN: q.N, kN: k.N, vN: v.N, totalN, gpr: q.gpr }; + this.packedBytes += wBytes + scaleBytes + biasBytes; + const gate = this.q4[L.gate.weight], up = this.q4[L.up.weight]; + if (gate.K !== up.K || gate.N !== up.N || gate.gpr !== up.gpr) + throw new Error(`layer ${L.index} gate/up packing requires matching shape`); + const guWBytes = (gate.N + up.N) * (gate.K / 8) * 4; + const guScaleBytes = (gate.N + up.N) * gate.gpr * 4; + const guW = this._buf(guWBytes); + const guScale = this._buf(guScaleBytes); + copy(gate.w, guW, 0, gate.N * (gate.K / 8) * 4); + copy(up.w, guW, gate.N * (gate.K / 8) * 4, up.N * (up.K / 8) * 4); + copy(gate.scale, guScale, 0, gate.N * gate.gpr * 4); + copy(up.scale, guScale, gate.N * gate.gpr * 4, up.N * up.gpr * 4); + this.gateUp[L.index] = { w: guW, scale: guScale, K: gate.K, N: gate.N, gpr: gate.gpr }; + this.packedBytes += guWBytes + guScaleBytes; + } + this.dev.queue.submit([enc.finish()]); + await this.dev.queue.onSubmittedWorkDone(); + } + memoryFootprintBytes() { + const c = this.cfg; + const kvBytes = c.numLayers * 2 * c.numKVHeads * this.maxCtx * c.headDim * 4; + const decodeScratchBytes = c.hiddenSize * 2 * 4 + (c.numHeads * c.headDim + 2 * c.numKVHeads * c.headDim + c.numHeads * c.headDim) * 4 + (Math.max(c.numHeads * c.headDim, c.intermediateSize) + c.intermediateSize + c.vocabSize) * 4; + const prefillScratchBytes = this.sTcap ? this.sTcap * (3 * c.hiddenSize + c.numHeads * c.headDim + 2 * c.numKVHeads * c.headDim + c.numHeads * c.headDim + 2 * c.intermediateSize) * 4 : 0; + return { kvBytes, decodeScratchBytes, prefillScratchBytes, packedBytes: this.packedBytes }; + } + _gemvMeta(q, biasBuf, mod) { + const gx = Math.min(q.N, 65535); + const bytes = new Uint8Array(32); + const dv = new DataView(bytes.buffer); + dv.setUint32(0, q.K, true); + dv.setUint32(4, q.N, true); + dv.setUint32(8, mod ? mod.rank : 0, true); + dv.setUint32(12, biasBuf ? 1 : 0, true); + dv.setUint32(16, mod ? 1 : 0, true); + dv.setUint32(20, gx, true); + dv.setFloat32(24, mod ? mod.scale : 0, true); + return { + gx, + gy: Math.ceil(q.N / gx), + bytes + }; + } + _gemv4Meta(q, biasBuf, mod) { + const gx = Math.min(q.N, 65535); + const bytes = new Uint8Array(32); + const dv = new DataView(bytes.buffer); + dv.setUint32(0, q.K, true); + dv.setUint32(4, q.N, true); + dv.setUint32(8, mod ? mod.rank : 0, true); + dv.setUint32(12, biasBuf ? 1 : 0, true); + dv.setUint32(16, mod ? 1 : 0, true); + dv.setUint32(20, gx, true); + dv.setFloat32(24, mod ? mod.scale : 0, true); + dv.setUint32(28, q.gpr, true); + return { + gx, + gy: Math.ceil(q.N / gx), + bytes + }; + } + setLora(adapter) { + this.lora = adapter; + this._loraEpoch++; + this.pool.clearSensitiveBindGroups(); + } + // {modules: {key:{A,B,rank,scale}}} A:[K][rank], B:[rank][N] f32 GPUBuffers + clearLora() { + this.lora = null; + this._loraEpoch++; + this.pool.clearSensitiveBindGroups(); + } + // Called after an in-place mutation of the active adapter's A/B buffers (e.g. an + // optimizer step during training). Bumps the LoRA epoch so cached bind groups that + // referenced the old contents are dropped and inference re-binds the mutated buffers. + invalidateLora() { + this._loraEpoch++; + this.pool.clearSensitiveBindGroups(); + } + _bg(pipe, buffers) { + return this.pool.uncachedBindGroup(pipe, buffers); + } + _bgCached(pipe, buffers, key, opts) { + return this.pool.cachedBindGroup(pipe, buffers, key, opts); + } + _dispatch(enc, pipe, bg, gx, gy = 1, cat, imm = null) { + this.lastDispatchCount++; + let ts; + if (this.prof && this.prof.idx < this.prof.cap) { + const i = this.prof.idx++; + this.prof.cats.push(cat || "misc"); + ts = { querySet: this.prof.qs, beginningOfPassWriteIndex: 2 * i, endOfPassWriteIndex: 2 * i + 1 }; + } + const p = enc.beginComputePass(ts ? { timestampWrites: ts } : void 0); + p.setPipeline(pipe); + if (bg) p.setBindGroup(0, bg); + if (imm) { + if (Array.isArray(imm)) { + let off = 0; + for (const part of imm) { + p.setImmediates(off, part); + off += part.byteLength || part.length * (part.BYTES_PER_ELEMENT || 4); + } + } else { + p.setImmediates(0, imm); + } + } + p.dispatchWorkgroups(gx, gy); + p.end(); + } + enableProf(cap = 700) { + this.prof = { + qs: this.dev.createQuerySet({ type: "timestamp", count: cap * 2 }), + cap, + idx: 0, + cats: [], + resolve: this._buf(cap * 16, GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC), + read: this._buf(cap * 16, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ) + }; + } + async profToken(id, pos) { + this._resetUni(); + this.prof.idx = 0; + this.prof.cats = []; + const enc = this.dev.createCommandEncoder(); + this.embedRow(enc, id); + this.step(enc, id, pos); + const n = this.prof.idx; + enc.resolveQuerySet(this.prof.qs, 0, n * 2, this.prof.resolve, 0); + enc.copyBufferToBuffer(this.prof.resolve, 0, this.prof.read, 0, n * 16); + this.dev.queue.submit([enc.finish()]); + await this.prof.read.mapAsync(GPUMapMode.READ); + const t = new BigInt64Array(this.prof.read.getMappedRange()); + const sums = {}; + for (let i = 0; i < n; i++) { + const us = Number(t[2 * i + 1] - t[2 * i]) / 1e3; + const c = this.prof.cats[i]; + sums[c] = (sums[c] || 0) + us; + } + this.prof.read.unmap(); + return sums; + } + poolStats() { + return this.pool.stats(); + } + // Phase 4 observability: best workgroup sizes chosen by autotune (or null if not run). + getBestWorkgroupSizes() { + return this.bestWorkgroupSizes ? { ...this.bestWorkgroupSizes } : null; + } + resetPoolStats() { + this.pool.resetStats(); + } + estimateKvCacheBytes() { + const c = this.cfg; + return c.numLayers * 2 * c.numKVHeads * this.maxCtx * c.headDim * 4; + } + estimatePrefillScratchBytes(T, loraRank = this._activeMaxLoraRank()) { + const c = this.cfg, H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize; + return T * H * 4 * 2 + T * qd * 4 * 2 + T * kvd * 4 * 2 + T * I * 4 * 2 + T * 4 + Math.max(1, T * Math.max(1, loraRank)) * 4; + } + greedyBatchSizeFor({ emitted = 0, remaining = Infinity, pos = 0 } = {}) { + const interactive = emitted < this.decodeBatchWarmupTokens ? this.decodeBatchWarmupSize : this.MAXBATCH; + return Math.max(0, Math.min(interactive, remaining, this.maxCtx - pos, this.decodeBatchCapacity)); + } + async _resetAutotuneDecodeState(tokens, seedTokenId = 0) { + const c = this.cfg, S = this.s, H = c.hiddenSize, hd = c.headDim, qd = c.numHeads * hd, kvd = c.numKVHeads * hd, I = c.intermediateSize; + const nsplitMax = Math.ceil(this.maxCtx / this.CHUNK); + const touchedTokens = Math.min(Math.max(0, Math.floor(tokens)), this.maxCtx); + const enc = this.dev.createCommandEncoder(); + const clear = /* @__PURE__ */ __name((buf, bytes) => { + if (bytes > 0) enc.clearBuffer(buf, 0, bytes); + }, "clear"); + clear(S.hidden, H * 4); + clear(S.normed, H * 4); + clear(S.q, qd * 4); + clear(S.k, kvd * 4); + clear(S.v, kvd * 4); + clear(S.attn, qd * 4); + clear(S.tmp, Math.max(qd, I) * 4); + clear(S.tmp2, I * 4); + clear(S.logits, c.vocabSize * 4); + clear(S.loraD, 256 * 4); + clear(S.idsBuf, this.decodeBatchCapacity * 4); + clear(S.pm, c.numHeads * nsplitMax * 4); + clear(S.pz, c.numHeads * nsplitMax * 4); + clear(S.po, c.numHeads * nsplitMax * hd * 4); + const kvBytes = touchedTokens * kvd * 4; + for (let i = 0; i < c.numLayers; i++) { + clear(this.kc[i], kvBytes); + clear(this.vc[i], kvBytes); + } + this.dev.queue.submit([enc.finish()]); + this.dev.queue.writeBuffer(S.amax, 0, new Uint32Array([seedTokenId])); + if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); + } + async autotuneDecodeBatch() { + const candidates = [...new Set(this.decodeBatchCandidates)].filter((k) => k >= 1 && k <= this.decodeBatchCapacity && k <= this.maxCtx).sort((a, b) => a - b); + const rows = []; + const resetTokens = candidates.length ? Math.max(...candidates) : 0; + let selected = candidates[0] ?? this.MAXBATCH, best = Infinity; + try { + for (const k of candidates) { + await this._resetAutotuneDecodeState(resetTokens); + const t0 = performance.now(); + await this.decodeGreedyBatch(0, k); + const ms = performance.now() - t0; + const msPerToken = ms / k; + rows.push({ k, ms, msPerToken }); + const latencyOk = !Number.isFinite(this.decodeBatchMaxLatencyMs) || ms <= this.decodeBatchMaxLatencyMs; + if (latencyOk && msPerToken < best) { + best = msPerToken; + selected = k; + } + } + if (!rows.some((r) => r.k === selected) && rows.length) + selected = rows.reduce((a, b) => a.msPerToken <= b.msPerToken ? a : b).k; + this.MAXBATCH = selected; + this.decodeBatchTuning = { + selected, + candidates: rows, + reason: "auto wall-clock decodeGreedyBatch with reset state" + }; + } catch (e) { + this.decodeBatchTuning = { selected: this.MAXBATCH, candidates: rows, reason: `auto failed: ${e.message}` }; + } finally { + if (resetTokens > 0) { + try { + await this._resetAutotuneDecodeState(resetTokens); + } catch { + } + } + } + return this.decodeBatchTuning; + } + // y = int8-GEMV(x, q) [+bias] [+lora]. q={w,scale,N,K}. moduleKey for LoRA lookup. + gemv(enc, xBuf, q, yBuf, biasBuf, moduleKey) { + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); + const meta = this._gemvMeta(q, biasBuf, mod); + const key = `gemv:${moduleKey || "base"}:${q.K}:${q.N}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; + const bg = this._bgCached( + this.pipes.gemv, + [xBuf, q.w, q.scale, biasBuf || this.s.dummy, this.s.loraD, mod ? mod.B : this.s.dummy, yBuf], + key, + { sensitive: !!mod } + ); + this._dispatch(enc, this.pipes.gemv, bg, meta.gx, meta.gy, `gemv:${q.N}x${q.K}`, meta.bytes); + } + gemv4(enc, xBuf, q, yBuf, biasBuf, moduleKey) { + const mod = this.lora?.modules?.[moduleKey]; + if (this.debugCapture) console.log("VWG gemv4: " + moduleKey + " mod=" + !!mod); + if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); + const meta = this._gemv4Meta(q, biasBuf, mod); + const key = `gemv4:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; + const bg = this._bgCached( + this.pipes.gemv4, + [xBuf, q.w, q.scale, biasBuf || this.s.dummy, this.s.loraD, mod ? mod.B : this.s.dummy, yBuf], + key, + { sensitive: !!mod } + ); + this._dispatch(enc, this.pipes.gemv4, bg, meta.gx, meta.gy, `g4:${q.N}x${q.K}`, meta.bytes); + if (mod) { + if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj" && this.debugStep < this.debugT) { + enc.copyBufferToBuffer(yBuf, 0, this.debugBufs.ySeq, this.debugStep * q.N * 4, q.N * 4); + this.debugStep++; + } + } + } + _loraA(enc, xBuf, q, mod, dBuf, moduleKey, label = "loraA") { + const imm = new Uint32Array([q.K, mod.rank]); + this._dispatch( + enc, + this.pipes.loraA, + this._bgCached(this.pipes.loraA, [xBuf, mod.A, dBuf], `${label}:${moduleKey}:${this._loraEpoch}`, { + sensitive: true + }), + mod.rank, + 1, + label, + imm + ); + if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj" && this.debugStep < this.debugT) { + enc.copyBufferToBuffer(xBuf, 0, this.debugBufs.xSeq, this.debugStep * q.K * 4, q.K * 4); + enc.copyBufferToBuffer(dBuf, 0, this.debugBufs.dSeq, this.debugStep * mod.rank * 4, mod.rank * 4); + } + } + _loraBAdd(enc, yBuf, q, mod, dBuf, moduleKey) { + const meta = new ArrayBuffer(32); + const dv = new DataView(meta); + dv.setUint32(0, q.N, true); + dv.setUint32(4, mod.rank, true); + dv.setFloat32(16, mod.scale, true); + const bg = this._bgCached( + this.pipes.loraBAdd, + [dBuf, mod.B, yBuf], + `loraBAdd:${moduleKey}:${this._loraEpoch}`, + { sensitive: true } + ); + this._dispatch(enc, this.pipes.loraBAdd, bg, Math.ceil(q.N / 256), 1, "loraB", new Uint8Array(meta)); + if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj" && this.debugStep < this.debugT) { + enc.copyBufferToBuffer(yBuf, 0, this.debugBufs.ySeq, this.debugStep * q.N * 4, q.N * 4); + this.debugStep++; + } + } + gemv4Add(enc, xBuf, q, yBuf, biasBuf, moduleKey) { + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); + const meta = this._gemv4Meta(q, biasBuf, mod); + const key = `gemv4add:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; + const bg = this._bgCached( + this.pipes.gemv4Add, + [xBuf, q.w, q.scale, biasBuf || this.s.dummy, this.s.loraD, mod ? mod.B : this.s.dummy, yBuf], + key, + { sensitive: !!mod } + ); + this._dispatch(enc, this.pipes.gemv4Add, bg, meta.gx, meta.gy, `g4add:${q.N}x${q.K}`, meta.bytes); + } + dynQuant(enc, xBuf, x_qBuf, scale_xBuf, K) { + const numGroups = Math.ceil(K / 128); + const imm = new Uint32Array([K]); + const bg = this._bg(this.pipes.dynQuant, [xBuf, x_qBuf, scale_xBuf]); + this._dispatch(enc, this.pipes.dynQuant, bg, numGroups, 1, "dynQuant", imm); + } + dynQuantT(enc, xBuf, x_qBuf, scale_xBuf, K, T) { + const numGroups = Math.ceil(K / 128); + const imm = new Uint32Array([K, T]); + const bg = this._bg(this.pipes.dynQuantT, [xBuf, x_qBuf, scale_xBuf]); + this._dispatch(enc, this.pipes.dynQuantT, bg, numGroups, T, "dynQuantT", imm); + } + gemv4W4A8(enc, xBuf, x_qBuf, scale_xBuf, q, yBuf, biasBuf, moduleKey) { + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); + const meta = this._gemv4Meta(q, biasBuf, mod); + const key = `gemv4_w4a8:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; + const bg = this._bgCached( + this.pipes.gemv4W4A8, + [ + x_qBuf, + scale_xBuf, + q.w, + q.scale, + biasBuf || this.s.dummy, + this.s.loraD, + mod ? mod.B : this.s.dummy, + yBuf + ], + key, + { sensitive: !!mod } + ); + this._dispatch(enc, this.pipes.gemv4W4A8, bg, meta.gx, meta.gy, `g4w4a8:${q.N}x${q.K}`, meta.bytes); + } + gemv4AddW4A8(enc, xBuf, x_qBuf, scale_xBuf, q, yBuf, biasBuf, moduleKey) { + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this._loraA(enc, xBuf, q, mod, this.s.loraD, moduleKey); + const meta = this._gemv4Meta(q, biasBuf, mod); + const key = `gemv4add_w4a8:${moduleKey || "base"}:${q.K}:${q.N}:${q.gpr}:${biasBuf ? 1 : 0}:${mod ? this._loraEpoch : 0}`; + const bg = this._bgCached( + this.pipes.gemv4AddW4A8, + [ + x_qBuf, + scale_xBuf, + q.w, + q.scale, + biasBuf || this.s.dummy, + this.s.loraD, + mod ? mod.B : this.s.dummy, + yBuf + ], + key, + { sensitive: !!mod } + ); + this._dispatch(enc, this.pipes.gemv4AddW4A8, bg, meta.gx, meta.gy, `g4addw4a8:${q.N}x${q.K}`, meta.bytes); + } + qkvGemv4W4A8(enc, xBuf, x_qBuf, scale_xBuf, packed, qBuf, kBuf, vBuf, L) { + const gx = Math.min(packed.totalN, 65535); + const imm = new Uint32Array([packed.K, packed.totalN, packed.qN, packed.kN, packed.vN, packed.gpr, gx, 0]); + const bg = this._bgCached( + this.pipes.qkvGemv4W4A8, + [x_qBuf, scale_xBuf, packed.w, packed.scale, packed.bias, qBuf, kBuf, vBuf], + `qkv_w4a8:${L.index}`, + { sensitive: false } + ); + this._dispatch( + enc, + this.pipes.qkvGemv4W4A8, + bg, + gx, + Math.ceil(packed.totalN / gx), + `qkvw4a8:${packed.totalN}x${packed.K}`, + imm + ); + for (const [part, out] of [ + [L.q, qBuf], + [L.k, kBuf], + [L.v, vBuf] + ]) { + const mod = this.lora?.modules?.[part.loraKey]; + if (!mod) continue; + const q = this.q4[part.weight]; + this._loraA(enc, xBuf, q, mod, this.s.loraD, part.loraKey); + this._loraBAdd(enc, out, q, mod, this.s.loraD, part.loraKey); + } + } + _gateUpImmediate(packed, gx, gateMod, upMod) { + const imm = new Uint32Array(12); + imm.set([ + packed.K, + packed.N, + packed.gpr, + gx, + gateMod ? gateMod.rank : 0, + upMod ? upMod.rank : 0, + gateMod ? 1 : 0, + upMod ? 1 : 0 + ]); + const f322 = new Float32Array(imm.buffer); + f322[8] = gateMod ? gateMod.scale : 0; + f322[9] = upMod ? upMod.scale : 0; + return imm; + } + gateUpSiluGemv4W4A8(enc, xBuf, x_qBuf, scale_xBuf, packed, yBuf, L) { + const gate = this.q4[L.gate.weight], up = this.q4[L.up.weight]; + const gateMod = this.lora?.modules?.[L.gate.loraKey]; + const upMod = this.lora?.modules?.[L.up.loraKey]; + if (gateMod) this._loraA(enc, xBuf, gate, gateMod, this.s.loraD, L.gate.loraKey, "loraA:gate"); + if (upMod) this._loraA(enc, xBuf, up, upMod, this.s.loraD2, L.up.loraKey, "loraA:up"); + const gx = Math.min(packed.N, 65535); + const imm = this._gateUpImmediate(packed, gx, gateMod, upMod); + const bg = this._bgCached( + this.pipes.gateUpSiluGemv4W4A8, + [ + x_qBuf, + scale_xBuf, + packed.w, + packed.scale, + yBuf, + this.s.loraD, + gateMod ? gateMod.B : this.s.dummy, + this.s.loraD2, + upMod ? upMod.B : this.s.dummy + ], + `gu_w4a8:${L.index}:${this._loraEpoch}:${gateMod ? 1 : 0}:${upMod ? 1 : 0}`, + { sensitive: !!(gateMod || upMod) } + ); + this._dispatch( + enc, + this.pipes.gateUpSiluGemv4W4A8, + bg, + gx, + Math.ceil(packed.N / gx), + `guw4a8:${packed.N}x${packed.K}`, + imm + ); + } + gemm4W4A8(enc, aBuf, a_qBuf, scale_xBuf, q, yBuf, T, biasBuf, moduleKey) { + const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); + const bg = this._bg(this.pipes.gemm4W4A8, [a_qBuf, scale_xBuf, q.w, q.scale, biasBuf || this.s.dummy, yBuf]); + this._dispatch(enc, this.pipes.gemm4W4A8, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4W4A8", imm); + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); + } + gemm4AddTW4A8(enc, aBuf, a_qBuf, scale_xBuf, q, yBuf, T, biasBuf, moduleKey) { + const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); + const bg = this._bg(this.pipes.gemm4AddTW4A8, [ + a_qBuf, + scale_xBuf, + q.w, + q.scale, + biasBuf || this.s.dummy, + yBuf + ]); + this._dispatch(enc, this.pipes.gemm4AddTW4A8, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4AddTW4A8", imm); + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); + } + // Fused decode: RMSNorm + int4 QKV GEMV + RoPE in one dispatch. The kernel + // assigns ONE workgroup per (head, rotation) pair, so it must be launched with + // totalPairs = (qN+kN+vN)/2 workgroups and the matching grid width — the prior + // `20`-workgroup launch (+ element-count meta) left most Q/K/V outputs unwritten + // and produced garbage tokens. The kernel normalizes x on the fly and has no + // `normed` output, so this path is for the NO-LoRA case only; callers must route + // LoRA-bearing layers to the unfused gemv4x3 path (which can add the adapter). + rmsNormQkvRope(enc, xBuf, layerIndex, pos) { + const c = this.cfg, L = this.plan.layers[layerIndex]; + const packed = this.qkv[L.index]; + const qPairs = packed.qN / 2, kPairs = packed.kN / 2, vPairs = packed.vN / 2; + const totalPairs = qPairs + kPairs + vPairs; + const gx = Math.min(totalPairs, 65535); + const meta = new Uint32Array([ + packed.K, + totalPairs, + qPairs, + kPairs, + vPairs, + packed.gpr, + gx, + pos, + c.headDim, + ...new Uint32Array(new Float32Array([c.rmsNormEps, packed.qN, packed.kN]).buffer) + ]); + const bg = this._bg( + this.pipes.rmsNormQkvRope, + [ + xBuf, + this.bufs[L.inputNorm], + packed.w, + packed.scale, + packed.bias, + this.ropeCos, + this.ropeSin, + this.s.q, + this.s.k, + this.s.v + ] + ); + this._dispatch(enc, this.pipes.rmsNormQkvRope, bg, gx, Math.ceil(totalPairs / gx), "rmsNormQkvRope", meta); + } + writeKvPage(enc, kBuf, vBuf, kcBuf, vcBuf, pos, layerIndex) { + const c = this.cfg; + const kvd = c.numKVHeads * c.headDim; + this.pam.ensureBlocks(0, pos + 1); + const btArr = this.pam.getBlockTableArray(0); + this.dev.queue.writeBuffer(this.s.blockTableBuf, 0, btArr); + const meta = new Uint32Array([pos, 0, this.pam.maxBlocksPerSeq, kvd]); + const bg = this._bg(this.pipes.writeKvPage, [kBuf, vBuf, kcBuf, vcBuf, this.s.blockTableBuf]); + this._dispatch(enc, this.pipes.writeKvPage, bg, Math.ceil(kvd / 256), 1, "writeKvPage", meta); + } + writeKvPageBatch(enc, kBuf, vBuf, kcBuf, vcBuf, T, off, layerIndex) { + const c = this.cfg; + const kvd = c.numKVHeads * c.headDim; + this.pam.ensureBlocks(0, off + T); + const btArr = this.pam.getBlockTableArray(0); + this.dev.queue.writeBuffer(this.s.blockTableBuf, 0, btArr); + const meta = new Uint32Array([T, 0, this.pam.maxBlocksPerSeq, kvd, off]); + const bg = this._bg(this.pipes.writeKvPageBatch, [kBuf, vBuf, kcBuf, vcBuf, this.s.blockTableBuf]); + this._dispatch(enc, this.pipes.writeKvPageBatch, bg, Math.ceil(T * kvd / 256), 1, "writeKvPageBatch", meta); + } + attnPaged(enc, qBuf, kc, vc, oBuf, ctx) { + const c = this.cfg, S = this.s; + const nsplit = Math.ceil(ctx / this.CHUNK); + const bgP = this._bg(this.pipes.attnPartialPaged, [ + qBuf, + kc, + vc, + S.pm, + S.pz, + S.po, + S.blockTableBuf + ]); + const immP = new Uint32Array([c.numHeads, c.numKVHeads, ctx, c.headDim, nsplit, this.CHUNK, 0, this.pam.maxBlocksPerSeq]); + this._dispatch(enc, this.pipes.attnPartialPaged, bgP, c.numHeads, nsplit, "attnP_paged", immP); + const useF16C = this.usingF16() && this.pipes.attnCF16; + const pipeC = useF16C ? this.pipes.attnCF16 : this.pipes.attnC; + const bgC = this._bg(pipeC, [ + S.pm, + S.pz, + S.po, + oBuf + ]); + const immC = new Uint32Array([c.numHeads, c.headDim, nsplit, 0]); + this._dispatch(enc, pipeC, bgC, c.numHeads, 1, useF16C ? "attnCF16" : "attnC", immC); + } + attnPrefillPaged(enc, qBuf, kc, vc, oBuf, T, qStart = 0, ctx = T) { + const c = this.cfg; + if (this.features.prefillAttention === "block" || qStart !== 0 || ctx !== T) { + const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T, qStart, ctx, 0, this.pam.maxBlocksPerSeq]); + this._dispatch( + enc, + this.pipes.attnPrefillBlockPaged, + this._bg(this.pipes.attnPrefillBlockPaged, [qBuf, kc, vc, oBuf, this.s.blockTableBuf]), + c.numHeads, + Math.ceil(T / 4), + "attnPrefillBlockPaged", + imm + ); + } else { + const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T, 0, this.pam.maxBlocksPerSeq, 0, 0]); + this._dispatch( + enc, + this.pipes.attnPrefillPaged, + this._bg(this.pipes.attnPrefillPaged, [ + qBuf, + kc, + vc, + oBuf, + this.s.blockTableBuf + ]), + c.numHeads, + T, + "attnPrefillPaged", + imm + ); + } + } + qkvGemv4(enc, xBuf, packed, qBuf, kBuf, vBuf, L) { + const gx = Math.min(packed.totalN, 65535); + const imm = new Uint32Array([packed.K, packed.totalN, packed.qN, packed.kN, packed.vN, packed.gpr, gx, 0]); + const bg = this._bgCached( + this.pipes.qkvGemv4, + [xBuf, packed.w, packed.scale, packed.bias, qBuf, kBuf, vBuf], + `qkv:${L.index}`, + { sensitive: false } + ); + this._dispatch(enc, this.pipes.qkvGemv4, bg, gx, Math.ceil(packed.totalN / gx), `qkv:${packed.totalN}x${packed.K}`, imm); + for (const [part, out] of [ + [L.q, qBuf], + [L.k, kBuf], + [L.v, vBuf] + ]) { + const mod = this.lora?.modules?.[part.loraKey]; + if (!mod) continue; + const q = this.q4[part.weight]; + this._loraA(enc, xBuf, q, mod, this.s.loraD, part.loraKey); + this._loraBAdd(enc, out, q, mod, this.s.loraD, part.loraKey); + } + } + fusedRmsQkvRope(enc, hiddenBuf, inputNormBuf, packed, qBuf, kBuf, vBuf, pos, L) { + const qPairs = packed.qN / 2; + const kPairs = packed.kN / 2; + const vPairs = packed.vN / 2; + const totalPairs = qPairs + kPairs + vPairs; + const gx = Math.min(totalPairs, 65535); + const meta = new Uint32Array([ + packed.K, + totalPairs, + qPairs, + kPairs, + vPairs, + packed.gpr, + gx, + pos, + this.cfg.headDim, + ...new Uint32Array(new Float32Array([this.cfg.rmsNormEps, packed.qN, packed.kN]).buffer) + ]); + const bg = this._bg( + this.pipes.rmsNormQkvRope, + [ + hiddenBuf, + inputNormBuf, + packed.w, + packed.scale, + packed.bias, + this.ropeCos, + this.ropeSin, + qBuf, + kBuf, + vBuf + ] + ); + this._dispatch( + enc, + this.pipes.rmsNormQkvRope, + bg, + gx, + Math.ceil(totalPairs / gx), + `fusedQkvRope:${totalPairs}x${packed.K}`, + meta + ); + } + gateUpSiluGemv4(enc, xBuf, packed, yBuf, L) { + const gate = this.q4[L.gate.weight], up = this.q4[L.up.weight]; + const gateMod = this.lora?.modules?.[L.gate.loraKey]; + const upMod = this.lora?.modules?.[L.up.loraKey]; + if (gateMod) this._loraA(enc, xBuf, gate, gateMod, this.s.loraD, L.gate.loraKey, "loraA:gate"); + if (upMod) this._loraA(enc, xBuf, up, upMod, this.s.loraD2, L.up.loraKey, "loraA:up"); + const gx = Math.min(packed.N, 65535); + const imm = this._gateUpImmediate(packed, gx, gateMod, upMod); + const bg = this._bgCached( + this.pipes.gateUpSiluGemv4, + [ + xBuf, + packed.w, + packed.scale, + yBuf, + this.s.loraD, + gateMod ? gateMod.B : this.s.dummy, + this.s.loraD2, + upMod ? upMod.B : this.s.dummy + ], + `gu:${L.index}:${this._loraEpoch}:${gateMod ? 1 : 0}:${upMod ? 1 : 0}`, + { sensitive: !!(gateMod || upMod) } + ); + this._dispatch(enc, this.pipes.gateUpSiluGemv4, bg, gx, Math.ceil(packed.N / gx), `gu:${packed.N}x${packed.K}`, imm); + } + rms(enc, xBuf, gBuf, yBuf, K) { + const imm = new Float32Array([K, this.cfg.rmsNormEps]); + const useF16 = this.usingF16() && this.pipes.rmsF16; + const pipe = useF16 ? this.pipes.rmsF16 : this.pipes.rms; + const key = `rms:${K}${useF16 ? ":f16" : ""}`; + this._dispatch(enc, pipe, this._bgCached(pipe, [xBuf, gBuf, yBuf], key), 1, 1, useF16 ? "rmsF16" : "rms", imm); + } + rope(enc, xBuf, pos, nHeads) { + const useF16 = this.usingF16() && this.pipes.ropeF16; + const pipe = useF16 ? this.pipes.ropeF16 : this.pipes.rope; + this._dispatch( + enc, + pipe, + this._bg(pipe, [ + xBuf, + this.ropeCos, + this.ropeSin + ]), + Math.ceil(nHeads * (this.cfg.headDim / 2) / 256), + 1, + useF16 ? "ropeF16" : "rope", + new Uint32Array([nHeads, this.cfg.headDim, pos]) + ); + } + ropeQK(enc, qBuf, kBuf, pos) { + const c = this.cfg; + const pairs = (c.numHeads + c.numKVHeads) * (c.headDim / 2); + const useF16 = this.usingF16() && this.pipes.ropeQKF16; + const pipe = useF16 ? this.pipes.ropeQKF16 : this.pipes.ropeQK; + this._dispatch( + enc, + pipe, + this._bg(pipe, [ + qBuf, + kBuf, + this.ropeCos, + this.ropeSin + ]), + Math.ceil(pairs / 256), + 1, + useF16 ? "ropeQKF16" : "ropeQK", + new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, pos]) + ); + } + attn(enc, qBuf, kc, vc, oBuf, ctx) { + const c = this.cfg, S = this.s; + const nsplit = Math.ceil(ctx / this.CHUNK); + const useF16P = this.usingF16() && this.pipes.attnPF16; + const pipeP = useF16P ? this.pipes.attnPF16 : this.pipes.attnP; + const bgP = this._bg(pipeP, [ + qBuf, + kc, + vc, + S.pm, + S.pz, + S.po + ]); + const immP = new Uint32Array([c.numHeads, c.numKVHeads, ctx, c.headDim, nsplit, this.CHUNK]); + this._dispatch(enc, pipeP, bgP, c.numHeads, nsplit, useF16P ? "attnPF16" : "attnP", immP); + const useF16C = this.usingF16() && this.pipes.attnCF16; + const pipeC = useF16C ? this.pipes.attnCF16 : this.pipes.attnC; + const bgC = this._bg(pipeC, [ + S.pm, + S.pz, + S.po, + oBuf + ]); + const immC = new Uint32Array([c.numHeads, c.headDim, nsplit, 0]); + this._dispatch(enc, pipeC, bgC, c.numHeads, 1, useF16C ? "attnCF16" : "attnC", immC); + } + // Decode one token at absolute position `pos`. Writes logits to s.logits. Returns nothing. + step(enc, tokenId, pos) { + const c = this.cfg, S = this.s, hd = c.headDim, kvd = c.numKVHeads * hd; + for (let i = 0; i < c.numLayers; i++) { + const L = this.plan.layers[i]; + const hasQkvLora = this.lora && (this.lora.modules[L.q.loraKey] || this.lora.modules[L.k.loraKey] || this.lora.modules[L.v.loraKey]); + if (this.features.fuseRMSNormQKVRoPE && !hasQkvLora && !this.features.actQuant) { + this.rmsNormQkvRope(enc, S.hidden, i, pos); + } else { + this.rms(enc, S.hidden, this.bufs[L.inputNorm], S.normed, c.hiddenSize); + if (this.features.actQuant) { + this.dynQuant(enc, S.normed, S.x_q, S.scale_x, c.hiddenSize); + this.qkvGemv4W4A8(enc, S.normed, S.x_q, S.scale_x, this.qkv[L.index], S.q, S.k, S.v, L); + } else { + if (!hasQkvLora && this.features.fuseQKV) { + this.fusedRmsQkvRope(enc, S.hidden, this.bufs[L.inputNorm], this.qkv[L.index], S.q, S.k, S.v, pos, L); + } else if (this.features.fuseQKV) { + this.qkvGemv4(enc, S.normed, this.qkv[L.index], S.q, S.k, S.v, L); + if (this.features.fuseRoPE) this.ropeQK(enc, S.q, S.k, pos); + else { + this.rope(enc, S.q, pos, c.numHeads); + this.rope(enc, S.k, pos, c.numKVHeads); + } + } else { + this.gemv4(enc, S.normed, this.q4[L.q.weight], S.q, this.bufs[L.q.bias], L.q.loraKey); + this.gemv4(enc, S.normed, this.q4[L.k.weight], S.k, this.bufs[L.k.bias], L.k.loraKey); + this.gemv4(enc, S.normed, this.q4[L.v.weight], S.v, this.bufs[L.v.bias], L.v.loraKey); + if (this.features.fuseRoPE) this.ropeQK(enc, S.q, S.k, pos); + else { + this.rope(enc, S.q, pos, c.numHeads); + this.rope(enc, S.k, pos, c.numKVHeads); + } + } + } + } + if (this.features.pagedAttention) { + this.writeKvPage(enc, S.k, S.v, this.kc[i], this.vc[i], pos, i); + } else { + enc.copyBufferToBuffer(S.k, 0, this.kc[i], pos * kvd * 4, kvd * 4); + enc.copyBufferToBuffer(S.v, 0, this.vc[i], pos * kvd * 4, kvd * 4); + } + if (this.features.pagedAttention) { + this.attnPaged(enc, S.q, this.kc[i], this.vc[i], S.attn, pos + 1); + } else { + this.attn(enc, S.q, this.kc[i], this.vc[i], S.attn, pos + 1); + } + if (this.features.actQuant) { + this.dynQuant(enc, S.attn, S.x_q, S.scale_x, c.hiddenSize); + if (this.features.fuseResidual) { + this.gemv4AddW4A8(enc, S.attn, S.x_q, S.scale_x, this.q4[L.o.weight], S.hidden, null, L.o.loraKey); + } else { + this.gemv4W4A8(enc, S.attn, S.x_q, S.scale_x, this.q4[L.o.weight], S.tmp, null, L.o.loraKey); + this._addInto(enc, S.hidden, S.tmp, c.hiddenSize); + } + } else { + if (this.features.fuseResidual) this.gemv4Add(enc, S.attn, this.q4[L.o.weight], S.hidden, null, L.o.loraKey); + else { + this.gemv4(enc, S.attn, this.q4[L.o.weight], S.tmp, null, L.o.loraKey); + this._addInto(enc, S.hidden, S.tmp, c.hiddenSize); + } + } + this.rms(enc, S.hidden, this.bufs[L.postAttentionNorm], S.normed, c.hiddenSize); + if (this.features.actQuant) { + this.dynQuant(enc, S.normed, S.x_q, S.scale_x, c.hiddenSize); + this.gateUpSiluGemv4W4A8(enc, S.normed, S.x_q, S.scale_x, this.gateUp[L.index], S.tmp, L); + } else { + if (this.features.fuseMLP) { + this.gateUpSiluGemv4(enc, S.normed, this.gateUp[L.index], S.tmp, L); + } else { + this.gemv4(enc, S.normed, this.q4[L.gate.weight], S.tmp, null, L.gate.loraKey); + this.gemv4(enc, S.normed, this.q4[L.up.weight], S.tmp2, null, L.up.loraKey); + this._siluMul(enc, S.tmp, S.tmp2, c.intermediateSize); + } + } + if (this.features.actQuant) { + this.dynQuant(enc, S.tmp, S.x_q, S.scale_x, c.intermediateSize); + if (this.features.fuseResidual) { + this.gemv4AddW4A8(enc, S.tmp, S.x_q, S.scale_x, this.q4[L.down.weight], S.hidden, null, L.down.loraKey); + } else { + this.gemv4W4A8(enc, S.tmp, S.x_q, S.scale_x, this.q4[L.down.weight], S.normed, null, L.down.loraKey); + this._addInto(enc, S.hidden, S.normed, c.hiddenSize); + } + } else { + if (this.features.fuseResidual) + this.gemv4Add(enc, S.tmp, this.q4[L.down.weight], S.hidden, null, L.down.loraKey); + else { + this.gemv4(enc, S.tmp, this.q4[L.down.weight], S.normed, null, L.down.loraKey); + this._addInto(enc, S.hidden, S.normed, c.hiddenSize); + } + } + } + this.rms(enc, S.hidden, this.bufs[this.plan.finalNorm.name], S.normed, c.hiddenSize); + this.gemv(enc, S.normed, this.q[this.plan.embed.name], S.logits, null, null); + } + _addInto(enc, yBuf, aBuf, n) { + const imm = new Uint32Array([n]); + const useF16 = this.usingF16() && this.pipes.addF16; + const pipe = useF16 ? this.pipes.addF16 : this.pipes.add; + const bg = this._bgCached(pipe, [aBuf, yBuf], `add:${n}${useF16 ? ":f16" : ""}`); + const wg = pipe.__wg || 256; + this._dispatch(enc, pipe, bg, Math.min(Math.ceil(n / wg), 65535), 1, useF16 ? "addF16" : "add", imm); + } + _siluMul(enc, gateBuf, upBuf, n) { + const imm = new Uint32Array([n]); + const useF16 = this.usingF16() && this.pipes.siluF16; + const pipe = useF16 ? this.pipes.siluF16 : this.pipes.silu; + const bg = this._bgCached(pipe, [gateBuf, upBuf], `silu:${n}${useF16 ? ":f16" : ""}`); + const wg = pipe.__wg || 256; + this._dispatch(enc, pipe, bg, Math.min(Math.ceil(n / wg), 65535), 1, useF16 ? "siluF16" : "silu", imm); + } + embedRow(enc, id) { + const e = this.q[this.plan.embed.name]; + const imm = new Uint32Array([id, this.cfg.hiddenSize]); + this._dispatch( + enc, + this.pipes.embed, + this._bg(this.pipes.embed, [e.w, e.scale, this.s.hidden]), + Math.ceil(this.cfg.hiddenSize / 256), + 1, + "embed", + imm + ); + } + async argmaxLogits() { + if (this._argmaxReadBusy) + throw new Error("argmaxLogits() is already in flight; concurrent generation is not supported"); + this._argmaxReadBusy = true; + const enc = this.dev.createCommandEncoder(); + const n = this.cfg.vocabSize || 0; + this._dispatch( + enc, + this.pipes.argmax, + this._bgCached(this.pipes.argmax, [this.s.logits, this.s.amax], "argmax"), + 1, + 1, + "argmax", + new Uint32Array([n]) + ); + enc.copyBufferToBuffer(this.s.amax, 0, this.argmaxRead, 0, 4); + this.dev.queue.submit([enc.finish()]); + if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); + try { + await this.argmaxRead.mapAsync(GPUMapMode.READ); + const id = new Uint32Array(this.argmaxRead.getMappedRange())[0]; + this.argmaxRead.unmap(); + return id; + } finally { + this._argmaxReadBusy = false; + } + } + // Convenience for numeric comparison harnesses (Phase 3 f16 eval etc.). + // Returns a fresh Float32Array copy of the current final logits buffer. + async readLogits() { + const n = this.cfg.vocabSize; + if (!this._logitsRead) { + this._logitsRead = this._buf(n * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + } + const enc = this.dev.createCommandEncoder(); + enc.copyBufferToBuffer(this.s.logits, 0, this._logitsRead, 0, n * 4); + this.dev.queue.submit([enc.finish()]); + if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); + await this._logitsRead.mapAsync(GPUMapMode.READ); + const out = new Float32Array(this._logitsRead.getMappedRange()).slice(); + this._logitsRead.unmap(); + return out; + } + async topKLogits(k = this.samplingTopK) { + if (this._topKReadBusy) throw new Error("topKLogits() is already in flight; concurrent sampling is not supported"); + this._topKReadBusy = true; + try { + k = Math.min(Math.max(1, Math.floor(k)), this.maxSamplingTopK, this.cfg.vocabSize); + const enc = this.dev.createCommandEncoder(); + for (let i = 0; i < k; i++) { + const imm = new Uint32Array([this.cfg.vocabSize, i]); + this._dispatch( + enc, + this.pipes.topkSelect, + this._bgCached(this.pipes.topkSelect, [this.s.logits, this.s.sampleIds, this.s.sampleVals], `topk:${i}`), + 1, + 1, + "topk", + imm + ); + } + enc.copyBufferToBuffer(this.s.sampleIds, 0, this.sampleIdsRead, 0, k * 4); + enc.copyBufferToBuffer(this.s.sampleVals, 0, this.sampleValsRead, 0, k * 4); + this.dev.queue.submit([enc.finish()]); + await Promise.all([this.sampleIdsRead.mapAsync(GPUMapMode.READ), this.sampleValsRead.mapAsync(GPUMapMode.READ)]); + const ids = Array.from(new Uint32Array(this.sampleIdsRead.getMappedRange(), 0, k)); + const vals = Array.from(new Float32Array(this.sampleValsRead.getMappedRange(), 0, k)); + return ids.map((id, i) => ({ id, logit: vals[i] })); + } finally { + if (this.sampleIdsRead.mapState !== "unmapped") this.sampleIdsRead.unmap(); + if (this.sampleValsRead.mapState !== "unmapped") this.sampleValsRead.unmap(); + this._topKReadBusy = false; + } + } + // Phase 5: GPU-resident sampling (pure-GPU top-k + sample chaining). + // Runs the iterative top-k selection dispatches directly into the GPU sampleIds/sampleVals + // buffers, then immediately chains the SAMPLE_TOPK kernel in the same submission. + // Only a single u32 (the chosen token) is ever read back from the GPU. + // This eliminates the previous k-value readbacks for the sampling path. + async sampleToken(temp = 1, r = typeof Math !== "undefined" ? Math.random() : 0.5) { + if (this._topKReadBusy) throw new Error("sampleToken: top-k selection already in flight"); + this._topKReadBusy = true; + const k = Math.min(this.samplingTopK, this.maxSamplingTopK, this.cfg.vocabSize); + try { + const enc = this.dev.createCommandEncoder(); + for (let i = 0; i < k; i++) { + const imm2 = new Uint32Array([this.cfg.vocabSize, i]); + this._dispatch( + enc, + this.pipes.topkSelect, + this._bgCached(this.pipes.topkSelect, [this.s.logits, this.s.sampleIds, this.s.sampleVals], `topk:${i}`), + 1, + 1, + "topk", + imm2 + ); + } + const bg = this._bg(this.pipes.sampleTopK, [ + this.s.sampleIds, + this.s.sampleVals, + this.s.sampled + ]); + const imm = new Uint32Array(4); + imm[0] = k; + const f322 = new Float32Array(imm.buffer); + f322[2] = temp > 0 ? temp : 1; + f322[3] = Math.max(0, Math.min(1, r)); + this._dispatch(enc, this.pipes.sampleTopK, bg, 1, 1, "sampleTopK", imm); + enc.copyBufferToBuffer(this.s.sampled, 0, this.sampledRead, 0, 4); + this.dev.queue.submit([enc.finish()]); + if (this.dev.queue.onSubmittedWorkDone) await this.dev.queue.onSubmittedWorkDone(); + await this.sampledRead.mapAsync(GPUMapMode.READ); + const id = new Uint32Array(this.sampledRead.getMappedRange())[0]; + this.sampledRead.unmap(); + return id; + } finally { + this._topKReadBusy = false; + } + } + // Run one token end-to-end (embed + step) and submit. + token(id, pos) { + this._resetUni(); + const enc = this.dev.createCommandEncoder(); + this.embedRow(enc, id); + this.step(enc, id, pos); + this.dev.queue.submit([enc.finish()]); + } + // embed the token id held in s.amax (GPU-resident, from a prior argmax) + embedFromBuf(enc) { + const e = this.q[this.plan.embed.name]; + const imm = new Uint32Array([this.cfg.hiddenSize]); + this._dispatch( + enc, + this.pipes.embedBuf, + this._bgCached(this.pipes.embedBuf, [e.w, e.scale, this.s.hidden, this.s.amax], "embedBuf"), + Math.ceil(this.cfg.hiddenSize / 256), + 1, + "embed", + imm + ); + } + // argmax(logits) -> s.amax, within the given encoder (no submit/readback) + argmaxInto(enc) { + const n = this.cfg.vocabSize || 0; + this._dispatch( + enc, + this.pipes.argmax, + this._bgCached(this.pipes.argmax, [this.s.logits, this.s.amax], "argmax"), + 1, + 1, + "argmax", + new Uint32Array([n]) + ); + } + // GPU-resident batched GREEDY decode only: chains embed->step->argmax for K + // tokens in ONE submit, reads back K ids once, and checks stop tokens only + // after readback. It assumes s.amax already holds the current token id to + // embed. Do not use for sampled decoding; sampled tokens must be written by + // the CPU/GPU sampler one step at a time. + async decodeBatch(startPos, K) { + K = Math.min(K, this.decodeBatchCapacity, this.maxCtx - startPos); + if (K <= 0) return []; + this._resetUni(); + const enc = this.dev.createCommandEncoder(); + for (let k = 0; k < K; k++) { + this.embedFromBuf(enc); + this.step(enc, 0, startPos + k); + this.argmaxInto(enc); + enc.copyBufferToBuffer(this.s.amax, 0, this.s.idsBuf, k * 4, 4); + } + enc.copyBufferToBuffer(this.s.idsBuf, 0, this.idsRead, 0, K * 4); + this.dev.queue.submit([enc.finish()]); + await this.idsRead.mapAsync(GPUMapMode.READ); + const ids = Array.from(new Uint32Array(this.idsRead.getMappedRange(), 0, K)); + this.idsRead.unmap(); + return ids; + } + async decodeGreedyBatch(startPos, K) { + return this.decodeBatch(startPos, K); + } + // ---- PREFILL (T>1): process the whole prompt at once via tiled GEMM. If a LoRA + // adapter has the projection module, add its batched delta immediately after base GEMM. + gemm4(enc, aBuf, q, yBuf, T, biasBuf, moduleKey) { + const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); + const bg = this._bg(this.pipes.gemm4, [aBuf, q.w, q.scale, biasBuf || this.s.dummy, yBuf]); + this._dispatch(enc, this.pipes.gemm4, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4", imm); + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); + } + gemm4AddT(enc, aBuf, q, yBuf, T, biasBuf, moduleKey) { + const imm = new Uint32Array([q.K, q.N, T, q.gpr, biasBuf ? 1 : 0, 0, 0, 0]); + const bg = this._bg(this.pipes.gemm4AddT, [aBuf, q.w, q.scale, biasBuf || this.s.dummy, yBuf]); + this._dispatch(enc, this.pipes.gemm4AddT, bg, Math.ceil(q.N / 64), Math.ceil(T / 16), "gemm4AddT", imm); + const mod = this.lora?.modules?.[moduleKey]; + if (mod) this.loraBatchDelta(enc, aBuf, yBuf, q, T, mod, moduleKey); + } + loraBatchDelta(enc, xBuf, yBuf, q, T, mod, moduleKey) { + if (this.debugCapture) console.log("VWG loraBatchDelta: " + moduleKey + " mod=" + !!mod); + const imm = new Uint32Array([q.K, mod.rank, T, 0]); + const bgA = this._bg(this.pipes.loraABatch, [xBuf, mod.A, this.sT.loraD]); + this._dispatch(enc, this.pipes.loraABatch, bgA, mod.rank, T, "loraA:T", imm); + if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj") { + enc.copyBufferToBuffer(xBuf, 0, this.debugBufs.xBat, 0, T * q.K * 4); + enc.copyBufferToBuffer(this.sT.loraD, 0, this.debugBufs.dBat, 0, T * mod.rank * 4); + } + const totalGroups = Math.ceil(T * q.N / 256); + let gx = totalGroups; + let gy = 1; + if (gx > 65535) { + gx = 256; + gy = Math.ceil(totalGroups / 256); + } + const meta = new ArrayBuffer(32); + const dv = new DataView(meta); + dv.setUint32(0, T, true); + dv.setUint32(4, q.N, true); + dv.setUint32(8, mod.rank, true); + dv.setUint32(12, gx, true); + dv.setFloat32(16, mod.scale, true); + const bgB = this._bg(this.pipes.loraBAddT, [this.sT.loraD, mod.B, yBuf]); + this._dispatch(enc, this.pipes.loraBAddT, bgB, gx, gy, "loraB:T", new Uint8Array(meta)); + if (this.debugCapture && moduleKey === "layers.0.self_attn.q_proj") { + enc.copyBufferToBuffer(yBuf, 0, this.debugBufs.yBat, 0, T * q.N * 4); + this.debugCaptured = true; + } + } + rmsT(enc, xBuf, gBuf, yBuf, T, K) { + const imm = new Float32Array([K, this.cfg.rmsNormEps]); + const useF16 = this.usingF16() && this.pipes.rmsTF16; + const pipe = useF16 ? this.pipes.rmsTF16 : this.pipes.rmsT; + this._dispatch(enc, pipe, this._bg(pipe, [xBuf, gBuf, yBuf]), T, 1, useF16 ? "rmsTF16" : "rmsT", imm); + } + ropeT(enc, xBuf, T, nHeads, pos0 = 0) { + const hd = this.cfg.headDim; + const imm = new Uint32Array([nHeads, hd, T, pos0]); + const useF16 = this.usingF16() && this.pipes.ropeTF16; + const pipe = useF16 ? this.pipes.ropeTF16 : this.pipes.ropeT; + this._dispatch( + enc, + pipe, + this._bg(pipe, [xBuf, this.ropeCos, this.ropeSin]), + Math.ceil(T * nHeads * (hd / 2) / 256), + 1, + useF16 ? "ropeTF16" : "ropeT", + imm + ); + } + attnPrefill(enc, qBuf, kc, vc, oBuf, T, qStart = 0, ctx = T) { + const c = this.cfg; + if (this.features.prefillAttention === "block" || qStart !== 0 || ctx !== T) { + const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T, qStart, ctx, 0, 0]); + this._dispatch( + enc, + this.pipes.attnPrefillBlock, + this._bg(this.pipes.attnPrefillBlock, [qBuf, kc, vc, oBuf]), + c.numHeads, + Math.ceil(T / 4), + "attnPrefillBlock", + imm + ); + } else { + const imm = new Uint32Array([c.numHeads, c.numKVHeads, c.headDim, T]); + this._dispatch( + enc, + this.pipes.attnPrefill, + this._bg(this.pipes.attnPrefill, [qBuf, kc, vc, oBuf]), + c.numHeads, + Math.ceil(T / 4), + "attnPrefill", + imm + ); + } + } + // (re)allocate prefill scratch sized to T (grows as needed; only paid when prefilling). + _ensurePrefillScratch(T, loraRank = 0, idsCap = T) { + if (this.sTcap >= T && (this.sTLoraRank || 0) >= loraRank && (this.sTidsCap || 0) >= idsCap) return; + const need = this.estimatePrefillScratchBytes(T, loraRank); + if (this.opts.maxPrefillScratchBytes && need > this.opts.maxPrefillScratchBytes) { + throw new Error( + `prefill scratch ${Math.ceil(need / 1048576)}MiB exceeds maxPrefillScratchBytes; lower maxPrefillT or use shorter prompt chunks` + ); + } + if (this.sT) for (const k in this.sT) this.sT[k].destroy(); + const c = this.cfg, H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize; + this.sT = { + hidden: this._buf(T * H * 4), + normed: this._buf(T * H * 4), + q: this._buf(T * qd * 4), + k: this._buf(T * kvd * 4), + v: this._buf(T * kvd * 4), + attn: this._buf(T * qd * 4), + tmp: this._buf(T * I * 4), + tmp2: this._buf(T * I * 4), + ids: this._buf(idsCap * 4), + loraD: this._buf(Math.max(1, T * Math.max(1, loraRank)) * 4), + x_q: this._buf(T * Math.max(H, I) * 4), + scale_x: this._buf(T * Math.max(H, I) / 128 * 4) + }; + this.sTcap = T; + this.sTLoraRank = loraRank; + this.sTidsCap = idsCap; + } + _activeMaxLoraRank() { + let rank = 0; + const mods = this.lora?.modules; + if (!mods) return 0; + for (const key of Object.keys(mods)) rank = Math.max(rank, mods[key].rank || 0); + return rank; + } + // Prefill the prompt (positions 0..T-1). Leaves last-row logits in s.logits and the + // KV cache populated, so decode continues from pos=T. T must be <= maxPrefillT. + prefillBatch(ids) { + const T = ids.length; + if (T > this.maxPrefillT) throw new Error(`prompt ${T} > maxPrefillT ${this.maxPrefillT}`); + if (T > this.maxCtx) throw new Error(`prompt ${T} > maxCtx ${this.maxCtx}`); + const chunk = this.features.prefillChunkSize; + if (chunk > 0 && T > chunk) return this._prefillChunked(ids, chunk); + return this._prefillFull(ids); + } + _prefillFull(ids) { + const c = this.cfg, S = this.s, T = ids.length, hd = c.headDim, kvd = c.numKVHeads * hd, H = c.hiddenSize; + this._ensurePrefillScratch(T, this._activeMaxLoraRank()); + const ST = this.sT; + this._resetUni(); + this.dev.queue.writeBuffer(ST.ids, 0, new Uint32Array(ids)); + const enc = this.dev.createCommandEncoder(); + const e = this.q[this.plan.embed.name]; + const imm = new Uint32Array([T, H, 0, 0]); + this._dispatch( + enc, + this.pipes.embedT, + this._bg(this.pipes.embedT, [e.w, e.scale, ST.hidden, ST.ids]), + Math.min(Math.ceil(T * H / 256), 65535), + 1, + "embedT", + imm + ); + for (let i = 0; i < c.numLayers; i++) { + const L = this.plan.layers[i]; + this.rmsT(enc, ST.hidden, this.bufs[L.inputNorm], ST.normed, T, H); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.q.weight], + ST.q, + T, + this.bufs[L.q.bias], + L.q.loraKey + ); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.k.weight], + ST.k, + T, + this.bufs[L.k.bias], + L.k.loraKey + ); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.v.weight], + ST.v, + T, + this.bufs[L.v.bias], + L.v.loraKey + ); + } else { + this.gemm4(enc, ST.normed, this.q4[L.q.weight], ST.q, T, this.bufs[L.q.bias], L.q.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.k.weight], ST.k, T, this.bufs[L.k.bias], L.k.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.v.weight], ST.v, T, this.bufs[L.v.bias], L.v.loraKey); + } + this.ropeT(enc, ST.q, T, c.numHeads); + this.ropeT(enc, ST.k, T, c.numKVHeads); + if (this.features.pagedAttention) { + this.writeKvPageBatch(enc, ST.k, ST.v, this.kc[i], this.vc[i], T, 0, i); + } else { + enc.copyBufferToBuffer(ST.k, 0, this.kc[i], 0, T * kvd * 4); + enc.copyBufferToBuffer(ST.v, 0, this.vc[i], 0, T * kvd * 4); + } + if (this.features.pagedAttention) { + this.attnPrefillPaged(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, 0, T); + } else { + this.attnPrefill(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, 0, T); + } + if (this.features.actQuant) { + this.dynQuantT(enc, ST.attn, ST.x_q, ST.scale_x, H, T); + if (this.features.fuseResidual) { + this.gemm4AddTW4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); + } else { + this.gemm4W4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); + this._addInto(enc, ST.hidden, ST.tmp, T * H); + } + } else { + if (this.features.fuseResidual) + this.gemm4AddT(enc, ST.attn, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); + else { + this.gemm4(enc, ST.attn, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); + this._addInto(enc, ST.hidden, ST.tmp, T * H); + } + } + this.rmsT(enc, ST.hidden, this.bufs[L.postAttentionNorm], ST.normed, T, H); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); + this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); + this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); + } else { + this.gemm4(enc, ST.normed, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); + } + this._siluMul(enc, ST.tmp, ST.tmp2, T * c.intermediateSize); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.tmp, ST.x_q, ST.scale_x, c.intermediateSize, T); + if (this.features.fuseResidual) { + this.gemm4AddTW4A8( + enc, + ST.tmp, + ST.x_q, + ST.scale_x, + this.q4[L.down.weight], + ST.hidden, + T, + null, + L.down.loraKey + ); + } else { + this.gemm4W4A8(enc, ST.tmp, ST.x_q, ST.scale_x, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); + this._addInto(enc, ST.hidden, ST.normed, T * H); + } + } else { + if (this.features.fuseResidual) + this.gemm4AddT(enc, ST.tmp, this.q4[L.down.weight], ST.hidden, T, null, L.down.loraKey); + else { + this.gemm4(enc, ST.tmp, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); + this._addInto(enc, ST.hidden, ST.normed, T * H); + } + } + } + enc.copyBufferToBuffer(ST.hidden, (T - 1) * H * 4, S.hidden, 0, H * 4); + this.rms(enc, S.hidden, this.bufs[this.plan.finalNorm.name], S.normed, H); + this.gemv(enc, S.normed, this.q[this.plan.embed.name], S.logits, null, null); + this.dev.queue.submit([enc.finish()]); + } + _prefillChunked(ids, chunkSize) { + const c = this.cfg, S = this.s, H = c.hiddenSize, hd = c.headDim, kvd = c.numKVHeads * hd; + const T = ids.length; + this._ensurePrefillScratch(Math.min(chunkSize, T), this._activeMaxLoraRank(), T); + const ST = this.sT; + this._resetUni(); + this.dev.queue.writeBuffer(ST.ids, 0, new Uint32Array(ids)); + const enc = this.dev.createCommandEncoder(); + const e = this.q[this.plan.embed.name]; + for (let off = 0; off < T; off += chunkSize) { + const end = Math.min(T, off + chunkSize); + const CT = end - off; + this._dispatch( + enc, + this.pipes.embedT, + this._bg(this.pipes.embedT, [e.w, e.scale, ST.hidden, ST.ids]), + Math.min(Math.ceil(CT * H / 256), 65535), + 1, + "embedT", + new Uint32Array([CT, H, off, 0]) + ); + for (let i = 0; i < c.numLayers; i++) { + const L = this.plan.layers[i]; + this.rmsT(enc, ST.hidden, this.bufs[L.inputNorm], ST.normed, CT, H); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, CT); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.q.weight], + ST.q, + CT, + this.bufs[L.q.bias], + L.q.loraKey + ); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.k.weight], + ST.k, + CT, + this.bufs[L.k.bias], + L.k.loraKey + ); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.v.weight], + ST.v, + CT, + this.bufs[L.v.bias], + L.v.loraKey + ); + } else { + this.gemm4(enc, ST.normed, this.q4[L.q.weight], ST.q, CT, this.bufs[L.q.bias], L.q.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.k.weight], ST.k, CT, this.bufs[L.k.bias], L.k.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.v.weight], ST.v, CT, this.bufs[L.v.bias], L.v.loraKey); + } + this.ropeT(enc, ST.q, CT, c.numHeads, off); + this.ropeT(enc, ST.k, CT, c.numKVHeads, off); + if (this.features.pagedAttention) { + this.writeKvPageBatch(enc, ST.k, ST.v, this.kc[i], this.vc[i], CT, off, i); + } else { + enc.copyBufferToBuffer(ST.k, 0, this.kc[i], off * kvd * 4, CT * kvd * 4); + enc.copyBufferToBuffer(ST.v, 0, this.vc[i], off * kvd * 4, CT * kvd * 4); + } + if (this.features.pagedAttention) { + this.attnPrefillPaged(enc, ST.q, this.kc[i], this.vc[i], ST.attn, CT, off, end); + } else { + this.attnPrefill(enc, ST.q, this.kc[i], this.vc[i], ST.attn, CT, off, end); + } + if (this.features.actQuant) { + this.dynQuantT(enc, ST.attn, ST.x_q, ST.scale_x, H, CT); + if (this.features.fuseResidual) { + this.gemm4AddTW4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.hidden, CT, null, L.o.loraKey); + } else { + this.gemm4W4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.tmp, CT, null, L.o.loraKey); + this._addInto(enc, ST.hidden, ST.tmp, CT * H); + } + } else { + if (this.features.fuseResidual) + this.gemm4AddT(enc, ST.attn, this.q4[L.o.weight], ST.hidden, CT, null, L.o.loraKey); + else { + this.gemm4(enc, ST.attn, this.q4[L.o.weight], ST.tmp, CT, null, L.o.loraKey); + this._addInto(enc, ST.hidden, ST.tmp, CT * H); + } + } + this.rmsT(enc, ST.hidden, this.bufs[L.postAttentionNorm], ST.normed, CT, H); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, CT); + this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.gate.weight], ST.tmp, CT, null, L.gate.loraKey); + this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.up.weight], ST.tmp2, CT, null, L.up.loraKey); + } else { + this.gemm4(enc, ST.normed, this.q4[L.gate.weight], ST.tmp, CT, null, L.gate.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.up.weight], ST.tmp2, CT, null, L.up.loraKey); + } + this._siluMul(enc, ST.tmp, ST.tmp2, CT * c.intermediateSize); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.tmp, ST.x_q, ST.scale_x, c.intermediateSize, CT); + if (this.features.fuseResidual) { + this.gemm4AddTW4A8( + enc, + ST.tmp, + ST.x_q, + ST.scale_x, + this.q4[L.down.weight], + ST.hidden, + CT, + null, + L.down.loraKey + ); + } else { + this.gemm4W4A8( + enc, + ST.tmp, + ST.x_q, + ST.scale_x, + this.q4[L.down.weight], + ST.normed, + CT, + null, + L.down.loraKey + ); + this._addInto(enc, ST.hidden, ST.normed, CT * H); + } + } else { + if (this.features.fuseResidual) + this.gemm4AddT(enc, ST.tmp, this.q4[L.down.weight], ST.hidden, CT, null, L.down.loraKey); + else { + this.gemm4(enc, ST.tmp, this.q4[L.down.weight], ST.normed, CT, null, L.down.loraKey); + this._addInto(enc, ST.hidden, ST.normed, CT * H); + } + } + } + if (end === T) { + enc.copyBufferToBuffer(ST.hidden, (CT - 1) * H * 4, S.hidden, 0, H * 4); + } + } + this.rms(enc, S.hidden, this.bufs[this.plan.finalNorm.name], S.normed, H); + this.gemv(enc, S.normed, this.q[this.plan.embed.name], S.logits, null, null); + this.dev.queue.submit([enc.finish()]); + } + async speculativeDecode(draftModel, promptIds, maxNewTokens, onToken) { + await this.prefillBatch(promptIds); + await draftModel.prefillBatch(promptIds); + let currentPos = promptIds.length; + const generatedIds = []; + let nextToken = await this.argmaxLogits(); + generatedIds.push(nextToken); + if (onToken) onToken(nextToken); + draftModel.dev.queue.writeBuffer(draftModel.s.amax, 0, new Uint32Array([nextToken])); + this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([nextToken])); + const gamma = 4; + while (generatedIds.length < maxNewTokens) { + const draftCandidates = await draftModel.decodeBatch(currentPos, gamma); + if (draftCandidates.length === 0) break; + const T = draftCandidates.length; + this._resetUni(); + this._ensurePrefillScratch(T, this._activeMaxLoraRank()); + const ST = this.sT; + const c = this.cfg, H = c.hiddenSize, kvd = c.numKVHeads * c.headDim; + this.dev.queue.writeBuffer(ST.ids, 0, new Uint32Array(draftCandidates)); + const enc = this.dev.createCommandEncoder(); + const e = this.q[this.plan.embed.name]; + const embedUni = new Uint32Array([T, H, 0, 0]); + this._dispatch( + enc, + this.pipes.embedT, + this._bg(this.pipes.embedT, [e.w, e.scale, ST.hidden, ST.ids]), + Math.min(Math.ceil(T * H / 256), 65535), + 1, + "embedT", + embedUni + ); + for (let i = 0; i < c.numLayers; i++) { + const L = this.plan.layers[i]; + this.rmsT(enc, ST.hidden, this.bufs[L.inputNorm], ST.normed, T, H); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.q.weight], + ST.q, + T, + this.bufs[L.q.bias], + L.q.loraKey + ); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.k.weight], + ST.k, + T, + this.bufs[L.k.bias], + L.k.loraKey + ); + this.gemm4W4A8( + enc, + ST.normed, + ST.x_q, + ST.scale_x, + this.q4[L.v.weight], + ST.v, + T, + this.bufs[L.v.bias], + L.v.loraKey + ); + } else { + this.gemm4(enc, ST.normed, this.q4[L.q.weight], ST.q, T, this.bufs[L.q.bias], L.q.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.k.weight], ST.k, T, this.bufs[L.k.bias], L.k.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.v.weight], ST.v, T, this.bufs[L.v.bias], L.v.loraKey); + } + this.ropeT(enc, ST.q, T, c.numHeads, currentPos); + this.ropeT(enc, ST.k, T, c.numKVHeads, currentPos); + if (this.features.pagedAttention) { + this.writeKvPageBatch(enc, ST.k, ST.v, this.kc[i], this.vc[i], T, currentPos, i); + } else { + enc.copyBufferToBuffer(ST.k, 0, this.kc[i], currentPos * kvd * 4, T * kvd * 4); + enc.copyBufferToBuffer(ST.v, 0, this.vc[i], currentPos * kvd * 4, T * kvd * 4); + } + if (this.features.pagedAttention) { + this.attnPrefillPaged(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, currentPos, currentPos + T); + } else { + this.attnPrefill(enc, ST.q, this.kc[i], this.vc[i], ST.attn, T, currentPos, currentPos + T); + } + if (this.features.actQuant) { + this.dynQuantT(enc, ST.attn, ST.x_q, ST.scale_x, H, T); + if (this.features.fuseResidual) { + this.gemm4AddTW4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); + } else { + this.gemm4W4A8(enc, ST.attn, ST.x_q, ST.scale_x, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); + this._addInto(enc, ST.hidden, ST.tmp, T * H); + } + } else { + if (this.features.fuseResidual) + this.gemm4AddT(enc, ST.attn, this.q4[L.o.weight], ST.hidden, T, null, L.o.loraKey); + else { + this.gemm4(enc, ST.attn, this.q4[L.o.weight], ST.tmp, T, null, L.o.loraKey); + this._addInto(enc, ST.hidden, ST.tmp, T * H); + } + } + this.rmsT(enc, ST.hidden, this.bufs[L.postAttentionNorm], ST.normed, T, H); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.normed, ST.x_q, ST.scale_x, H, T); + this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); + this.gemm4W4A8(enc, ST.normed, ST.x_q, ST.scale_x, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); + } else { + this.gemm4(enc, ST.normed, this.q4[L.gate.weight], ST.tmp, T, null, L.gate.loraKey); + this.gemm4(enc, ST.normed, this.q4[L.up.weight], ST.tmp2, T, null, L.up.loraKey); + } + this._siluMul(enc, ST.tmp, ST.tmp2, T * c.intermediateSize); + if (this.features.actQuant) { + this.dynQuantT(enc, ST.tmp, ST.x_q, ST.scale_x, c.intermediateSize, T); + if (this.features.fuseResidual) { + this.gemm4AddTW4A8( + enc, + ST.tmp, + ST.x_q, + ST.scale_x, + this.q4[L.down.weight], + ST.hidden, + T, + null, + L.down.loraKey + ); + } else { + this.gemm4W4A8(enc, ST.tmp, ST.x_q, ST.scale_x, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); + this._addInto(enc, ST.hidden, ST.normed, T * H); + } + } else { + if (this.features.fuseResidual) + this.gemm4AddT(enc, ST.tmp, this.q4[L.down.weight], ST.hidden, T, null, L.down.loraKey); + else { + this.gemm4(enc, ST.tmp, this.q4[L.down.weight], ST.normed, T, null, L.down.loraKey); + this._addInto(enc, ST.hidden, ST.normed, T * H); + } + } + } + if (!this.s.logitsT || this.sTcap < T) { + if (this.s.logitsT) this.s.logitsT.destroy(); + this.s.logitsT = this._buf(T * c.vocabSize * 4); + if (this.logitsTRead) this.logitsTRead.destroy(); + this.logitsTRead = this._buf(T * c.vocabSize * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ); + } + for (let t = 0; t < T; t++) { + enc.copyBufferToBuffer(ST.hidden, t * H * 4, this.s.hidden, 0, H * 4); + this.rms(enc, this.s.hidden, this.bufs[this.plan.finalNorm.name], this.s.normed, H); + this.gemv(enc, this.s.normed, this.q[this.plan.embed.name], this.s.logits, null, null); + enc.copyBufferToBuffer(this.s.logits, 0, this.s.logitsT, t * c.vocabSize * 4, c.vocabSize * 4); + } + enc.copyBufferToBuffer(this.s.logitsT, 0, this.logitsTRead, 0, T * c.vocabSize * 4); + this.dev.queue.submit([enc.finish()]); + await this.logitsTRead.mapAsync(GPUMapMode.READ); + const logitsArray = new Float32Array(this.logitsTRead.getMappedRange()); + let acceptedCount = 0; + let targetToken = 0; + for (let t = 0; t < T; t++) { + let maxVal = -1e30; + let argmaxId = 0; + const offset = t * c.vocabSize; + for (let v = 0; v < c.vocabSize; v++) { + const l = logitsArray[offset + v]; + if (l > maxVal) { + maxVal = l; + argmaxId = v; + } + } + targetToken = argmaxId; + if (t < T) { + if (draftCandidates[t] === targetToken) { + acceptedCount++; + } else { + break; + } + } + } + this.logitsTRead.unmap(); + for (let a = 0; a < acceptedCount; a++) { + generatedIds.push(draftCandidates[a]); + if (onToken) onToken(draftCandidates[a]); + } + generatedIds.push(targetToken); + if (onToken) onToken(targetToken); + const nextPos = currentPos + acceptedCount + 1; + this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([targetToken])); + draftModel.dev.queue.writeBuffer(draftModel.s.amax, 0, new Uint32Array([targetToken])); + if (this.features.pagedAttention) { + this.pam.ensureBlocks(0, nextPos); + } + currentPos = nextPos; + } + return generatedIds; + } + // Simple high-level generation helper (Phase 5 wiring). + // If opts.sample === true, uses the GPU sampler (sampleToken) with given temp; + // otherwise falls back to argmax (greedy). + // This makes sampleToken part of the real generation path. + async generate(promptIds, maxNewTokens = 32, opts = {}) { + const doSample = !!opts.sample; + const temp = opts.temp != null && opts.temp > 0 ? opts.temp : 1; + await this.prefillBatch(promptIds); + const generatedIds = []; + let pos = promptIds.length; + let next = doSample ? await this.sampleToken(temp) : await this.argmaxLogits(); + generatedIds.push(next); + if (opts.onToken) opts.onToken(next); + this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([next])); + while (generatedIds.length < maxNewTokens) { + this._resetUni(); + const enc = this.dev.createCommandEncoder(); + this.embedFromBuf(enc); + this.step(enc, 0, pos); + this.dev.queue.submit([enc.finish()]); + next = doSample ? await this.sampleToken(temp) : await this.argmaxLogits(); + generatedIds.push(next); + if (opts.onToken) opts.onToken(next); + this.dev.queue.writeBuffer(this.s.amax, 0, new Uint32Array([next])); + pos += 1; + } + return generatedIds; + } + setupDebugCapture(T, K, rank, N) { + this.debugCapture = true; + this.debugT = T; + this.debugK = K; + this.debugRank = rank; + this.debugN = N; + this.debugStep = 0; + this.debugCaptured = false; + this.debugBufs = { + xSeq: this._buf(T * K * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), + dSeq: this._buf(T * rank * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), + ySeq: this._buf(T * N * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), + xBat: this._buf(T * K * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), + dBat: this._buf(T * rank * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ), + yBat: this._buf(T * N * 4, GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ) + }; + } + async readDebugCapture() { + this.debugCapture = false; + const bufs = this.debugBufs; + if (!bufs) return null; + await Promise.all([ + bufs.xSeq.mapAsync(GPUMapMode.READ), + bufs.dSeq.mapAsync(GPUMapMode.READ), + bufs.ySeq.mapAsync(GPUMapMode.READ), + bufs.xBat.mapAsync(GPUMapMode.READ), + bufs.dBat.mapAsync(GPUMapMode.READ), + bufs.yBat.mapAsync(GPUMapMode.READ) + ]); + const res = { + xSeq: new Float32Array(bufs.xSeq.getMappedRange()).slice(), + dSeq: new Float32Array(bufs.dSeq.getMappedRange()).slice(), + ySeq: new Float32Array(bufs.ySeq.getMappedRange()).slice(), + xBat: new Float32Array(bufs.xBat.getMappedRange()).slice(), + dBat: new Float32Array(bufs.dBat.getMappedRange()).slice(), + yBat: new Float32Array(bufs.yBat.getMappedRange()).slice() + }; + bufs.xSeq.unmap(); + bufs.xSeq.destroy(); + bufs.dSeq.unmap(); + bufs.dSeq.destroy(); + bufs.ySeq.unmap(); + bufs.ySeq.destroy(); + bufs.xBat.unmap(); + bufs.xBat.destroy(); + bufs.dBat.unmap(); + bufs.dBat.destroy(); + bufs.yBat.unmap(); + bufs.yBat.destroy(); + this.debugBufs = null; + return res; + } +}; +var PagedAttentionManager = class { + static { + __name(this, "PagedAttentionManager"); + } + constructor(maxCtx, pageSize = 16) { + this.pageSize = pageSize; + this.maxCtx = maxCtx; + this.maxBlocksPerSeq = Math.ceil(maxCtx / pageSize); + this.freeBlocks = []; + this.seqBlocks = /* @__PURE__ */ new Map(); + const totalBlocks = this.maxBlocksPerSeq * 4; + for (let i = 0; i < totalBlocks; i++) { + this.freeBlocks.push(i); + } + } + allocateSeq(seqId) { + this.seqBlocks.set(seqId, []); + } + freeSeq(seqId) { + const blocks = this.seqBlocks.get(seqId) || []; + this.freeBlocks.push(...blocks); + this.seqBlocks.delete(seqId); + } + ensureBlocks(seqId, numTokens) { + const neededBlocks = Math.ceil(numTokens / this.pageSize); + const blocks = this.seqBlocks.get(seqId); + if (!blocks) throw new Error(`Sequence ${seqId} not allocated`); + while (blocks.length < neededBlocks) { + if (this.freeBlocks.length === 0) { + const newBlock = blocks.length + 1e3; + this.freeBlocks.push(newBlock); + } + blocks.push(this.freeBlocks.pop()); + } + return blocks; + } + getBlockTableArray(seqId) { + const blocks = this.seqBlocks.get(seqId) || []; + const arr = new Uint32Array(this.maxBlocksPerSeq); + arr.set(blocks); + return arr; + } +}; + +// src/services/device_service.js +async function initWebGPUDevice({ log: log2 = /* @__PURE__ */ __name(() => { +}, "log") } = {}) { + log2("requesting WebGPU device\u2026"); + const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" }); + if (!adapter) throw new Error("no WebGPU adapter (use a WebGPU-capable browser)"); + if (!navigator.gpu.wgslLanguageFeatures?.has("immediate_address_space")) + throw new Error("WGSL immediate_address_space is not available (upgrade to Chrome 149+)"); + if (!adapter.features.has("subgroups")) + throw new Error( + 'GPU lacks the required "subgroups" feature. The current fast WGSL kernels require subgroups and no fallback kernel set is bundled.' + ); + const hasSubgroupId = !!navigator.gpu.wgslLanguageFeatures?.has("subgroup_id"); + const hasLinearIndexing = !!navigator.gpu.wgslLanguageFeatures?.has("linear_indexing"); + const hasF16 = adapter.features.has("shader-f16"); + const hasTimestamp = adapter.features.has("timestamp-query"); + const reqFeatures = ["subgroups"]; + if (adapter.features.has("shader-f16")) reqFeatures.push("shader-f16"); + if (hasTimestamp) reqFeatures.push("timestamp-query"); + const dev = await adapter.requestDevice({ + requiredFeatures: reqFeatures, + requiredLimits: { + maxBufferSize: adapter.limits.maxBufferSize, + maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize, + maxStorageBuffersPerShaderStage: adapter.limits.maxStorageBuffersPerShaderStage + } + }); + dev.addEventListener?.("uncapturederror", (e) => console.error("GPUERR", e.error.message)); + log2(`WebGPU ready. maxBuffer=${(Number(adapter.limits.maxBufferSize) / 1e9).toFixed(2)}GB subgroupId=${hasSubgroupId} linearIdx=${hasLinearIndexing} f16=${hasF16} tsQuery=${hasTimestamp}`); + return dev; +} +__name(initWebGPUDevice, "initWebGPUDevice"); + +// src/services/prompt_formatter.js +function chatML(messages) { + let s = messages[0]?.role === "system" ? "" : "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"; + for (const m of messages) s += `<|im_start|>${m.role} +${m.content}<|im_end|> +`; + return s + "<|im_start|>assistant\n"; +} +__name(chatML, "chatML"); +function formatMessages(tokenizer, messages) { + try { + return tokenizer.apply_chat_template(messages, { tokenize: false, add_generation_prompt: true }); + } catch { + return chatML(messages); + } +} +__name(formatMessages, "formatMessages"); + +// src/services/model_session.js +async function buildTokenizer(reader) { + const tj = JSON.parse(await reader.text("tokenizer.json")); + const tc = JSON.parse(await reader.text("tokenizer_config.json")); + const { PreTrainedTokenizer } = await import("@huggingface/transformers"); + return new PreTrainedTokenizer(tj, tc); +} +__name(buildTokenizer, "buildTokenizer"); +function randomUnit() { + if (globalThis.crypto?.getRandomValues) { + const u = new Uint32Array(1); + globalThis.crypto.getRandomValues(u); + return u[0] / 4294967296; + } + return Math.random(); +} +__name(randomUnit, "randomUnit"); +function sampleTopK(candidates, { temperature, topP = 1 }) { + if (!temperature || temperature <= 0) return candidates[0]?.id ?? 0; + const best = candidates[0]?.logit ?? 0; + const weighted = candidates.map((c2) => ({ id: c2.id, w: Math.exp((c2.logit - best) / temperature) })); + let sum = weighted.reduce((a, c2) => a + c2.w, 0); + if (topP > 0 && topP < 1 && weighted.length > 1 && sum > 0) { + let csum = 0, keep = 0; + for (; keep < weighted.length; keep++) { + csum += weighted[keep].w / sum; + if (csum >= topP) { + keep++; + break; + } + } + weighted.length = Math.max(1, keep); + sum = weighted.reduce((a, c2) => a + c2.w, 0); + } + let r = randomUnit() * sum, c = 0; + for (const item of weighted) { + c += item.w; + if (r <= c) return item.id; + } + return weighted[weighted.length - 1]?.id ?? candidates[0]?.id ?? 0; +} +__name(sampleTopK, "sampleTopK"); +var ModelSession = class { + static { + __name(this, "ModelSession"); + } + constructor({ cfg = QWEN25_3B, log: log2 = /* @__PURE__ */ __name(() => { + }, "log"), runtimeOptions = {} } = {}) { + this.cfg = cfg; + this.log = log2; + this.runtimeOptions = { decodeBatchSize: "auto", samplingTopK: 40, ...runtimeOptions }; + this.dev = null; + this.rt = null; + this.tokenizer = null; + } + async loadWith(reader, label) { + this.dev = await initWebGPUDevice({ log: this.log }); + this.log(`loading tokenizer from ${label}\u2026`); + this.tokenizer = await buildTokenizer(reader); + this.log(`tokenizer loaded. streaming + quantizing weights (int4) from ${label}\u2026`); + const t0 = performance.now(); + this.rt = new QwenWGPU(this.dev, this.cfg, this.runtimeOptions); + await this.rt.build(reader, (msg, frac) => this.log(`weights: ${msg} ${(frac * 100).toFixed(0)}%`)); + window.__rt = this.rt; + window.__tokenizer = this.tokenizer; + const tuning = this.rt.decodeBatchTuning; + const tuned = tuning ? ` decodeBatch=${tuning.selected} (${tuning.reason})` : ""; + this.log( + `READY in ${((performance.now() - t0) / 1e3).toFixed(1)}s \u2014 base loaded once; adapters hot-swap live.${tuned}` + ); + return this; + } + async readLogits() { + const n = this.cfg.vocabSize; + const rb = this.dev.createBuffer({ size: n * 4, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }); + const enc = this.dev.createCommandEncoder(); + enc.copyBufferToBuffer(this.rt.s.logits, 0, rb, 0, n * 4); + this.dev.queue.submit([enc.finish()]); + await rb.mapAsync(GPUMapMode.READ); + const a = new Float32Array(rb.getMappedRange()).slice(); + rb.unmap(); + rb.destroy(); + return a; + } + async sampleNextToken({ temperature, topK = this.rt.samplingTopK, topP = 1 } = {}) { + return sampleTopK(await this.rt.topKLogits(topK), { temperature, topP }); + } + async *generate(messages, { maxTokens = 1024, temperature = 0, topK, topP = 1, stopIds = [151645, 151643] } = {}) { + const rt = this.rt, tokenizer = this.tokenizer; + const ids = tokenizer.encode(formatMessages(tokenizer, messages)); + if (ids.length <= rt.maxPrefillT) rt.prefillBatch(ids); + else for (let p = 0; p < ids.length; p++) rt.token(ids[p], p); + let pos = ids.length; + const emit = /* @__PURE__ */ __name((id) => tokenizer.decode([id], { skip_special_tokens: true }), "emit"); + if (temperature > 0) { + let next = await this.sampleNextToken({ temperature, topK, topP }); + for (let step = 0; step < maxTokens; step++) { + if (stopIds.includes(next)) break; + const d = emit(next); + if (d) yield d; + rt.token(next, pos); + pos++; + next = await this.sampleNextToken({ temperature, topK, topP }); + } + return; + } + const first = await rt.argmaxLogits(); + if (stopIds.includes(first)) return; + { + const d = emit(first); + if (d) yield d; + } + let emitted = 1; + while (emitted < maxTokens && pos < rt.maxCtx) { + const K = rt.greedyBatchSizeFor({ emitted, remaining: maxTokens - emitted, pos }); + const batch = await rt.decodeGreedyBatch(pos, K); + pos += batch.length; + let stop = false; + for (const id of batch) { + if (stopIds.includes(id)) { + stop = true; + break; + } + const d = emit(id); + if (d) yield d; + emitted++; + if (emitted >= maxTokens) { + stop = true; + break; + } + } + if (stop) break; + } + } +}; + +// src/qwgpu/backward_kernels.js +var GEMM_DX_INT4 = ` +requires immediate_address_space; +struct Meta { T:u32, N:u32, K:u32, gpr:u32 }; +@group(0) @binding(0) var dY: array; // [T][N] +@group(0) @binding(1) var W: array; // [N][K/8] int4 +@group(0) @binding(2) var scaleW: array; // [N][gpr] +@group(0) @binding(3) var dX: array; // [T][K] +var m: Meta; +fn deq4(n: u32, k: u32, K8: u32) -> f32 { + let word = W[n*K8 + (k >> 3u)]; + let shift = (k & 7u) * 4u; + let nib = i32(word << (28u - shift)) >> 28u; + return f32(nib) * scaleW[n*m.gpr + (k >> 7u)]; +} +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let total = m.T * m.K; let stride = nwg.x * 256u; let K8 = m.K / 8u; + for (var i = gid.x; i < total; i = i + stride) { + let t = i / m.K; let k = i % m.K; + var acc = 0.0; + let yb = t * m.N; + for (var n = 0u; n < m.N; n = n + 1u) { acc = acc + dY[yb + n] * deq4(n, k, K8); } + dX[i] = dX[i] + acc; + } +}`; +var LORA_DD = ` +requires immediate_address_space; +struct Meta { T:u32, N:u32, rank:u32, p:u32, scale:f32, f0:f32, f1:f32, f2:f32 }; +@group(0) @binding(0) var dY: array; // [T][N] +@group(0) @binding(1) var B: array; // [rank][N] +@group(0) @binding(2) var dD: array; // [T][rank] +var m: Meta; +var part: array; +@compute @workgroup_size(256) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let idx = wid.x; let t = idx / m.rank; let r = idx % m.rank; let tid = lid.x; + if (t >= m.T) { return; } + var s = 0.0; let yb = t*m.N; let bb = r*m.N; + for (var n = tid; n < m.N; n = n + 256u) { s = s + dY[yb + n] * B[bb + n]; } + part[tid] = s; workgroupBarrier(); + for (var st = 128u; st > 0u; st = st/2u) { if (tid < st) { part[tid] = part[tid] + part[tid+st]; } workgroupBarrier(); } + if (tid == 0u) { dD[t*m.rank + r] = m.scale * part[0]; } +}`; +var LORA_GRAD_A = ` +requires immediate_address_space; +struct Meta { T:u32, K:u32, rank:u32, p:u32 }; +@group(0) @binding(0) var dD: array; // [T][rank] +@group(0) @binding(1) var X: array; // [T][K] +@group(0) @binding(2) var dA: array; // [rank][K] +var m: Meta; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let total = m.rank * m.K; let stride = nwg.x * 256u; + for (var i = gid.x; i < total; i = i + stride) { + let r = i / m.K; let k = i % m.K; + var acc = 0.0; + for (var t = 0u; t < m.T; t = t + 1u) { acc = acc + dD[t*m.rank + r] * X[t*m.K + k]; } + dA[i] = dA[i] + acc; + } +}`; +var LORA_GRAD_B = ` +requires immediate_address_space; +struct Meta { T:u32, N:u32, rank:u32, p:u32, scale:f32, f0:f32, f1:f32, f2:f32 }; +@group(0) @binding(0) var D: array; // [T][rank] +@group(0) @binding(1) var dY: array; // [T][N] +@group(0) @binding(2) var dB: array; // [rank][N] +var m: Meta; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let total = m.rank * m.N; let stride = nwg.x * 256u; + for (var i = gid.x; i < total; i = i + stride) { + let r = i / m.N; let n = i % m.N; + var acc = 0.0; + for (var t = 0u; t < m.T; t = t + 1u) { acc = acc + D[t*m.rank + r] * dY[t*m.N + n]; } + dB[i] = dB[i] + m.scale * acc; + } +}`; +var LORA_DX_ADD = ` +requires immediate_address_space; +struct Meta { T:u32, K:u32, rank:u32, p:u32 }; +@group(0) @binding(0) var dD: array; // [T][rank] +@group(0) @binding(1) var A: array; // [rank][K] +@group(0) @binding(2) var dX: array; // [T][K] +var m: Meta; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let total = m.T * m.K; let stride = nwg.x * 256u; + for (var i = gid.x; i < total; i = i + stride) { + let t = i / m.K; let k = i % m.K; + var acc = 0.0; + for (var r = 0u; r < m.rank; r = r + 1u) { acc = acc + dD[t*m.rank + r] * A[r*m.K + k]; } + dX[i] = dX[i] + acc; + } +}`; +var RMSNORM_BWD_T = ` +requires immediate_address_space; +override WG: u32 = 256u; +@group(0) @binding(0) var x: array; // [T][K] +@group(0) @binding(1) var g: array; // [K] +@group(0) @binding(2) var dy: array; // [T][K] +@group(0) @binding(3) var dx: array; // [T][K] +var m: vec2; // K, eps +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; let K = u32(m.x); let base = wid.x * K; + // sum of squares for inv + var ss = 0.0; + for (var k = tid; k < K; k = k + WG) { let v = x[base+k]; ss = ss + v*v; } + red[tid] = ss; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } + let ms = red[0] / m.x; + let inv = inverseSqrt(ms + m.y); + workgroupBarrier(); + // c = sum dy*g*x + var cc = 0.0; + for (var k = tid; k < K; k = k + WG) { cc = cc + dy[base+k]*g[k]*x[base+k]; } + red[tid] = cc; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } + let c = red[0]; + let inv3overK = inv*inv*inv / m.x; + for (var k = tid; k < K; k = k + WG) { + dx[base+k] = inv*g[k]*dy[base+k] - inv3overK * x[base+k] * c; + } +}`; +var SWIGLU_BWD = ` +requires immediate_address_space; +override WG: u32 = 256u; +@group(0) @binding(0) var gate: array; +@group(0) @binding(1) var up: array; +@group(0) @binding(2) var dOut: array; +@group(0) @binding(3) var dGate: array; +@group(0) @binding(4) var dUp: array; +var n: u32; +@compute @workgroup_size(WG) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let stride = nwg.x * WG; + for (var i = gid.x; i < n; i = i + stride) { + let z = gate[i]; let sig = 1.0/(1.0+exp(-z)); let sl = z*sig; + let d = dOut[i]; + dUp[i] = d * sl; + dGate[i] = d * up[i] * (sig * (1.0 + z*(1.0 - sig))); + } +}`; +var ROPE_BWD_T = ` +requires immediate_address_space; +@group(0) @binding(0) var dx: array; // [T][nHeads*headDim] gradient +@group(0) @binding(1) var cosT: array; +@group(0) @binding(2) var sinT: array; +var m: vec4; // nHeads, headDim, T, pos0 +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let g = gid.x; let H = m.x; let D = m.y; let T = m.z; let pos0 = m.w; let half = D/2u; + let perRow = H*half; if (g >= T*perRow) { return; } + let row = g / perRow; let r = g % perRow; let h = r / half; let j = r % half; + let rb = row*H*D; let lo = rb + h*D + j; let hi = lo + half; let off = (pos0+row)*D + j; + let c = cosT[off]; let s = sinT[off]; + let dl = dx[lo]; let dh = dx[hi]; + dx[lo] = c*dl + s*dh; + dx[hi] = -s*dl + c*dh; +}`; +var ATTN_BWD_STATS = ` +requires immediate_address_space; +override WG: u32 = 128u; +struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32 }; +@group(0) @binding(0) var q: array; // [T][nHeads*hd] +@group(0) @binding(1) var kc: array; // [T][nKV*hd] +@group(0) @binding(2) var o: array; // [T][nHeads*hd] attn output +@group(0) @binding(3) var doo: array; // [T][nHeads*hd] grad of attn output +@group(0) @binding(4) var lse: array; // [nHeads*T] +@group(0) @binding(5) var delta: array; // [nHeads*T] +var m: Meta; +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let h = wid.x; let t = wid.y; let tid = lid.x; + let hd = m.hd; let nKV = m.nKV; let kvh = h / (m.nHeads / nKV); + let qb = t*m.nHeads*hd + h*hd; let kvstride = nKV*hd; let hoff = kvh*hd; + let scl = 1.0 / sqrt(f32(hd)); + // running max + var lmax = -1e30; + for (var j = tid; j <= t; j = j + WG) { + var dot = 0.0; let kb = j*kvstride + hoff; + for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qb+d]*kc[kb+d]; } + lmax = max(lmax, dot*scl); + } + red[tid] = lmax; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = max(red[tid], red[tid+s]); } workgroupBarrier(); } + let M = red[0]; + workgroupBarrier(); + var lsum = 0.0; + for (var j = tid; j <= t; j = j + WG) { + var dot = 0.0; let kb = j*kvstride + hoff; + for (var d = 0u; d < hd; d = d + 1u) { dot = dot + q[qb+d]*kc[kb+d]; } + lsum = lsum + exp(dot*scl - M); + } + red[tid] = lsum; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } + // delta + var dl = 0.0; + for (var d = tid; d < hd; d = d + WG) { dl = dl + doo[qb+d]*o[qb+d]; } + // reuse red after sum captured + let Z = red[0]; + workgroupBarrier(); + red[tid] = dl; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } + if (tid == 0u) { lse[h*m.T + t] = M + log(Z); delta[h*m.T + t] = red[0]; } +}`; +var ATTN_BWD_DQ = ` +requires immediate_address_space; +override WG: u32 = 128u; +struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var doo: array; +@group(0) @binding(4) var lse: array; +@group(0) @binding(5) var delta: array; +@group(0) @binding(6) var dq: array; +var m: Meta; +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let h = wid.x; let t = wid.y; let d = lid.x; + let hd = m.hd; let nKV = m.nKV; let kvh = h / (m.nHeads / nKV); + let qb = t*m.nHeads*hd + h*hd; let kvstride = nKV*hd; let hoff = kvh*hd; + let scl = 1.0 / sqrt(f32(hd)); + let lse_t = lse[h*m.T + t]; let delta_t = delta[h*m.T + t]; + // Guard every storage read behind (d < hd): WGSL select() is eager and would + // still evaluate the buffer load for inactive lanes (OOB when hd < WG). Barriers + // stay at uniform control flow so the reductions remain valid. + let inHd = d < hd; + var acc = 0.0; + for (var j = 0u; j <= t; j = j + 1u) { + let kb = j*kvstride + hoff; + // s = scl * dot(q, k_j) + var sv = 0.0; if (inHd) { sv = q[qb+d] * kc[kb+d]; } + red[d] = sv; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } + let sval = red[0] * scl; + workgroupBarrier(); + // dp = dot(do, v_j) + var dpv = 0.0; if (inHd) { dpv = doo[qb+d] * vc[kb+d]; } + red[d] = dpv; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } + let dp = red[0]; + workgroupBarrier(); + let p = exp(sval - lse_t); + let ds = p * (dp - delta_t); + if (inHd) { acc = acc + ds * kc[kb+d]; } + } + if (inHd) { dq[qb+d] = dq[qb+d] + scl * acc; } +}`; +var ATTN_BWD_DKV = ` +requires immediate_address_space; +override WG: u32 = 128u; +struct Meta { nHeads:u32, nKV:u32, hd:u32, T:u32 }; +@group(0) @binding(0) var q: array; +@group(0) @binding(1) var kc: array; +@group(0) @binding(2) var vc: array; +@group(0) @binding(3) var doo: array; +@group(0) @binding(4) var lse: array; +@group(0) @binding(5) var delta: array; +@group(0) @binding(6) var dk: array; +@group(0) @binding(7) var dv: array; +var m: Meta; +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let kvh = wid.x; let j = wid.y; let d = lid.x; + let hd = m.hd; let nKV = m.nKV; let group = m.nHeads / nKV; + let kvstride = nKV*hd; let hoff = kvh*hd; let kb = j*kvstride + hoff; + let scl = 1.0 / sqrt(f32(hd)); + // Guard storage reads behind (d < hd) \u2014 see ATTN_BWD_DQ note on eager select(). + let inHd = d < hd; + var dkacc = 0.0; var dvacc = 0.0; + for (var hi = 0u; hi < group; hi = hi + 1u) { + let h = kvh*group + hi; + for (var t = j; t < m.T; t = t + 1u) { + let qb = t*m.nHeads*hd + h*hd; + var sv = 0.0; if (inHd) { sv = q[qb+d] * kc[kb+d]; } + red[d] = sv; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } + let sval = red[0] * scl; + workgroupBarrier(); + var dpv = 0.0; if (inHd) { dpv = doo[qb+d] * vc[kb+d]; } + red[d] = dpv; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (d < s) { red[d] = red[d] + red[d+s]; } workgroupBarrier(); } + let dp = red[0]; + workgroupBarrier(); + let p = exp(sval - lse[h*m.T + t]); + let ds = p * (dp - delta[h*m.T + t]); + if (inHd) { + dkacc = dkacc + scl * ds * q[qb+d]; + dvacc = dvacc + p * doo[qb+d]; + } + } + } + if (inHd) { dk[kb+d] = dk[kb+d] + dkacc; dv[kb+d] = dv[kb+d] + dvacc; } +}`; +var LOGITS_GEMM_I8 = ` +requires immediate_address_space; +struct Meta { T:u32, vocab:u32, K:u32, tOff:u32 }; +@group(0) @binding(0) var normed: array; // [T][K] (full-seq buffer, offset by tOff) +@group(0) @binding(1) var E: array; // [vocab][K/4] int8 +@group(0) @binding(2) var scaleE: array; // [vocab] +@group(0) @binding(3) var logits: array; // [Tblock][vocab] +var m: Meta; +fn sx8(v: u32) -> i32 { + return i32(v << 24u) >> 24u; +} +fn unpack4xI8(x: u32) -> vec4 { + return vec4( + sx8(x & 0xffu), + sx8((x >> 8u) & 0xffu), + sx8((x >> 16u) & 0xffu), + sx8((x >> 24u) & 0xffu) + ); +} +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let total = m.T * m.vocab; let stride = nwg.x * 256u; let K4 = m.K / 4u; + for (var i = gid.x; i < total; i = i + stride) { + let t = i / m.vocab; let v = i % m.vocab; + let nb = (m.tOff + t) * m.K; let eb = v * K4; + var acc = 0.0; + for (var c = 0u; c < K4; c = c + 1u) { + let p = unpack4xI8(E[eb + c]); let kk = c*4u; + acc = acc + normed[nb+kk]*f32(p.x) + normed[nb+kk+1u]*f32(p.y) + + normed[nb+kk+2u]*f32(p.z) + normed[nb+kk+3u]*f32(p.w); + } + logits[i] = acc * scaleE[v]; + } +}`; +var CE_SOFTMAX_GRAD = ` +requires immediate_address_space; +override WG: u32 = 256u; +struct Meta { vocab:u32, tOff:u32, lossScale:f32, p:u32 }; +@group(0) @binding(0) var logits: array; // [bt][vocab] -> dLogits +@group(0) @binding(1) var labels: array; // [T] token id (global) +@group(0) @binding(2) var mask: array; // [T] 1 train / 0 skip (global) +@group(0) @binding(3) var lossOut: array;// [T] (global) +var m: Meta; +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { + let lt = wid.x; let tid = lid.x; let base = lt*m.vocab; + let gt = m.tOff + lt; // global token index for target/mask/loss + let mk = mask[gt]; + // max + var mx = -1e30; + for (var v = tid; v < m.vocab; v = v + WG) { mx = max(mx, logits[base+v]); } + red[tid] = mx; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = max(red[tid], red[tid+s]); } workgroupBarrier(); } + let M = red[0]; workgroupBarrier(); + // sum exp + var sm = 0.0; + for (var v = tid; v < m.vocab; v = v + WG) { sm = sm + exp(logits[base+v] - M); } + red[tid] = sm; workgroupBarrier(); + for (var s = WG/2u; s > 0u; s = s/2u) { if (tid < s) { red[tid] = red[tid] + red[tid+s]; } workgroupBarrier(); } + let Z = red[0]; + let tgt = labels[gt]; + if (tid == 0u) { + let ltgt = logits[base + tgt]; + lossOut[gt] = mk * (log(Z) - (ltgt - M)); + } + // dLogits = mask*lossScale*(p - onehot) + let invZ = 1.0 / Z; let g = mk * m.lossScale; + for (var v = tid; v < m.vocab; v = v + WG) { + var p = exp(logits[base+v] - M) * invZ; + if (v == tgt) { p = p - 1.0; } + logits[base+v] = g * p; + } +}`; +var DHIDDEN_FROM_DLOGITS_I8 = ` +requires immediate_address_space; +struct Meta { T:u32, vocab:u32, K:u32, tOff:u32 }; +@group(0) @binding(0) var dLogits: array; // [Tblock][vocab] +@group(0) @binding(1) var E: array; // [vocab][K/4] int8 +@group(0) @binding(2) var scaleE: array; // [vocab] +@group(0) @binding(3) var dHidden: array; // [T][K] (offset tOff) +var m: Meta; +fn sx8(v: u32) -> i32 { + return i32(v << 24u) >> 24u; +} +fn unpack4xI8(x: u32) -> vec4 { + return vec4( + sx8(x & 0xffu), + sx8((x >> 8u) & 0xffu), + sx8((x >> 16u) & 0xffu), + sx8((x >> 24u) & 0xffu) + ); +} +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let total = m.T * m.K; let stride = nwg.x * 256u; let K4 = m.K / 4u; + for (var i = gid.x; i < total; i = i + stride) { + let t = i / m.K; let k = i % m.K; + let lb = t * m.vocab; + var acc = 0.0; + let word_idx = k >> 2u; let lane = k & 3u; + for (var v = 0u; v < m.vocab; v = v + 1u) { + let p = unpack4xI8(E[v*K4 + word_idx]); + var b: i32; if (lane==0u){b=p.x;} else if (lane==1u){b=p.y;} else if (lane==2u){b=p.z;} else {b=p.w;} + acc = acc + dLogits[lb + v] * scaleE[v] * f32(b); + } + dHidden[(m.tOff + t)*m.K + k] = dHidden[(m.tOff + t)*m.K + k] + acc; + } +}`; +var ADAMW_STEP = ` +requires immediate_address_space; +struct Meta { n:u32, p:u32, lr:f32, beta1:f32, beta2:f32, eps:f32, wd:f32, gScale:f32, b1c:f32, b2c:f32, f0:f32, f1:f32 }; +@group(0) @binding(0) var param: array; +@group(0) @binding(1) var grad: array; +@group(0) @binding(2) var mBuf: array; +@group(0) @binding(3) var vBuf: array; +var m: Meta; +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3, @builtin(num_workgroups) nwg: vec3) { + let stride = nwg.x * 256u; + for (var i = gid.x; i < m.n; i = i + stride) { + let gr = grad[i] * m.gScale; + let mm = m.beta1 * mBuf[i] + (1.0 - m.beta1) * gr; + let vv = m.beta2 * vBuf[i] + (1.0 - m.beta2) * gr * gr; + mBuf[i] = mm; vBuf[i] = vv; + let mhat = mm / m.b1c; let vhat = vv / m.b2c; + param[i] = param[i] - m.lr * (mhat / (sqrt(vhat) + m.eps) + m.wd * param[i]); + } +}`; +var SUMSQ = ` +requires immediate_address_space; +override WG: u32 = 256u; +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var out: array; // [1] +var n: u32; +var red: array; +@compute @workgroup_size(WG) +fn main(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; var s = 0.0; + for (var i = tid; i < n; i = i + WG) { let v = x[i]; s = s + v*v; } + red[tid] = s; workgroupBarrier(); + for (var st = WG/2u; st > 0u; st = st/2u) { if (tid < st) { red[tid] = red[tid] + red[tid+st]; } workgroupBarrier(); } + if (tid == 0u) { out[0] = out[0] + red[0]; } +}`; + +// src/qwgpu/trainer.js +var STORAGE2 = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; +var READBACK = GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ; +var nowMs = /* @__PURE__ */ __name(() => globalThis.performance?.now?.() ?? Date.now(), "nowMs"); +var ALL_PROJ = ["q", "k", "v", "o", "gate", "up", "down"]; +function createTrainableAdapter(rt, opts = {}) { + const rank = Math.max(1, Math.floor(opts.rank ?? 16)); + const alpha = opts.alpha ?? rank * 2; + const scale = opts.scale ?? alpha / rank; + const targets = opts.targetModules ?? ALL_PROJ; + const stddev = opts.stddev ?? 1 / Math.sqrt(rank); + const usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; + const gauss = /* @__PURE__ */ __name(() => { + let u = 0, v = 0; + while (u === 0) u = Math.random(); + while (v === 0) v = Math.random(); + return Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * v); + }, "gauss"); + const modules = {}; + for (const L of rt.plan.layers) { + for (const name of ALL_PROJ) { + if (!targets.includes(name)) continue; + const part = L[name]; + const q4 = rt.q4[part.weight]; + const K = q4.K, N = q4.N; + const Aarr = new Float32Array(rank * K); + for (let i = 0; i < Aarr.length; i++) Aarr[i] = gauss() * stddev; + const Barr = new Float32Array(rank * N); + const A = rt.dev.createBuffer({ size: Aarr.byteLength, usage }); + const B = rt.dev.createBuffer({ size: Barr.byteLength, usage }); + rt.dev.queue.writeBuffer(A, 0, Aarr); + rt.dev.queue.writeBuffer(B, 0, Barr); + modules[part.loraKey] = { A, B, rank, scale, inDim: K, outDim: N }; + } + } + return { name: opts.name || "trainable", modules }; +} +__name(createTrainableAdapter, "createTrainableAdapter"); +var QwenLoraTrainer = class { + static { + __name(this, "QwenLoraTrainer"); + } + // rt: a built QwenWGPU. opts: see _normalizeOpts. + constructor(rt, opts = {}) { + this.rt = rt; + this.dev = rt.dev; + this.cfg = rt.cfg; + this.opts = this._normalizeOpts(opts); + this.step = 0; + this._microInWindow = 0; + this.scratchT = 0; + this._buildPipes(); + } + _normalizeOpts(o) { + return { + lr: o.lr ?? 1e-4, + beta1: o.beta1 ?? 0.9, + beta2: o.beta2 ?? 0.999, + eps: o.eps ?? 1e-8, + weightDecay: o.weightDecay ?? 0, + maxGradNorm: o.maxGradNorm ?? 1, + gradAccumSteps: Math.max(1, Math.floor(o.gradAccumSteps ?? 1)), + lmHeadBlock: Math.max(1, Math.floor(o.lmHeadBlock ?? 128)), + maxTrainSeq: Math.max(1, Math.floor(o.maxTrainSeq ?? 512)), + warmupSteps: Math.max(0, Math.floor(o.warmupSteps ?? 0)), + totalSteps: o.totalSteps ?? 0, + // for cosine decay; 0 disables decay + minLrRatio: o.minLrRatio ?? 0.1, + targetModules: o.targetModules ?? ALL_PROJ + }; + } + _buildPipes() { + const rt = this.rt; + this.p = { + dx4: rt._pipe(GEMM_DX_INT4, "bwd_dx4"), + dd: rt._pipe(LORA_DD, "bwd_lora_dd"), + gradA: rt._pipe(LORA_GRAD_A, "bwd_lora_dA"), + gradB: rt._pipe(LORA_GRAD_B, "bwd_lora_dB"), + dxAdd: rt._pipe(LORA_DX_ADD, "bwd_lora_dx"), + rmsBwd: rt._pipe(RMSNORM_BWD_T, "bwd_rms"), + swiglu: rt._pipe(SWIGLU_BWD, "bwd_swiglu"), + ropeBwd: rt._pipe(ROPE_BWD_T, "bwd_rope"), + attnStats: rt._pipe(ATTN_BWD_STATS, "bwd_attn_stats"), + attnDq: rt._pipe(ATTN_BWD_DQ, "bwd_attn_dq"), + attnDkv: rt._pipe(ATTN_BWD_DKV, "bwd_attn_dkv"), + logits: rt._pipe(LOGITS_GEMM_I8, "bwd_logits"), + ceGrad: rt._pipe(CE_SOFTMAX_GRAD, "bwd_ce"), + dHidden: rt._pipe(DHIDDEN_FROM_DLOGITS_I8, "bwd_dhidden"), + adamw: rt._pipe(ADAMW_STEP, "adamw"), + sumsq: rt._pipe(SUMSQ, "sumsq") + }; + } + // ---- adapter attach: build per-module grad + Adam moment state ---- + // The adapter must already be uploaded (loadLoraAdapterGPU) and set on rt. + attach(adapter) { + if (!adapter || !adapter.modules) throw new Error("trainer.attach: adapter with modules required"); + this.adapter = adapter; + this.rt.setLora(adapter); + const rt = this.rt; + const byKey = /* @__PURE__ */ new Map(); + for (const L of rt.plan.layers) { + for (const name of ALL_PROJ) { + const part = L[name]; + byKey.set(part.loraKey, { part, kind: name, q4: rt.q4[part.weight] }); + } + } + this.state = {}; + let maxRank = 1; + for (const key of Object.keys(adapter.modules)) { + const mod = adapter.modules[key]; + const info = byKey.get(key); + if (!info) continue; + const kind = info.kind.replace(/_proj$/, ""); + if (!this.opts.targetModules.includes(kind)) continue; + const K = info.q4.K, N = info.q4.N, rank = mod.rank; + maxRank = Math.max(maxRank, rank); + this.state[key] = { + mod, + q4: info.q4, + K, + N, + rank, + scale: mod.scale, + dA: rt._buf(rank * K * 4), + dB: rt._buf(rank * N * 4), + mA: rt._buf(rank * K * 4), + vA: rt._buf(rank * K * 4), + mB: rt._buf(rank * N * 4), + vB: rt._buf(rank * N * 4) + }; + } + this.maxRank = maxRank; + this.trainedKeys = Object.keys(this.state); + if (!this.trainedKeys.length) throw new Error("trainer.attach: no trainable modules matched targetModules"); + this._zeroAdamMoments(); + this.zeroGrads(); + return this; + } + _zeroAdamMoments() { + const enc = this.dev.createCommandEncoder(); + for (const k of this.trainedKeys) { + const st = this.state[k]; + enc.clearBuffer(st.mA); + enc.clearBuffer(st.vA); + enc.clearBuffer(st.mB); + enc.clearBuffer(st.vB); + } + this.dev.queue.submit([enc.finish()]); + } + zeroGrads() { + const enc = this.dev.createCommandEncoder(); + for (const k of this.trainedKeys) { + enc.clearBuffer(this.state[k].dA); + enc.clearBuffer(this.state[k].dB); + } + this.dev.queue.submit([enc.finish()]); + this._microInWindow = 0; + } + // ---- activation/gradient scratch sized to the sequence ---- + _ensureScratch(T) { + if (this.scratchT >= T && this.s) return; + if (this.s) for (const k in this.s) this.s[k].destroy?.(); + if (this.ckpt) for (const c2 of this.ckpt) c2.destroy?.(); + this.lossRead?.destroy?.(); + this.normRead?.destroy?.(); + const c = this.cfg; + const H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize, nH = c.numHeads, R = this.maxRank, lmB = this.opts.lmHeadBlock, V = c.vocabSize; + const b = /* @__PURE__ */ __name((n) => this.rt._buf(n * 4), "b"); + this.ckpt = []; + for (let i = 0; i <= c.numLayers; i++) this.ckpt.push(b(T * H)); + this.s = { + hid: b(T * H), + normed1: b(T * H), + normed2: b(T * H), + normedF: b(T * H), + q: b(T * qd), + k: b(T * kvd), + v: b(T * kvd), + attn: b(T * qd), + hmid: b(T * H), + gate: b(T * I), + up: b(T * I), + swig: b(T * I), + dHidden: b(T * H), + dnorm: b(T * H), + dtmp: b(T * H), + dhmid: b(T * H), + dq: b(T * qd), + dk: b(T * kvd), + dv: b(T * kvd), + dob: b(T * qd), + dgate: b(T * I), + dup: b(T * I), + dswig: b(T * I), + dD: b(T * R), + Dmat: b(T * R), + lse: b(nH * T), + delta: b(nH * T), + logits: b(lmB * V), + loss: b(T), + targets: this.rt._buf(T * 4), + mask: b(T), + normBuf: b(1) + }; + this.lossRead = this.rt._buf(T * 4, READBACK); + this.normRead = this.rt._buf(4, READBACK); + this.scratchT = T; + } + // ---- small dispatch helpers ---- + _grid1d(n) { + return Math.min(Math.ceil(n / 256), 65535); + } + _disp(enc, pipe, buffers, gx, gy, imm, cat) { + const bg = this.rt._bg(pipe, buffers); + this.rt._dispatch(enc, pipe, bg, gx, gy, cat || "train", imm); + } + _u32(arr) { + return new Uint32Array(arr); + } + _meta(u32parts, f32parts = {}) { + const buf = new ArrayBuffer(48); + const dv = new DataView(buf); + for (const [i, v] of u32parts) dv.setUint32(i * 4, v >>> 0, true); + for (const [i, v] of Object.entries(f32parts)) dv.setFloat32(Number(i) * 4, v, true); + return new Uint8Array(buf); + } + // ---- forward with checkpoints (LoRA-modified, f32) ---- + _layerForward(enc, L, hid, T) { + const rt = this.rt, c = this.cfg, s = this.s; + const H = c.hiddenSize; + rt.rmsT(enc, hid, rt.bufs[L.inputNorm], s.normed1, T, H); + rt.gemm4(enc, s.normed1, rt.q4[L.q.weight], s.q, T, rt.bufs[L.q.bias], L.q.loraKey); + rt.gemm4(enc, s.normed1, rt.q4[L.k.weight], s.k, T, rt.bufs[L.k.bias], L.k.loraKey); + rt.gemm4(enc, s.normed1, rt.q4[L.v.weight], s.v, T, rt.bufs[L.v.bias], L.v.loraKey); + rt.ropeT(enc, s.q, T, c.numHeads); + rt.ropeT(enc, s.k, T, c.numKVHeads); + rt.attnPrefill(enc, s.q, s.k, s.v, s.attn, T, 0, T); + rt.gemm4AddT(enc, s.attn, rt.q4[L.o.weight], hid, T, null, L.o.loraKey); + rt.rmsT(enc, hid, rt.bufs[L.postAttentionNorm], s.normed2, T, H); + rt.gemm4(enc, s.normed2, rt.q4[L.gate.weight], s.gate, T, null, L.gate.loraKey); + rt.gemm4(enc, s.normed2, rt.q4[L.up.weight], s.up, T, null, L.up.loraKey); + enc.copyBufferToBuffer(s.gate, 0, s.swig, 0, T * c.intermediateSize * 4); + rt._siluMul(enc, s.swig, s.up, T * c.intermediateSize); + rt.gemm4AddT(enc, s.swig, rt.q4[L.down.weight], hid, T, null, L.down.loraKey); + } + _forward(enc, ids, T) { + const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize; + rt._ensurePrefillScratch(T, this.maxRank); + rt._resetUni(); + const e = rt.q[rt.plan.embed.name]; + this.dev.queue.writeBuffer(rt.sT.ids, 0, new Uint32Array(ids)); + rt._dispatch( + enc, + rt.pipes.embedT, + rt._bg(rt.pipes.embedT, [e.w, e.scale, this.ckpt[0], rt.sT.ids]), + Math.min(Math.ceil(T * H / 256), 65535), + 1, + "embedT", + this._u32([T, H, 0, 0]) + ); + enc.copyBufferToBuffer(this.ckpt[0], 0, s.hid, 0, T * H * 4); + for (let i = 0; i < c.numLayers; i++) { + this._layerForward(enc, rt.plan.layers[i], s.hid, T); + enc.copyBufferToBuffer(s.hid, 0, this.ckpt[i + 1], 0, T * H * 4); + } + } + // recompute one layer's forward internals (from its checkpoint) into scratch, also + // producing hmid (= ckpt + attnProj) which the backward needs as the post-attn input. + _recomputeLayer(enc, L, T) { + const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize, idx = L.index; + rt.rmsT(enc, this.ckpt[idx], rt.bufs[L.inputNorm], s.normed1, T, H); + rt.gemm4(enc, s.normed1, rt.q4[L.q.weight], s.q, T, rt.bufs[L.q.bias], L.q.loraKey); + rt.gemm4(enc, s.normed1, rt.q4[L.k.weight], s.k, T, rt.bufs[L.k.bias], L.k.loraKey); + rt.gemm4(enc, s.normed1, rt.q4[L.v.weight], s.v, T, rt.bufs[L.v.bias], L.v.loraKey); + rt.ropeT(enc, s.q, T, c.numHeads); + rt.ropeT(enc, s.k, T, c.numKVHeads); + rt.attnPrefill(enc, s.q, s.k, s.v, s.attn, T, 0, T); + enc.copyBufferToBuffer(this.ckpt[idx], 0, s.hmid, 0, T * H * 4); + rt.gemm4AddT(enc, s.attn, rt.q4[L.o.weight], s.hmid, T, null, L.o.loraKey); + rt.rmsT(enc, s.hmid, rt.bufs[L.postAttentionNorm], s.normed2, T, H); + rt.gemm4(enc, s.normed2, rt.q4[L.gate.weight], s.gate, T, null, L.gate.loraKey); + rt.gemm4(enc, s.normed2, rt.q4[L.up.weight], s.up, T, null, L.up.loraKey); + enc.copyBufferToBuffer(s.gate, 0, s.swig, 0, T * c.intermediateSize * 4); + rt._siluMul(enc, s.swig, s.up, T * c.intermediateSize); + } + // ---- LoRA + base projection backward ---- + // dY [T][N] -> accumulate into dXbuf [T][K] (base + LoRA), plus dA/dB grads. + _projBackward(enc, key, Xbuf, dYbuf, dXbuf, T) { + const st = this.state[key]; + if (!st) { + this._dispatch_dx4(enc, dYbuf, st, dXbuf, T, key); + return; + } + const { K, N, rank, scale, q4, dA, dB } = st; + const s = this.s; + this._disp( + enc, + this.p.dx4, + [dYbuf, q4.w, q4.scale, dXbuf], + this._grid1d(T * K), + 1, + this._meta([[0, T], [1, N], [2, K], [3, q4.gpr]]), + "dx4" + ); + this._disp( + enc, + this.p.dd, + [dYbuf, st.mod.B, s.dD], + T * rank, + 1, + this._meta([[0, T], [1, N], [2, rank]], { 4: scale }), + "dd" + ); + this._disp( + enc, + this.p.gradA, + [s.dD, Xbuf, dA], + this._grid1d(rank * K), + 1, + this._meta([[0, T], [1, K], [2, rank]]), + "gradA" + ); + this._disp( + enc, + this.rt.pipes.loraABatch, + [Xbuf, st.mod.A, s.Dmat], + rank, + T, + this._u32([K, rank, T, 0]), + "loraABatch" + ); + this._disp( + enc, + this.p.gradB, + [s.Dmat, dYbuf, dB], + this._grid1d(rank * N), + 1, + this._meta([[0, T], [1, N], [2, rank]], { 4: scale }), + "gradB" + ); + this._disp( + enc, + this.p.dxAdd, + [s.dD, st.mod.A, dXbuf], + this._grid1d(T * K), + 1, + this._meta([[0, T], [1, K], [2, rank]]), + "dxAdd" + ); + } + _dispatch_dx4(enc, dYbuf, st, dXbuf, T, key) { + const info = this._infoForKey(key); + const q4 = info.q4; + this._disp( + enc, + this.p.dx4, + [dYbuf, q4.w, q4.scale, dXbuf], + this._grid1d(T * q4.K), + 1, + this._meta([[0, T], [1, q4.N], [2, q4.K], [3, q4.gpr]]), + "dx4" + ); + } + _infoForKey(key) { + for (const L of this.rt.plan.layers) + for (const name of ALL_PROJ) if (L[name].loraKey === key) return { q4: this.rt.q4[L[name].weight] }; + throw new Error(`unknown loraKey ${key}`); + } + _rmsBwd(enc, xBuf, gBuf, dyBuf, dxBuf, T) { + const c = this.cfg; + this._disp( + enc, + this.p.rmsBwd, + [xBuf, gBuf, dyBuf, dxBuf], + T, + 1, + new Float32Array([c.hiddenSize, c.rmsNormEps]), + "rmsBwd" + ); + } + // ---- full backward for one micro-batch; accumulates grads, returns nothing ---- + _backward(enc, T, numActive) { + const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize, qd = c.numHeads * c.headDim, kvd = c.numKVHeads * c.headDim, I = c.intermediateSize, V = c.vocabSize; + rt.rmsT(enc, this.ckpt[c.numLayers], rt.bufs[rt.plan.finalNorm.name], s.normedF, T, H); + enc.clearBuffer(s.dnorm); + const e = rt.q[rt.plan.embed.name]; + const lossScale = 1 / Math.max(1, numActive); + const lmB = this.opts.lmHeadBlock; + for (let off = 0; off < T; off += lmB) { + const bt = Math.min(lmB, T - off); + this._disp( + enc, + this.p.logits, + [s.normedF, e.w, e.scale, s.logits], + this._grid1d(bt * V), + 1, + this._meta([[0, bt], [1, V], [2, H], [3, off]]), + "logits" + ); + this._disp( + enc, + this.p.ceGrad, + [s.logits, s.targets, s.mask, s.loss], + bt, + 1, + this._meta([[0, V], [1, off]], { 2: lossScale }), + "ce" + ); + this._disp( + enc, + this.p.dHidden, + [s.logits, e.w, e.scale, s.dnorm], + this._grid1d(bt * H), + 1, + this._meta([[0, bt], [1, V], [2, H], [3, off]]), + "dHidden" + ); + } + this._rmsBwd(enc, this.ckpt[c.numLayers], rt.bufs[rt.plan.finalNorm.name], s.dnorm, s.dHidden, T); + for (let i = c.numLayers - 1; i >= 0; i--) { + const L = rt.plan.layers[i]; + this._recomputeLayer(enc, L, T); + enc.clearBuffer(s.dswig); + this._projBackward(enc, L.down.loraKey, s.swig, s.dHidden, s.dswig, T); + this._disp( + enc, + this.p.swiglu, + [s.gate, s.up, s.dswig, s.dgate, s.dup], + this._grid1d(T * I), + 1, + this._u32([T * I]), + "swiglu" + ); + enc.clearBuffer(s.dnorm); + this._projBackward(enc, L.gate.loraKey, s.normed2, s.dgate, s.dnorm, T); + this._projBackward(enc, L.up.loraKey, s.normed2, s.dup, s.dnorm, T); + this._rmsBwd(enc, s.hmid, rt.bufs[L.postAttentionNorm], s.dnorm, s.dtmp, T); + enc.copyBufferToBuffer(s.dHidden, 0, s.dhmid, 0, T * H * 4); + rt._addInto(enc, s.dhmid, s.dtmp, T * H); + enc.clearBuffer(s.dob); + this._projBackward(enc, L.o.loraKey, s.attn, s.dhmid, s.dob, T); + const am = this._u32([c.numHeads, c.numKVHeads, c.headDim, T]); + this._disp(enc, this.p.attnStats, [s.q, s.k, s.attn, s.dob, s.lse, s.delta], c.numHeads, T, am, "attnStats"); + enc.clearBuffer(s.dq); + enc.clearBuffer(s.dk); + enc.clearBuffer(s.dv); + this._disp(enc, this.p.attnDq, [s.q, s.k, s.v, s.dob, s.lse, s.delta, s.dq], c.numHeads, T, am, "attnDq"); + this._disp( + enc, + this.p.attnDkv, + [s.q, s.k, s.v, s.dob, s.lse, s.delta, s.dk, s.dv], + c.numKVHeads, + T, + am, + "attnDkv" + ); + this._disp( + enc, + this.p.ropeBwd, + [s.dq, rt.ropeCos, rt.ropeSin], + Math.ceil(T * c.numHeads * (c.headDim / 2) / 256), + 1, + this._u32([c.numHeads, c.headDim, T, 0]), + "ropeBwd" + ); + this._disp( + enc, + this.p.ropeBwd, + [s.dk, rt.ropeCos, rt.ropeSin], + Math.ceil(T * c.numKVHeads * (c.headDim / 2) / 256), + 1, + this._u32([c.numKVHeads, c.headDim, T, 0]), + "ropeBwd" + ); + enc.clearBuffer(s.dnorm); + this._projBackward(enc, L.q.loraKey, s.normed1, s.dq, s.dnorm, T); + this._projBackward(enc, L.k.loraKey, s.normed1, s.dk, s.dnorm, T); + this._projBackward(enc, L.v.loraKey, s.normed1, s.dv, s.dnorm, T); + this._rmsBwd(enc, this.ckpt[i], rt.bufs[L.inputNorm], s.dnorm, s.dtmp, T); + enc.copyBufferToBuffer(s.dhmid, 0, s.dHidden, 0, T * H * 4); + rt._addInto(enc, s.dHidden, s.dtmp, T * H); + } + } + // shifted-label targets + mask into the scratch buffers; returns numActive. + _writeTargets(tokens, lossMask, T) { + const targets = new Uint32Array(T); + const mask = new Float32Array(T); + let numActive = 0; + for (let t = 0; t < T - 1; t++) { + targets[t] = tokens[t + 1] >>> 0; + const mk = lossMask ? lossMask[t] ? 1 : 0 : 1; + mask[t] = mk; + numActive += mk; + } + targets[T - 1] = 0; + mask[T - 1] = 0; + this.dev.queue.writeBuffer(this.s.targets, 0, targets); + this.dev.queue.writeBuffer(this.s.mask, 0, mask); + return numActive; + } + // loss head only (final norm + streamed logits + CE), no backward sweep. Used by + // evalLoss(). CE overwrites s.logits with dLogits but we ignore that here. + _lossOnly(enc, T, numActive) { + const rt = this.rt, c = this.cfg, s = this.s, H = c.hiddenSize, V = c.vocabSize; + rt.rmsT(enc, this.ckpt[c.numLayers], rt.bufs[rt.plan.finalNorm.name], s.normedF, T, H); + const e = rt.q[rt.plan.embed.name]; + const lossScale = 1 / Math.max(1, numActive); + const lmB = this.opts.lmHeadBlock; + for (let off = 0; off < T; off += lmB) { + const bt = Math.min(lmB, T - off); + this._disp(enc, this.p.logits, [s.normedF, e.w, e.scale, s.logits], this._grid1d(bt * V), 1, this._meta([[0, bt], [1, V], [2, H], [3, off]]), "logits"); + this._disp(enc, this.p.ceGrad, [s.logits, s.targets, s.mask, s.loss], bt, 1, this._meta([[0, V], [1, off]], { 2: lossScale }), "ce"); + } + } + // ---- public: forward-only mean cross-entropy (no grads). For held-out eval. ---- + async evalLoss(tokens, lossMask) { + const T = tokens.length; + if (T > this.opts.maxTrainSeq) throw new Error(`seq ${T} > maxTrainSeq ${this.opts.maxTrainSeq}`); + this._ensureScratch(T); + const wasF16 = this.rt.usingF16?.(); + this.rt.setUseF16?.(false); + try { + const numActive = this._writeTargets(tokens, lossMask, T); + const enc = this.dev.createCommandEncoder(); + this._forward(enc, tokens, T); + this._lossOnly(enc, T, numActive); + enc.copyBufferToBuffer(this.s.loss, 0, this.lossRead, 0, T * 4); + this.dev.queue.submit([enc.finish()]); + await this.lossRead.mapAsync(GPUMapMode.READ); + const arr = new Float32Array(this.lossRead.getMappedRange().slice(0)); + this.lossRead.unmap(); + let sum = 0; + for (let t = 0; t < T; t++) sum += arr[t]; + return { loss: sum / Math.max(1, numActive), numActive }; + } finally { + if (wasF16) this.rt.setUseF16?.(true); + } + } + // ---- public: accumulate one micro-batch. tokens: Int array, lossMask: 0/1 array. ---- + // lossMask[t]==1 means "train the prediction of tokens[t+1] from position t". + async microStep(tokens, lossMask) { + const c = this.cfg; + const T = tokens.length; + const t0 = nowMs(); + if (T > this.opts.maxTrainSeq) throw new Error(`seq ${T} > maxTrainSeq ${this.opts.maxTrainSeq}`); + this._ensureScratch(T); + const wasF16 = this.rt.usingF16?.(); + this.rt.setUseF16?.(false); + try { + const numActive = this._writeTargets(tokens, lossMask, T); + const enc = this.dev.createCommandEncoder(); + this._forward(enc, tokens, T); + this._backward(enc, T, numActive); + enc.copyBufferToBuffer(this.s.loss, 0, this.lossRead, 0, T * 4); + this.dev.queue.submit([enc.finish()]); + await this.lossRead.mapAsync(GPUMapMode.READ); + const lossArr = new Float32Array(this.lossRead.getMappedRange().slice(0)); + this.lossRead.unmap(); + let lossSum = 0; + for (let t = 0; t < T; t++) lossSum += lossArr[t]; + this._microInWindow++; + const microStepMs = nowMs() - t0; + return { + loss: lossSum / Math.max(1, numActive), + numActive, + tokens: T, + microStepMs, + trainTokPerSec: T / Math.max(1e-6, microStepMs / 1e3) + }; + } finally { + if (wasF16) this.rt.setUseF16?.(true); + } + } + // ---- public: apply accumulated grads with AdamW + global-norm clip ---- + async optimizerStep() { + const t0 = nowMs(); + const o = this.opts; + const accum = this._microInWindow || 1; + const encN = this.dev.createCommandEncoder(); + encN.clearBuffer(this.s.normBuf); + for (const k of this.trainedKeys) { + const st = this.state[k]; + this._disp(encN, this.p.sumsq, [st.dA, this.s.normBuf], 1, 1, this._u32([st.rank * st.K]), "sumsq"); + this._disp(encN, this.p.sumsq, [st.dB, this.s.normBuf], 1, 1, this._u32([st.rank * st.N]), "sumsq"); + } + encN.copyBufferToBuffer(this.s.normBuf, 0, this.normRead, 0, 4); + this.dev.queue.submit([encN.finish()]); + await this.normRead.mapAsync(GPUMapMode.READ); + const sumsq = new Float32Array(this.normRead.getMappedRange().slice(0))[0]; + this.normRead.unmap(); + const gradScale = 1 / accum; + const gnorm = Math.sqrt(sumsq) * gradScale; + const clip2 = o.maxGradNorm > 0 && gnorm > o.maxGradNorm ? o.maxGradNorm / (gnorm + 1e-6) : 1; + const gScale = gradScale * clip2; + this.step++; + const lr = this._lrAt(this.step); + const b1c = 1 - Math.pow(o.beta1, this.step); + const b2c = 1 - Math.pow(o.beta2, this.step); + const enc = this.dev.createCommandEncoder(); + for (const k of this.trainedKeys) { + const st = this.state[k]; + const metaA = this._adamMeta(st.rank * st.K, lr, gScale, b1c, b2c); + this._disp(enc, this.p.adamw, [st.mod.A, st.dA, st.mA, st.vA], this._grid1d(st.rank * st.K), 1, metaA, "adamw"); + const metaB = this._adamMeta(st.rank * st.N, lr, gScale, b1c, b2c); + this._disp(enc, this.p.adamw, [st.mod.B, st.dB, st.mB, st.vB], this._grid1d(st.rank * st.N), 1, metaB, "adamw"); + } + this.dev.queue.submit([enc.finish()]); + this.rt.invalidateLora(); + this.zeroGrads(); + return { lr, gradNorm: gnorm, clip: clip2, optimizerStepMs: nowMs() - t0 }; + } + _lrAt(step) { + const o = this.opts; + if (o.warmupSteps > 0 && step <= o.warmupSteps) return o.lr * (step / o.warmupSteps); + if (o.totalSteps > 0 && step > o.warmupSteps) { + const prog = (step - o.warmupSteps) / Math.max(1, o.totalSteps - o.warmupSteps); + const cos = 0.5 * (1 + Math.cos(Math.PI * Math.min(1, prog))); + return o.lr * (o.minLrRatio + (1 - o.minLrRatio) * cos); + } + return o.lr; + } + _adamMeta(n, lr, gScale, b1c, b2c) { + const o = this.opts; + const buf = new ArrayBuffer(48); + const dv = new DataView(buf); + dv.setUint32(0, n >>> 0, true); + dv.setFloat32(8, lr, true); + dv.setFloat32(12, o.beta1, true); + dv.setFloat32(16, o.beta2, true); + dv.setFloat32(20, o.eps, true); + dv.setFloat32(24, o.weightDecay, true); + dv.setFloat32(28, gScale, true); + dv.setFloat32(32, b1c, true); + dv.setFloat32(36, b2c, true); + return new Uint8Array(buf); + } + // ---- convenience: one full optimization step over a list of micro-batches ---- + async trainStep(batches) { + const list = Array.isArray(batches) ? batches : [batches]; + let lossSum = 0, n = 0, numActive = 0, tokens = 0, microStepMs = 0; + for (const b of list) { + const r = await this.microStep(b.tokens, b.lossMask); + lossSum += r.loss; + numActive += r.numActive || 0; + tokens += r.tokens || b.tokens?.length || 0; + microStepMs += r.microStepMs || 0; + n++; + } + const opt = await this.optimizerStep(); + const totalStepMs = microStepMs + (opt.optimizerStepMs || 0); + return { + loss: lossSum / Math.max(1, n), + microBatches: n, + numActive, + tokens, + microStepMs, + totalStepMs, + trainTokPerSec: tokens / Math.max(1e-6, totalStepMs / 1e3), + ...opt + }; + } +}; + +// src/services/training_controller.js +var IM_END = 151645; +var TrainingController = class { + static { + __name(this, "TrainingController"); + } + // session: a loaded ModelSession (rt + tokenizer). adapters: AdapterRegistry. + constructor({ session: session2, adapters: adapters2, log: log2 = /* @__PURE__ */ __name(() => { + }, "log"), trainerOptions = {} } = {}) { + this.session = session2; + this.adapters = adapters2; + this.log = log2; + this.trainerOptions = trainerOptions; + this.trainer = null; + this.adapter = null; + } + get rt() { + return this.session.rt; + } + get tokenizer() { + return this.session.tokenizer; + } + // Create + register a fresh trainable adapter and attach the trainer to it. + initAdapter(name = "trainable", { rank = 16, alpha = 32, targetModules } = {}) { + const adapter = createTrainableAdapter(this.rt, { name, rank, alpha, targetModules }); + this.adapters.adapters[name] = adapter; + this.adapter = adapter; + this.trainer = new QwenLoraTrainer(this.rt, this.trainerOptions); + this.trainer.attach(adapter); + this.log(`init adapter "${name}" rank=${rank} alpha=${alpha} modules=${Object.keys(adapter.modules).length}`); + return adapter; + } + // Attach to an already-registered adapter (e.g. continue training a loaded one). + attachAdapter(name) { + const adapter = this.adapters.get(name); + if (!adapter) throw new Error(`adapter "${name}" not found`); + this.adapter = adapter; + this.trainer = new QwenLoraTrainer(this.rt, this.trainerOptions); + this.trainer.attach(adapter); + return adapter; + } + /* + * TECHNIQUE: Completion-only loss masking with shifted labels + * Tokenize prompt (with assistant generation prompt) and completion separately. + * mask[t]=1 trains the prediction of tokens[t+1] from position t — so we mask + * positions whose NEXT token is part of the completion (incl. the final EOS). + * Prompt tokens get mask=0, so the model is only graded on what it should write. + */ + prepareExample({ messages, prompt, completion, trainPromptToo = false }) { + const tk = this.tokenizer; + let promptIds; + if (messages) { + promptIds = tk.encode(formatMessages(tk, messages)); + } else { + promptIds = tk.encode(prompt); + } + const compIds = tk.encode(completion, { add_special_tokens: false }); + const tokens = [...promptIds, ...compIds, IM_END]; + const T = tokens.length; + const lossMask = new Array(T).fill(0); + const firstTrainPos = trainPromptToo ? 0 : Math.max(0, promptIds.length - 1); + for (let t = firstTrainPos; t < T - 1; t++) lossMask[t] = 1; + return { + tokens, + lossMask, + promptLength: promptIds.length, + completionLength: compIds.length, + firstTrainPos + }; + } + inspectExample(example) { + const prepared = this.prepareExample(example); + const { tokens, lossMask, promptLength, completionLength, firstTrainPos } = prepared; + const rows = tokens.map((id, index) => { + const targetId = index + 1 < tokens.length ? tokens[index + 1] : null; + const segment = index < promptLength ? "prompt" : index < promptLength + completionLength ? "completion" : "eos"; + return { + index, + id, + text: decodeToken(this.tokenizer, id), + segment, + trainsNext: !!lossMask[index], + targetId, + targetText: targetId == null ? "" : decodeToken(this.tokenizer, targetId) + }; + }); + return { + ...prepared, + trainPositions: lossMask.reduce((n, v) => n + (v ? 1 : 0), 0), + firstTrainPos, + rows + }; + } + prepareBatch(examples) { + return examples.map((e) => this.prepareExample(e)); + } + // One optimizer step over `microBatches` (array of {tokens, lossMask}); grads + // accumulate across them, then a single AdamW update is applied. + async step(microBatches) { + if (!this.trainer) throw new Error("call initAdapter()/attachAdapter() first"); + return this.trainer.trainStep(microBatches); + } + // Full training run over a dataset of examples. Honors gradAccumSteps by grouping + // examples into accumulation windows. Calls onStep({step, loss, lr, gradNorm}). + async train(examples, { epochs = 1, onStep = /* @__PURE__ */ __name(() => { + }, "onStep"), maxTrainSeq } = {}) { + if (!this.trainer) this.initAdapter(); + const accum = this.trainer.opts.gradAccumSteps; + const cap = maxTrainSeq ?? this.trainer.opts.maxTrainSeq; + let globalStep = 0; + for (let ep = 0; ep < epochs; ep++) { + const order = shuffle([...Array(examples.length).keys()]); + let window2 = []; + for (const idx of order) { + let mb = this.prepareExample(examples[idx]); + if (mb.tokens.length > cap) mb = truncate(mb, cap); + window2.push(mb); + if (window2.length === accum) { + const r = await this.step(window2); + globalStep++; + this.log(`step ${globalStep} epoch ${ep} loss=${r.loss.toFixed(4)} lr=${r.lr.toExponential(2)} |g|=${r.gradNorm.toFixed(3)}`); + onStep({ step: globalStep, epoch: ep, ...r }); + window2 = []; + } + } + if (window2.length) { + const r = await this.step(window2); + globalStep++; + onStep({ step: globalStep, epoch: ep, ...r }); + } + } + this.adapters.applyToRuntime(this.adapter.name, this.rt); + return { steps: globalStep, adapter: this.adapter }; + } +}; +function truncate(mb, cap) { + return { + ...mb, + tokens: mb.tokens.slice(0, cap), + lossMask: mb.lossMask.slice(0, cap) + }; +} +__name(truncate, "truncate"); +function decodeToken(tokenizer, id) { + try { + if (tokenizer?.decode) return tokenizer.decode([id], { skip_special_tokens: false }); + } catch { + } + return String(id); +} +__name(decodeToken, "decodeToken"); +function shuffle(a) { + for (let i = a.length - 1; i > 0; i--) { + const j = Math.floor(Math.random() * (i + 1)); + [a[i], a[j]] = [a[j], a[i]]; + } + return a; +} +__name(shuffle, "shuffle"); + +// src/lora_export.js +async function readBufferF32(dev, src, byteLen) { + const rb = dev.createBuffer({ size: byteLen, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }); + const enc = dev.createCommandEncoder(); + enc.copyBufferToBuffer(src, 0, rb, 0, byteLen); + dev.queue.submit([enc.finish()]); + await rb.mapAsync(GPUMapMode.READ); + const out = new Float32Array(rb.getMappedRange().slice(0)); + rb.unmap(); + rb.destroy(); + return out; +} +__name(readBufferF32, "readBufferF32"); +function transpose2d(arr, rows, cols) { + const o = new Float32Array(arr.length); + for (let r = 0; r < rows; r++) for (let c = 0; c < cols; c++) o[c * rows + r] = arr[r * cols + c]; + return o; +} +__name(transpose2d, "transpose2d"); +function buildSafetensors(tensors, metadata = { format: "pt" }) { + let offset = 0; + const header = {}; + if (metadata) header.__metadata__ = metadata; + for (const t of tensors) { + const bytes = t.data.byteLength; + header[t.name] = { dtype: "F32", shape: t.shape, data_offsets: [offset, offset + bytes] }; + offset += bytes; + } + let headerStr = JSON.stringify(header); + const enc = new TextEncoder(); + let headerBytes = enc.encode(headerStr); + const pad = (8 - headerBytes.length % 8) % 8; + if (pad) { + headerStr += " ".repeat(pad); + headerBytes = enc.encode(headerStr); + } + const total = 8 + headerBytes.length + offset; + const buf = new ArrayBuffer(total); + const dv = new DataView(buf); + dv.setBigUint64(0, BigInt(headerBytes.length), true); + new Uint8Array(buf, 8, headerBytes.length).set(headerBytes); + let p = 8 + headerBytes.length; + for (const t of tensors) { + new Uint8Array(buf, p, t.data.byteLength).set(new Uint8Array(t.data.buffer, t.data.byteOffset, t.data.byteLength)); + p += t.data.byteLength; + } + return new Uint8Array(buf); +} +__name(buildSafetensors, "buildSafetensors"); +async function exportLoraAdapter(trainer, opts = {}) { + const rt = trainer.rt; + const dev = rt.dev; + const tensors = []; + const targets = /* @__PURE__ */ new Set(); + const rankByKey = {}; + const alphaByKey = {}; + for (const key of trainer.trainedKeys) { + const st = trainer.state[key]; + const A = await readBufferF32(dev, st.mod.A, st.rank * st.K * 4); + const B = await readBufferF32(dev, st.mod.B, st.rank * st.N * 4); + const Bt = transpose2d(B, st.rank, st.N); + const base = `base_model.model.model.${key}`; + tensors.push({ name: `${base}.lora_A.weight`, shape: [st.rank, st.K], data: A }); + tensors.push({ name: `${base}.lora_B.weight`, shape: [st.N, st.rank], data: Bt }); + rankByKey[key] = st.rank; + alphaByKey[key] = st.scale * st.rank; + targets.add(key.split(".").pop()); + } + const safetensors = buildSafetensors(tensors); + const ranks = Object.values(rankByKey); + const alphas = Object.values(alphaByKey); + const r = opts.rank ?? mode(ranks) ?? 0; + const alpha = opts.alpha ?? mode(alphas) ?? 0; + const rankPattern = {}; + const alphaPattern = {}; + for (const key of Object.keys(rankByKey)) { + if (rankByKey[key] !== r) rankPattern[key] = rankByKey[key]; + if (alphaByKey[key] !== alpha) alphaPattern[key] = alphaByKey[key]; + } + const config = { + peft_type: "LORA", + auto_mapping: null, + base_model_name_or_path: opts.baseModel || "WeiboAI/VibeThinker-3B", + r, + lora_alpha: alpha, + target_modules: [...targets], + lora_dropout: 0, + bias: "none", + fan_in_fan_out: false, + inference_mode: true, + task_type: "CAUSAL_LM", + ...Object.keys(rankPattern).length ? { rank_pattern: rankPattern } : {}, + ...Object.keys(alphaPattern).length ? { alpha_pattern: alphaPattern } : {} + }; + const configJson = JSON.stringify(config, null, 2); + return { safetensors, config, configJson }; +} +__name(exportLoraAdapter, "exportLoraAdapter"); +function mode(arr) { + if (!arr.length) return void 0; + const counts = /* @__PURE__ */ new Map(); + let best = arr[0], bestN = 0; + for (const v of arr) { + const n = (counts.get(v) || 0) + 1; + counts.set(v, n); + if (n > bestN) { + bestN = n; + best = v; + } + } + return best; +} +__name(mode, "mode"); +async function downloadLoraAdapter(trainer, opts = {}) { + const { safetensors, configJson } = await exportLoraAdapter(trainer, opts); + const stem = opts.name || trainer.adapter?.name || "adapter"; + triggerDownload(new Blob([safetensors], { type: "application/octet-stream" }), `${stem}.safetensors`); + triggerDownload(new Blob([configJson], { type: "application/json" }), "adapter_config.json"); +} +__name(downloadLoraAdapter, "downloadLoraAdapter"); +function triggerDownload(blob, filename) { + if (typeof document === "undefined") return; + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = filename; + document.body.appendChild(a); + a.click(); + a.remove(); + setTimeout(() => URL.revokeObjectURL(url), 1e3); +} +__name(triggerDownload, "triggerDownload"); + +// src/lora_gpu.js +function parseSt(buf) { + const dv = new DataView(buf); + const hl = Number(dv.getBigUint64(0, true)); + const header = JSON.parse(new TextDecoder().decode(new Uint8Array(buf, 8, hl))); + return { header, dataStart: 8 + hl, u8: new Uint8Array(buf) }; +} +__name(parseSt, "parseSt"); +function bf16f32(u8, off, n) { + const u16 = new Uint16Array(u8.buffer, u8.byteOffset + off, n); + const o = new Float32Array(n); + const o32 = new Uint32Array(o.buffer); + for (let i = 0; i < n; i++) o32[i] = u16[i] << 16; + return o; +} +__name(bf16f32, "bf16f32"); +function f32(u8, off, n) { + return new Float32Array(u8.buffer.slice(u8.byteOffset + off, u8.byteOffset + off + n * 4)); +} +__name(f32, "f32"); +function readTensor(st, name) { + const t = st.header[name]; + const n = t.shape.reduce((a, b) => a * b, 1); + const dt = t.dtype.toUpperCase(); + const arr = dt === "BF16" ? bf16f32(st.u8, st.dataStart + t.data_offsets[0], n) : f32(st.u8, st.dataStart + t.data_offsets[0], n); + return { arr, shape: t.shape }; +} +__name(readTensor, "readTensor"); +var isA = /* @__PURE__ */ __name((name) => /lora_a/i.test(name), "isA"); +function transpose2d2(arr, rows, cols) { + const o = new Float32Array(arr.length); + for (let r = 0; r < rows; r++) for (let c = 0; c < cols; c++) o[c * rows + r] = arr[r * cols + c]; + return o; +} +__name(transpose2d2, "transpose2d"); +async function loadLoraAdapterGPU(dev, files, cfg) { + const stFile = files.find((f) => f.name.endsWith(".safetensors")); + if (!stFile) throw new Error("no .safetensors in adapter files"); + const cfgFile = files.find((f) => /adapter_config\.json|config\.json/.test(f.name)); + let rankCfg = 16, scaleCfg = null; + if (cfgFile) { + const c = JSON.parse(await cfgFile.text()); + const lp = c.lora_parameters || {}; + rankCfg = c.r ?? c.rank ?? c.lora_rank ?? lp.rank ?? rankCfg; + if (lp.scale != null) + scaleCfg = lp.scale; + else if (c.lora_alpha != null) + scaleCfg = c.lora_alpha / rankCfg; + else if (c.alpha != null) scaleCfg = c.alpha / rankCfg; + } + const st = parseSt(await stFile.arrayBuffer()); + const names = Object.keys(st.header).filter((k) => k !== "__metadata__" && /lora_[abAB]/.test(k)); + const groups = {}; + for (const nm of names) { + const key = moduleKeyFromTensorName(nm); + if (!key) continue; + (groups[key] ||= {})[isA(nm) ? "A" : "B"] = readTensor(st, nm); + } + const S = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST; + const mk = /* @__PURE__ */ __name((arr) => { + const b = dev.createBuffer({ size: arr.byteLength, usage: S }); + dev.queue.writeBuffer(b, 0, arr); + return b; + }, "mk"); + const modules = {}; + for (const key of Object.keys(groups)) { + const g = groups[key]; + if (!g.A || !g.B) continue; + const r = Math.min(...g.A.shape, ...g.B.shape); + let Aarr = g.A.arr; + if (g.A.shape[0] !== r) Aarr = transpose2d2(g.A.arr, g.A.shape[0], g.A.shape[1]); + let Barr = g.B.arr; + if (g.B.shape[0] !== r) Barr = transpose2d2(g.B.arr, g.B.shape[0], g.B.shape[1]); + const scale = scaleCfg != null ? scaleCfg : 2; + modules[key] = { A: mk(Aarr), B: mk(Barr), rawA: Aarr, rawB: Barr, rank: r, scale }; + } + if (!Object.keys(modules).length) throw new Error("no LoRA modules matched layers.*.{self_attn,mlp}.*_proj"); + const name = stFile.name.replace(/\.safetensors$/, ""); + return { name, modules }; +} +__name(loadLoraAdapterGPU, "loadLoraAdapterGPU"); + +// src/services/store.js +var store_exports = {}; +__export(store_exports, { + connectDirectory: () => connectDirectory, + deleteRun: () => deleteRun, + ensurePermission: () => ensurePermission, + forgetDirectory: () => forgetDirectory, + fsSupported: () => fsSupported, + getRun: () => getRun, + getRunBlobs: () => getRunBlobs, + listRuns: () => listRuns, + loadRunFiles: () => loadRunFiles, + newId: () => newId, + readDirText: () => readDirText, + saveRun: () => saveRun, + savedDirectory: () => savedDirectory, + writeFileToDir: () => writeFileToDir +}); +var LS_KEY = "emberglass.history.v2"; +var DB_NAME = "emberglass"; +var DB_VERSION = 1; +var BLOB_STORE = "adapters"; +var HANDLE_STORE = "handles"; +var _dbp = null; +function db() { + if (_dbp) return _dbp; + _dbp = new Promise((resolve, reject) => { + const r = indexedDB.open(DB_NAME, DB_VERSION); + r.onupgradeneeded = () => { + const d = r.result; + if (!d.objectStoreNames.contains(BLOB_STORE)) d.createObjectStore(BLOB_STORE); + if (!d.objectStoreNames.contains(HANDLE_STORE)) d.createObjectStore(HANDLE_STORE); + }; + r.onsuccess = () => resolve(r.result); + r.onerror = () => reject(r.error); + }); + return _dbp; +} +__name(db, "db"); +async function idbPut(store, key, val) { + const d = await db(); + return new Promise((res, rej) => { + const tx = d.transaction(store, "readwrite"); + tx.objectStore(store).put(val, key); + tx.oncomplete = () => res(); + tx.onerror = () => rej(tx.error); + }); +} +__name(idbPut, "idbPut"); +async function idbGet(store, key) { + const d = await db(); + return new Promise((res, rej) => { + const tx = d.transaction(store, "readonly"); + const rq = tx.objectStore(store).get(key); + rq.onsuccess = () => res(rq.result); + rq.onerror = () => rej(rq.error); + }); +} +__name(idbGet, "idbGet"); +async function idbDel(store, key) { + const d = await db(); + return new Promise((res, rej) => { + const tx = d.transaction(store, "readwrite"); + tx.objectStore(store).delete(key); + tx.oncomplete = () => res(); + tx.onerror = () => rej(tx.error); + }); +} +__name(idbDel, "idbDel"); +function listRuns() { + try { + const a = JSON.parse(localStorage.getItem(LS_KEY) || "[]"); + return Array.isArray(a) ? a : []; + } catch { + return []; + } +} +__name(listRuns, "listRuns"); +function writeIndex(arr) { + try { + localStorage.setItem(LS_KEY, JSON.stringify(arr)); + } catch (e) { + console.warn("[store] localStorage write failed", e); + } +} +__name(writeIndex, "writeIndex"); +function getRun(id) { + return listRuns().find((r) => r.id === id) || null; +} +__name(getRun, "getRun"); +function newId() { + return "run_" + Date.now().toString(36) + "_" + Math.random().toString(36).slice(2, 7); +} +__name(newId, "newId"); +async function saveRun(meta, files) { + const stBytes = files.safetensors instanceof Uint8Array ? files.safetensors : new Uint8Array(files.safetensors); + await idbPut(BLOB_STORE, meta.id, { + safetensors: new Blob([stBytes], { type: "application/octet-stream" }), + configJson: files.configJson || "{}" + }); + const idx = listRuns().filter((r) => r.id !== meta.id); + idx.unshift(meta); + writeIndex(idx); + return meta; +} +__name(saveRun, "saveRun"); +async function deleteRun(id) { + writeIndex(listRuns().filter((r) => r.id !== id)); + try { + await idbDel(BLOB_STORE, id); + } catch { + } +} +__name(deleteRun, "deleteRun"); +async function loadRunFiles(id) { + const rec = await idbGet(BLOB_STORE, id); + if (!rec) throw new Error("adapter blob missing for " + id); + const meta = getRun(id); + const stem = (meta?.name || id).replace(/[^\w.-]+/g, "_"); + return [ + new File([rec.safetensors], `${stem}.safetensors`, { type: "application/octet-stream" }), + new File([rec.configJson], "adapter_config.json", { type: "application/json" }) + ]; +} +__name(loadRunFiles, "loadRunFiles"); +async function getRunBlobs(id) { + const rec = await idbGet(BLOB_STORE, id); + if (!rec) throw new Error("adapter blob missing for " + id); + return { safetensors: rec.safetensors, configJson: rec.configJson }; +} +__name(getRunBlobs, "getRunBlobs"); +var fsSupported = typeof window !== "undefined" && "showDirectoryPicker" in window; +async function connectDirectory() { + if (!fsSupported) throw new Error("File System Access API not available in this browser"); + const handle = await window.showDirectoryPicker({ id: "emberglass", mode: "readwrite" }); + await idbPut(HANDLE_STORE, "dir", handle); + return handle; +} +__name(connectDirectory, "connectDirectory"); +async function savedDirectory() { + if (!fsSupported) return null; + try { + return await idbGet(HANDLE_STORE, "dir") || null; + } catch { + return null; + } +} +__name(savedDirectory, "savedDirectory"); +async function forgetDirectory() { + try { + await idbDel(HANDLE_STORE, "dir"); + } catch { + } +} +__name(forgetDirectory, "forgetDirectory"); +async function ensurePermission(handle, mode2 = "readwrite") { + if (!handle) return false; + const opts = { mode: mode2 }; + if (await handle.queryPermission(opts) === "granted") return true; + return await handle.requestPermission(opts) === "granted"; +} +__name(ensurePermission, "ensurePermission"); +async function readDirText(handle, { exts = ["txt", "md", "json", "csv"], maxChars = 2e5 } = {}) { + let out = ""; + const names = []; + for await (const [name, h] of handle.entries()) { + if (h.kind !== "file") continue; + const ext = name.split(".").pop().toLowerCase(); + if (!exts.includes(ext)) continue; + try { + const f = await h.getFile(); + out += ` + +# ${name} +` + await f.text(); + names.push(name); + if (out.length > maxChars) break; + } catch { + } + } + return { text: out.slice(0, maxChars), names }; +} +__name(readDirText, "readDirText"); +async function writeFileToDir(handle, name, data) { + const fh = await handle.getFileHandle(name, { create: true }); + const w = await fh.createWritable(); + await w.write(data); + await w.close(); +} +__name(writeFileToDir, "writeFileToDir"); + +// src/skills.js +function specSig(spec) { + return spec.ops.map((o) => `${o.name}(${(o.params || []).join(", ")})${o.ret ? " -> " + o.ret : ""}`).join("; "); +} +__name(specSig, "specSig"); +function skillSystem(domain, spec) { + return `You are ${domain}. Convert the request into a macro using ONLY these operations: +` + specSig(spec) + `. +Output ONLY the macro, one call per line, no prose. If the request is outside ${spec.scope}, output exactly: OUT_OF_SCOPE.`; +} +__name(skillSystem, "skillSystem"); +function parseMacroCalls(text) { + const out = []; + for (const raw of String(text).split("\n")) { + const line = raw.trim(); + if (!line || line === "OUT_OF_SCOPE") continue; + const m = line.match(/^(?:[A-Za-z_]\w*\s*=\s*)?([A-Za-z_]\w*)\s*\((.*)\)\s*;?\s*$/); + if (!m) continue; + const keys = [...m[2].matchAll(/(?:^|,)\s*([A-Za-z_]\w*)\s*=/g)].map((k) => k[1]); + out.push({ op: m[1], keys }); + } + return out; +} +__name(parseMacroCalls, "parseMacroCalls"); +function verifyMacro(text, spec) { + const t = String(text); + const calls = parseMacroCalls(t); + const bounced = /(^|\n)\s*OUT_OF_SCOPE\s*($|\n)/.test(t) && calls.length === 0; + if (bounced) return { status: "oos", calls: [], issues: [], n: 0 }; + if (!calls.length) return { status: "empty", calls: [], issues: [], n: 0 }; + const byName = new Map(spec.ops.map((o) => [o.name, o])); + const issues = []; + const detail = []; + for (const c of calls) { + const op = byName.get(c.op); + if (!op) { + issues.push(`unknown op: ${c.op}`); + detail.push({ op: c.op, ok: false }); + continue; + } + const allowed = new Set(op.params || []); + const bad = c.keys.filter((k) => !allowed.has(k)); + if (bad.length) { + issues.push(`${c.op}: unexpected arg ${bad.join(", ")}`); + detail.push({ op: c.op, ok: false }); + } else detail.push({ op: c.op, ok: true }); + } + return { status: issues.length ? "bad" : "ok", calls: detail, issues, n: calls.length }; +} +__name(verifyMacro, "verifyMacro"); +function hashStr(s) { + let h = 2166136261; + for (let i = 0; i < s.length; i++) { + h ^= s.charCodeAt(i); + h = Math.imul(h, 16777619); + } + return h >>> 0; +} +__name(hashStr, "hashStr"); +function mulberry32(a) { + return function() { + a |= 0; + a = a + 1831565813 | 0; + let t = Math.imul(a ^ a >>> 15, 1 | a); + t = t + Math.imul(t ^ t >>> 7, 61 | t) ^ t; + return ((t ^ t >>> 14) >>> 0) / 4294967296; + }; +} +__name(mulberry32, "mulberry32"); +function fill(tpl, choice) { + return tpl.replace(/\{(\w+)\}/g, (_, k) => k in choice ? choice[k] : "{" + k + "}"); +} +__name(fill, "fill"); +function expand(def, perTemplate) { + const rnd = mulberry32(hashStr(def.key)); + const out = []; + const seen = /* @__PURE__ */ new Set(); + for (const t of def.templates || []) { + const slots = [...new Set([...t.req.matchAll(/\{(\w+)\}/g)].map((m) => m[1]))]; + let made = 0, tries = 0; + const cap = perTemplate * 8; + while (made < perTemplate && tries < cap) { + tries++; + const choice = {}; + for (const s of slots) { + const arr = def.vocab[s] || ["x"]; + choice[s] = arr[Math.floor(rnd() * arr.length)]; + } + const req = fill(t.req, choice); + if (seen.has(req)) continue; + seen.add(req); + out.push([req, fill(t.macro, choice)]); + made++; + } + } + return out; +} +__name(expand, "expand"); +function buildSkill(def, perTemplate = 6) { + const spec = { scope: def.scope, ops: def.ops }; + const examples = [ + ...def.fixed || [], + ...expand(def, perTemplate), + ...(def.oos || []).map((r) => [r, "OUT_OF_SCOPE"]) + ]; + return { + key: def.key, + label: def.label, + icon: def.icon, + desc: def.desc, + domain: def.domain, + spec, + system: skillSystem(def.domain, spec), + suggest: def.suggest, + examples + }; +} +__name(buildSkill, "buildSkill"); +var PEOPLE = ["mom", "Sarah", "Alex", "the design team", "my manager", "Priya", "John", "the landlord", "accounting", "Dana"]; +var TOPICS = ["the Q3 roadmap", "the launch", "the budget", "onboarding", "the API redesign", "the offsite", "the bug report", "the contract"]; +var WHENS = ["today 17:00", "tomorrow 09:00", "Friday 14:00", "next Monday 10:00", "Thursday 16:30", "tonight 19:00"]; +var DEFS = [ + { + key: "inbox-calendar", + label: "Inbox & Calendar", + icon: "\u2709", + domain: "an Inbox & Calendar operator", + scope: "inbox or calendar", + desc: "Compiles requests like \u201Cemail my mom and book a reminder to respond\u201D into a verifiable macro over a fixed set of inbox/calendar actions; bounces anything else.", + suggest: "Email the design team this week's notes, then put a 30-minute review on my calendar for Monday morning.", + ops: [ + { name: "find_email", params: ["query"], ret: "thread" }, + { name: "compose_email", params: ["to", "subject", "body"] }, + { name: "reply_email", params: ["thread", "body"] }, + { name: "forward_email", params: ["thread", "to", "note"] }, + { name: "archive_email", params: ["thread"] }, + { name: "label_email", params: ["thread", "label"] }, + { name: "schedule_send", params: ["to", "subject", "body", "when"] }, + { name: "create_event", params: ["title", "start", "end", "remind_min"] }, + { name: "set_reminder", params: ["text", "when"] }, + { name: "find_slot", params: ["duration_min", "after", "before"], ret: "slot" }, + { name: "rsvp", params: ["event", "response"] } + ], + fixed: [ + [ + "email my mom and book a calendar event to remind me to respond", + 'compose_email(to="mom", subject="Hi mom", body="Just checking in \u2014 talk soon!")\ncreate_event(title="Respond to mom", start="tomorrow 09:00", end="tomorrow 09:15", remind_min=10)' + ], + [ + "schedule a 30 minute focus block tomorrow afternoon", + 's = find_slot(duration_min=30, after="tomorrow 13:00", before="tomorrow 18:00")\ncreate_event(title="Focus block", start=s.start, end=s.end, remind_min=5)' + ], + [ + "reply yes to the standup invite and add it to my calendar", + 't = find_email(query="standup invite")\nrsvp(event=t, response="yes")' + ] + ], + templates: [ + { req: "email {person} about {topic}", macro: 'compose_email(to="{person}", subject="{topic}", body="Quick note about {topic}.")' }, + { req: "remind me to follow up on {topic} {when}", macro: 'set_reminder(text="Follow up on {topic}", when="{when}")' }, + { req: "find the email from {person} and reply that I will review it by {when}", macro: 't = find_email(query="from:{person}")\nreply_email(thread=t, body="Thanks \u2014 I will review this by {when}.")' }, + { req: "archive the emails about {topic}", macro: 't = find_email(query="{topic}")\narchive_email(thread=t)' }, + { req: "forward the {topic} email to {person}", macro: 't = find_email(query="{topic}")\nforward_email(thread=t, to="{person}", note="FYI \u2014 for your records.")' }, + { req: "label the email from {person} as {label}", macro: 't = find_email(query="from:{person}")\nlabel_email(thread=t, label="{label}")' }, + { req: "send {person} a note {when} saying thanks for {topic}", macro: 'schedule_send(to="{person}", subject="Thank you", body="Thanks for {topic}.", when="{when}")' }, + { req: "set up a meeting about {topic} with {person} {when} for 30 minutes", macro: 'create_event(title="{topic} with {person}", start="{when}", end="{when}", remind_min=10)' }, + { req: "find a 45 minute slot {when} and book {topic}", macro: 's = find_slot(duration_min=45, after="{when}", before="{when}")\ncreate_event(title="{topic}", start=s.start, end=s.end, remind_min=10)' } + ], + vocab: { person: PEOPLE, topic: TOPICS, when: WHENS, label: ["housing", "urgent", "finance", "travel", "follow-up", "receipts"] }, + oos: ["order me a pizza", "what is the capital of France?", "play some jazz"] + }, + { + key: "music", + label: "Music", + icon: "\u266A", + domain: "a music player operator", + scope: "music playback", + desc: "Turns \u201Cplay some lo-fi and turn it down\u201D into a macro over a music action space \u2014 find/play/queue/volume/playlist \u2014 and bounces non-music asks.", + suggest: "Play something upbeat for cooking and add it to a new playlist called Dinner.", + ops: [ + { name: "find_track", params: ["query"], ret: "track" }, + { name: "play_track", params: ["track"] }, + { name: "queue_track", params: ["track"] }, + { name: "pause", params: [] }, + { name: "skip", params: [] }, + { name: "previous", params: [] }, + { name: "set_volume", params: ["level"] }, + { name: "create_playlist", params: ["name"] }, + { name: "add_to_playlist", params: ["playlist", "track"] }, + { name: "shuffle", params: ["on"] }, + { name: "repeat", params: ["mode"] } + ], + fixed: [ + ["skip this song", "skip()"], + ["pause the music", "pause()"], + ["go back to the previous song", "previous()"] + ], + templates: [ + { req: "play some {genre}", macro: 't = find_track(query="{genre}")\nplay_track(track=t)' }, + { req: "queue up {artist} after this", macro: 't = find_track(query="{artist}")\nqueue_track(track=t)' }, + { req: "set the volume to {vol}", macro: "set_volume(level={vol})" }, + { req: "make a playlist called {name}", macro: 'create_playlist(name="{name}")' }, + { req: "add {artist} to my {name} playlist", macro: 't = find_track(query="{artist}")\nadd_to_playlist(playlist="{name}", track=t)' }, + { req: "shuffle my {name} playlist", macro: 'shuffle(on=true)\nt = find_track(query="playlist:{name}")\nplay_track(track=t)' }, + { req: "put on {artist} and turn it up", macro: 't = find_track(query="{artist}")\nplay_track(track=t)\nset_volume(level=80)' }, + { req: "repeat this {mode}", macro: 'repeat(mode="{mode}")' } + ], + vocab: { + genre: ["lo-fi beats", "deep house", "classic jazz", "pop hits", "ambient", "classical", "90s hip hop", "indie rock"], + artist: ["Taylor Swift", "The Beatles", "Daft Punk", "Miles Davis", "Radiohead", "Bad Bunny", "Fleetwood Mac"], + name: ["Focus", "Workout", "Dinner", "Chill", "Road Trip", "Sleep"], + vol: ["10", "25", "40", "60", "75", "90"], + mode: ["one", "all"] + }, + oos: ["email my boss", "what is the weather today?", "open an issue on the repo"] + }, + { + key: "github", + label: "GitHub", + icon: "\u{1F419}", + domain: "a GitHub operator", + scope: "GitHub repositories, issues, and pull requests", + desc: "Compiles dev requests into a macro over issues, pull requests, and repos; bounces anything that isn\u2019t GitHub.", + suggest: 'Open an issue on the api repo titled "fix login redirect", then assign it to Dana.', + ops: [ + { name: "find_issue", params: ["query"], ret: "issue" }, + { name: "create_issue", params: ["repo", "title", "body"] }, + { name: "comment_issue", params: ["issue", "body"] }, + { name: "close_issue", params: ["issue"] }, + { name: "assign_issue", params: ["issue", "assignee"] }, + { name: "label_issue", params: ["issue", "label"] }, + { name: "find_pr", params: ["query"], ret: "pr" }, + { name: "open_pr", params: ["repo", "title", "branch"] }, + { name: "review_pr", params: ["pr", "verdict"] }, + { name: "merge_pr", params: ["pr"] }, + { name: "create_repo", params: ["name", "visibility"] }, + { name: "star_repo", params: ["repo"] } + ], + fixed: [ + [ + "open an issue on the api repo titled fix login redirect and assign it to Dana", + 'i = create_issue(repo="api", title="fix login redirect", body="The login flow redirects to the wrong page.")\nassign_issue(issue=i, assignee="Dana")' + ] + ], + templates: [ + { req: "open an issue on {repo} titled {title}", macro: 'create_issue(repo="{repo}", title="{title}", body="{title}.")' }, + { req: "close the {topic} issue", macro: 'i = find_issue(query="{topic}")\nclose_issue(issue=i)' }, + { req: "comment {comment} on the {topic} issue", macro: 'i = find_issue(query="{topic}")\ncomment_issue(issue=i, body="{comment}")' }, + { req: "assign the {topic} issue to {user}", macro: 'i = find_issue(query="{topic}")\nassign_issue(issue=i, assignee="{user}")' }, + { req: "label the {topic} issue as {label}", macro: 'i = find_issue(query="{topic}")\nlabel_issue(issue=i, label="{label}")' }, + { req: "open a pull request on {repo} from {branch} titled {title}", macro: 'open_pr(repo="{repo}", title="{title}", branch="{branch}")' }, + { req: "approve the {topic} pull request", macro: 'p = find_pr(query="{topic}")\nreview_pr(pr=p, verdict="approve")' }, + { req: "merge the {topic} PR", macro: 'p = find_pr(query="{topic}")\nmerge_pr(pr=p)' }, + { req: "create a private repo called {repo}", macro: 'create_repo(name="{repo}", visibility="private")' }, + { req: "star the {repo} repo", macro: 'star_repo(repo="{repo}")' } + ], + vocab: { + repo: ["api", "frontend", "docs", "infra", "mobile-app", "design-system"], + title: ["fix login redirect", "add dark mode", "update README", "flaky test fix", "bump dependencies", "improve error logs"], + topic: ["login", "dark mode", "flaky test", "memory leak", "rate limiting", "docs typo"], + comment: ["looks good to me", "can you add a test?", "I will pick this up", "reproduced on main", "duplicate of #42"], + user: ["Dana", "Alex", "Priya", "the on-call", "Sam"], + label: ["bug", "enhancement", "good first issue", "p1", "docs", "wontfix"], + branch: ["feature/auth", "fix/cache", "chore/deps", "feat/ui", "hotfix/crash"] + }, + oos: ["play some music", "email my mom", "what is 2 + 2?"] + }, + { + key: "slack", + label: "Slack", + icon: "\u{1F4AC}", + domain: "a Slack operator", + scope: "Slack messaging", + desc: "Compiles team-chat requests into a macro over channels, DMs, threads, and reminders; bounces non-Slack asks.", + suggest: "Post the release notes in #launch and DM Dana to review them.", + ops: [ + { name: "find_message", params: ["query"], ret: "message" }, + { name: "send_message", params: ["channel", "text"] }, + { name: "dm", params: ["user", "text"] }, + { name: "reply_thread", params: ["message", "text"] }, + { name: "react", params: ["message", "emoji"] }, + { name: "set_status", params: ["text", "emoji"] }, + { name: "create_channel", params: ["name"] }, + { name: "invite", params: ["user", "channel"] }, + { name: "remind", params: ["text", "when"] }, + { name: "pin", params: ["message"] } + ], + fixed: [ + [ + "post the release notes in #launch and dm Dana to review them", + 'send_message(channel="launch", text="Release notes are up \u2014 please review.")\ndm(user="Dana", text="Can you review the release notes I posted in #launch?")' + ] + ], + templates: [ + { req: "post {text} in #{channel}", macro: 'send_message(channel="{channel}", text="{text}")' }, + { req: "dm {user} {text}", macro: 'dm(user="{user}", text="{text}")' }, + { req: "reply {text} to the {topic} thread", macro: 'm = find_message(query="{topic}")\nreply_thread(message=m, text="{text}")' }, + { req: "react {emoji} to the {topic} message", macro: 'm = find_message(query="{topic}")\nreact(message=m, emoji="{emoji}")' }, + { req: "set my status to {text}", macro: 'set_status(text="{text}", emoji="{emoji}")' }, + { req: "create a channel called {channel}", macro: 'create_channel(name="{channel}")' }, + { req: "invite {user} to #{channel}", macro: 'invite(user="{user}", channel="{channel}")' }, + { req: "remind the team to {task} {when}", macro: 'remind(text="{task}", when="{when}")' }, + { req: "pin the {topic} message", macro: 'm = find_message(query="{topic}")\npin(message=m)' } + ], + vocab: { + channel: ["launch", "general", "engineering", "design", "random", "incidents"], + user: ["Dana", "Alex", "Priya", "Sam", "the team lead"], + text: ["standup in 5", "PR is ready for review", "deploy is green", "lunch at noon?", "great work today"], + topic: ["deploy", "incident", "roadmap", "lunch", "release"], + emoji: [":eyes:", ":white_check_mark:", ":tada:", ":fire:", ":+1:"], + task: ["submit timesheets", "join the retro", "review the doc", "update the board"], + when: WHENS + }, + oos: ["play a song", "order groceries", "what time is it in Tokyo?"] + }, + { + key: "notion", + label: "Notion", + icon: "\u{1F4DD}", + domain: "a Notion operator", + scope: "Notion pages, notes, and tasks", + desc: "Compiles note-taking requests into a macro over pages, blocks, tasks, and databases; bounces anything else.", + suggest: 'Create a page titled "Trip plan" and add a task to book flights due Friday.', + ops: [ + { name: "find_page", params: ["query"], ret: "page" }, + { name: "create_page", params: ["title", "body"] }, + { name: "append_block", params: ["page", "text"] }, + { name: "create_task", params: ["title", "due"] }, + { name: "complete_task", params: ["task"] }, + { name: "find_task", params: ["query"], ret: "task" }, + { name: "add_to_database", params: ["database", "name"] }, + { name: "set_property", params: ["page", "key", "value"] }, + { name: "create_database", params: ["name"] } + ], + fixed: [ + [ + "create a page titled Trip plan and add a task to book flights due Friday", + 'create_page(title="Trip plan", body="Planning notes.")\ncreate_task(title="Book flights", due="Friday")' + ] + ], + templates: [ + { req: "create a page titled {title}", macro: 'create_page(title="{title}", body="{title} \u2014 notes.")' }, + { req: "add a note {text} to the {topic} page", macro: 'p = find_page(query="{topic}")\nappend_block(page=p, text="{text}")' }, + { req: "add a task to {task} due {when}", macro: 'create_task(title="{task}", due="{when}")' }, + { req: "mark the {task} task done", macro: 't = find_task(query="{task}")\ncomplete_task(task=t)' }, + { req: "add {name} to my {database} database", macro: 'add_to_database(database="{database}", name="{name}")' }, + { req: "set the status of the {topic} page to {value}", macro: 'p = find_page(query="{topic}")\nset_property(page=p, key="status", value="{value}")' }, + { req: "create a database called {database}", macro: 'create_database(name="{database}")' } + ], + vocab: { + title: ["Trip plan", "Q3 goals", "Reading list", "Meeting notes", "Project brief", "Recipes"], + text: ["remember to confirm the budget", "add the agenda", "link the spec", "note the blockers"], + topic: ["trip", "goals", "project", "meeting", "reading"], + task: ["book flights", "draft the brief", "email the vendor", "review the PR", "pay the invoice"], + when: ["today", "tomorrow", "Friday", "next week", "end of month"], + name: ["Acme Co", "Q3 launch", "Vendor X", "Idea: dark mode"], + database: ["Projects", "CRM", "Tasks", "Reading", "Inventory"], + value: ["in progress", "done", "blocked", "todo", "review"] + }, + oos: ["play music", "navigate home", "send a tweet"] + }, + { + key: "x", + label: "X", + icon: "\u{1D54F}", + domain: "an X (Twitter) operator", + scope: "posting and engagement on X", + desc: "Compiles social requests into a macro over posts, replies, reposts, follows, and DMs; bounces anything off-platform.", + suggest: 'Post "shipping something fun today \u{1F680}" and schedule a follow-up for 5pm.', + ops: [ + { name: "find_post", params: ["query"], ret: "post" }, + { name: "post", params: ["text"] }, + { name: "reply", params: ["post", "text"] }, + { name: "repost", params: ["post"] }, + { name: "like", params: ["post"] }, + { name: "follow", params: ["user"] }, + { name: "dm", params: ["user", "text"] }, + { name: "schedule_post", params: ["text", "when"] }, + { name: "bookmark", params: ["post"] } + ], + fixed: [ + [ + "post shipping something fun today and schedule a follow up for 5pm", + 'post(text="shipping something fun today \u{1F680}")\nschedule_post(text="more details soon \u2014 stay tuned", when="today 17:00")' + ] + ], + templates: [ + { req: "post {text}", macro: 'post(text="{text}")' }, + { req: "reply {text} to the {topic} post", macro: 'p = find_post(query="{topic}")\nreply(post=p, text="{text}")' }, + { req: "repost the {topic} tweet", macro: 'p = find_post(query="{topic}")\nrepost(post=p)' }, + { req: "like the {topic} post", macro: 'p = find_post(query="{topic}")\nlike(post=p)' }, + { req: "follow {user}", macro: 'follow(user="{user}")' }, + { req: "dm {user} {text}", macro: 'dm(user="{user}", text="{text}")' }, + { req: "schedule a post {when} saying {text}", macro: 'schedule_post(text="{text}", when="{when}")' }, + { req: "bookmark the {topic} thread", macro: 'p = find_post(query="{topic}")\nbookmark(post=p)' } + ], + vocab: { + text: ["gm", "big news coming", "loved this talk", "hot take: tabs > spaces", "thanks for 10k followers"], + topic: ["the launch", "the keynote", "the meme", "the thread on AI", "the announcement"], + user: ["@levelsio", "@naval", "@swyx", "@dhh", "@karpathy"], + when: WHENS + }, + oos: ["archive my inbox", "play a playlist", "open a GitHub issue"] + }, + { + key: "instagram", + label: "Instagram", + icon: "\u{1F4F7}", + domain: "an Instagram operator", + scope: "Instagram posts, stories, and DMs", + desc: "Compiles requests into a macro over photo posts, stories, comments, and DMs; bounces anything off-platform.", + suggest: 'Post a photo with caption "sunset run \u{1F305}" and share it to my story.', + ops: [ + { name: "find_post", params: ["query"], ret: "post" }, + { name: "post_photo", params: ["caption", "media"] }, + { name: "post_story", params: ["media"] }, + { name: "reply_dm", params: ["user", "text"] }, + { name: "like_post", params: ["post"] }, + { name: "comment", params: ["post", "text"] }, + { name: "follow", params: ["user"] }, + { name: "save_post", params: ["post"] } + ], + fixed: [ + [ + "post a photo with caption sunset run and share it to my story", + 'post_photo(caption="sunset run \u{1F305}", media="latest")\npost_story(media="latest")' + ] + ], + templates: [ + { req: "post a photo with caption {caption}", macro: 'post_photo(caption="{caption}", media="latest")' }, + { req: "share {media} to my story", macro: 'post_story(media="{media}")' }, + { req: "comment {text} on the {topic} post", macro: 'p = find_post(query="{topic}")\ncomment(post=p, text="{text}")' }, + { req: "like the {topic} post", macro: 'p = find_post(query="{topic}")\nlike_post(post=p)' }, + { req: "reply {text} to {user} in DMs", macro: 'reply_dm(user="{user}", text="{text}")' }, + { req: "follow {user}", macro: 'follow(user="{user}")' }, + { req: "save the {topic} post", macro: 'p = find_post(query="{topic}")\nsave_post(post=p)' } + ], + vocab: { + caption: ["sunset run \u{1F305}", "weekend vibes", "new kicks \u{1F45F}", "homemade pasta \u{1F35D}", "trail day"], + media: ["latest", "the beach photo", "the reel", "the carousel"], + text: ["love this!", "where is this?", "so good \u{1F525}", "congrats!", "need the recipe"], + topic: ["the travel", "the food", "the fit check", "the puppy", "the launch"], + user: ["@natgeo", "@nike", "@a_friend", "@the_chef"] + }, + oos: ["merge the pull request", "set a reminder", "navigate to work"] + }, + { + key: "youtube", + label: "YouTube", + icon: "\u25B6", + domain: "a YouTube operator", + scope: "YouTube playback and library", + desc: "Compiles requests into a macro over search, playback, playlists, and subscriptions; bounces anything else.", + suggest: "Play a 10-minute beginner yoga video and add it to my Morning playlist.", + ops: [ + { name: "find_video", params: ["query"], ret: "video" }, + { name: "play_video", params: ["video"] }, + { name: "queue_video", params: ["video"] }, + { name: "subscribe", params: ["channel"] }, + { name: "like_video", params: ["video"] }, + { name: "add_to_playlist", params: ["playlist", "video"] }, + { name: "create_playlist", params: ["name"] }, + { name: "comment", params: ["video", "text"] } + ], + fixed: [ + [ + "play a beginner yoga video and add it to my Morning playlist", + 'v = find_video(query="beginner yoga 10 minutes")\nplay_video(video=v)\nadd_to_playlist(playlist="Morning", video=v)' + ] + ], + templates: [ + { req: "play a video about {query}", macro: 'v = find_video(query="{query}")\nplay_video(video=v)' }, + { req: "queue a video about {query}", macro: 'v = find_video(query="{query}")\nqueue_video(video=v)' }, + { req: "subscribe to {channel}", macro: 'subscribe(channel="{channel}")' }, + { req: "like the {query} video", macro: 'v = find_video(query="{query}")\nlike_video(video=v)' }, + { req: "add a {query} video to my {name} playlist", macro: 'v = find_video(query="{query}")\nadd_to_playlist(playlist="{name}", video=v)' }, + { req: "make a playlist called {name}", macro: 'create_playlist(name="{name}")' }, + { req: "comment {text} on the {query} video", macro: 'v = find_video(query="{query}")\ncomment(video=v, text="{text}")' } + ], + vocab: { + query: ["lo-fi study mix", "rust tutorial", "marathon training", "pasta recipe", "guitar lesson", "space documentary"], + channel: ["Veritasium", "Fireship", "MKBHD", "Kurzgesagt", "NileRed"], + name: ["Morning", "Watch Later", "Cooking", "Workouts", "Learning"], + text: ["great explanation!", "first", "this helped a lot", "please do a part 2"] + }, + oos: ["email the team", "open a PR", "set my Slack status"] + }, + { + key: "maps", + label: "Maps", + icon: "\u{1F4CD}", + domain: "a Maps operator", + scope: "navigation and places", + desc: "Compiles requests into a macro over places, directions, and navigation; bounces anything off-map.", + suggest: "Find the nearest coffee shop and start navigation, then share my ETA with Alex.", + ops: [ + { name: "search_place", params: ["query"], ret: "place" }, + { name: "find_nearby", params: ["category"], ret: "place" }, + { name: "directions", params: ["to", "mode"] }, + { name: "start_navigation", params: ["place"] }, + { name: "save_place", params: ["place", "list"] }, + { name: "share_eta", params: ["place", "contact"] } + ], + fixed: [ + [ + "find the nearest coffee shop and start navigation then share my eta with Alex", + 'p = find_nearby(category="coffee shop")\nstart_navigation(place=p)\nshare_eta(place=p, contact="Alex")' + ] + ], + templates: [ + { req: "navigate to {place}", macro: 'p = search_place(query="{place}")\nstart_navigation(place=p)' }, + { req: "directions to {place} by {mode}", macro: 'directions(to="{place}", mode="{mode}")' }, + { req: "find a {category} near me", macro: 'find_nearby(category="{category}")' }, + { req: "find the nearest {category} and navigate there", macro: 'p = find_nearby(category="{category}")\nstart_navigation(place=p)' }, + { req: "save {place} to my {list} list", macro: 'p = search_place(query="{place}")\nsave_place(place=p, list="{list}")' }, + { req: "share my ETA to {place} with {contact}", macro: 'p = search_place(query="{place}")\nshare_eta(place=p, contact="{contact}")' } + ], + vocab: { + place: ["the airport", "downtown", "the office", "Central Park", "the train station", "the stadium"], + mode: ["driving", "walking", "transit", "cycling"], + category: ["coffee shop", "gas station", "pharmacy", "grocery store", "ATM", "parking"], + list: ["Favorites", "Want to go", "Trip", "Restaurants"], + contact: ["Alex", "mom", "Dana", "the group"] + }, + oos: ["post a tweet", "play a song", "create a GitHub repo"] + }, + { + key: "amazon", + label: "Shopping", + icon: "\u{1F6D2}", + domain: "a shopping operator", + scope: "shopping cart and orders", + desc: "Compiles requests into a macro over product search, cart, orders, and lists; bounces anything that isn\u2019t shopping.", + suggest: "Add two packs of AA batteries to my cart and track my last order.", + ops: [ + { name: "search_product", params: ["query"], ret: "product" }, + { name: "add_to_cart", params: ["product", "qty"] }, + { name: "buy_now", params: ["product"] }, + { name: "find_order", params: ["query"], ret: "order" }, + { name: "track_order", params: ["order"], ret: "status" }, + { name: "reorder", params: ["query"] }, + { name: "add_to_list", params: ["product", "list"] } + ], + fixed: [ + [ + "add two packs of AA batteries to my cart and track my last order", + 'p = search_product(query="AA batteries 2 pack")\nadd_to_cart(product=p, qty=2)\no = find_order(query="last order")\ntrack_order(order=o)' + ] + ], + templates: [ + { req: "add {qty} {product} to my cart", macro: 'p = search_product(query="{product}")\nadd_to_cart(product=p, qty={qty})' }, + { req: "buy {product} now", macro: 'p = search_product(query="{product}")\nbuy_now(product=p)' }, + { req: "reorder {product}", macro: 'reorder(query="{product}")' }, + { req: "track my {product} order", macro: 'o = find_order(query="{product}")\ntrack_order(order=o)' }, + { req: "add {product} to my {list} list", macro: 'p = search_product(query="{product}")\nadd_to_list(product=p, list="{list}")' }, + { req: "search for {product}", macro: 'search_product(query="{product}")' } + ], + vocab: { + product: ["AA batteries", "USB-C cable", "olive oil", "running shoes", "paper towels", "a coffee grinder", "phone case"], + qty: ["1", "2", "3", "4"], + list: ["Wishlist", "Subscribe & Save", "Home", "Gifts"] + }, + oos: ["send an email", "play a video", "navigate to the office"] + }, + { + key: "reddit", + label: "Reddit", + icon: "\u{1F47D}", + domain: "a Reddit operator", + scope: "Reddit posts and comments", + desc: "Compiles requests into a macro over submissions, comments, votes, and subscriptions; bounces anything off-platform.", + suggest: 'Post "What mechanical keyboard should I buy?" to r/keyboards and subscribe.', + ops: [ + { name: "find_post", params: ["query"], ret: "post" }, + { name: "submit_post", params: ["subreddit", "title", "body"] }, + { name: "comment", params: ["post", "text"] }, + { name: "upvote", params: ["post"] }, + { name: "reply_comment", params: ["comment", "text"] }, + { name: "subscribe", params: ["subreddit"] }, + { name: "save_post", params: ["post"] } + ], + fixed: [ + [ + "post what mechanical keyboard should I buy to r/keyboards and subscribe", + 'submit_post(subreddit="keyboards", title="What mechanical keyboard should I buy?", body="Budget is flexible \u2014 looking for recommendations.")\nsubscribe(subreddit="keyboards")' + ] + ], + templates: [ + { req: "post {title} to r/{subreddit}", macro: 'submit_post(subreddit="{subreddit}", title="{title}", body="{title}")' }, + { req: "comment {text} on the {topic} post", macro: 'p = find_post(query="{topic}")\ncomment(post=p, text="{text}")' }, + { req: "upvote the {topic} post", macro: 'p = find_post(query="{topic}")\nupvote(post=p)' }, + { req: "subscribe to r/{subreddit}", macro: 'subscribe(subreddit="{subreddit}")' }, + { req: "save the {topic} post", macro: 'p = find_post(query="{topic}")\nsave_post(post=p)' } + ], + vocab: { + subreddit: ["keyboards", "programming", "AskReddit", "buildapc", "cooking", "fitness"], + title: ["What keyboard should I buy?", "Best beginner setup?", "How do I start running?", "Favorite pasta recipe?"], + text: ["this is the way", "underrated take", "source?", "thanks for sharing", "happy cake day"], + topic: ["the keyboard", "the build", "the recipe", "the AMA", "the discussion"] + }, + oos: ["email my mom", "play a song", "navigate home"] + }, + { + key: "linkedin", + label: "LinkedIn", + icon: "\u{1F4BC}", + domain: "a LinkedIn operator", + scope: "LinkedIn networking and posts", + desc: "Compiles requests into a macro over posts, connections, messages, and endorsements; bounces anything off-platform.", + suggest: "Connect with Priya with a note, then endorse her for product management.", + ops: [ + { name: "find_person", params: ["query"], ret: "person" }, + { name: "post_update", params: ["text"] }, + { name: "connect", params: ["user", "note"] }, + { name: "message", params: ["user", "text"] }, + { name: "endorse", params: ["person", "skill"] }, + { name: "find_post", params: ["query"], ret: "post" }, + { name: "comment", params: ["post", "text"] } + ], + fixed: [ + [ + "connect with Priya with a note then endorse her for product management", + 'connect(user="Priya", note="Great working with you \u2014 let us stay in touch!")\np = find_person(query="Priya")\nendorse(person=p, skill="product management")' + ] + ], + templates: [ + { req: "post an update saying {text}", macro: 'post_update(text="{text}")' }, + { req: "connect with {user} and add a note {note}", macro: 'connect(user="{user}", note="{note}")' }, + { req: "message {user} {text}", macro: 'message(user="{user}", text="{text}")' }, + { req: "endorse {user} for {skill}", macro: 'p = find_person(query="{user}")\nendorse(person=p, skill="{skill}")' }, + { req: "comment {text} on the {topic} post", macro: 'p = find_post(query="{topic}")\ncomment(post=p, text="{text}")' } + ], + vocab: { + text: ["excited to share I started a new role", "we are hiring engineers", "grateful for a great quarter", "thoughts on remote work"], + user: ["Priya", "Alex", "a recruiter", "Dana", "my former manager"], + note: ["Great working with you!", "Loved your talk", "Let us connect", "Fellow alum here"], + skill: ["product management", "leadership", "TypeScript", "design", "data science"], + topic: ["the hiring", "the milestone", "the article", "the announcement"] + }, + oos: ["play music", "open a github issue", "navigate to the airport"] + } +]; +var SKILLS = DEFS.map((d) => buildSkill(d, 6)); +var POPULAR_2026 = [ + { key: "inbox-calendar", name: "Inbox & Calendar", skill: "inbox-calendar", cat: "productivity", bg: "#2f72c4", glyph: "\u2709", fs: 22 }, + { key: "music", name: "Music", skill: "music", cat: "media", bg: "#1db954", glyph: "\u266A", fs: 24 }, + { key: "github", name: "GitHub", skill: "github", cat: "developer", bg: "#181717", glyph: "GH", fs: 15 }, + { key: "youtube", name: "YouTube", skill: "youtube", cat: "media", bg: "#FF0000", glyph: "\u25B6", fs: 18 }, + { key: "instagram", name: "Instagram", skill: "instagram", cat: "social", bg: "linear-gradient(135deg,#feda75,#d62976 48%,#4f5bd5)", glyph: "\u{1F4F7}", fs: 20 }, + { key: "x", name: "X", skill: "x", cat: "social", bg: "#000000", glyph: "\u{1D54F}", fs: 23 }, + { key: "slack", name: "Slack", skill: "slack", cat: "work", bg: "#4A154B", glyph: "S", fs: 24 }, + { key: "notion", name: "Notion", skill: "notion", cat: "productivity", bg: "#0f0f0f", glyph: "N", fs: 24 }, + { key: "maps", name: "Maps", skill: "maps", cat: "navigation", bg: "#34A853", glyph: "\u{1F4CD}", fs: 20 }, + { key: "amazon", name: "Amazon", skill: "amazon", cat: "shopping", bg: "#FF9900", fg: "#232F3E", glyph: "a", fs: 27 }, + { key: "reddit", name: "Reddit", skill: "reddit", cat: "social", bg: "#FF4500", glyph: "\u{1F47D}", fs: 20 }, + { key: "linkedin", name: "LinkedIn", skill: "linkedin", cat: "work", bg: "#0A66C2", glyph: "in", fs: 17 }, + // ── the broader armory (coming soon) ── + { key: "google", name: "Google", cat: "productivity", bg: "#4285F4", glyph: "G", fs: 25 }, + { key: "whatsapp", name: "WhatsApp", cat: "social", bg: "#25D366", glyph: "\u2706", fs: 22 }, + { key: "tiktok", name: "TikTok", cat: "social", bg: "#010101", glyph: "\u266B", fs: 22 }, + { key: "facebook", name: "Facebook", cat: "social", bg: "#1877F2", glyph: "f", fs: 27 }, + { key: "snapchat", name: "Snapchat", cat: "social", bg: "#FFFC00", fg: "#111", glyph: "\u{1F47B}", fs: 22 }, + { key: "messenger", name: "Messenger", cat: "social", bg: "#0084FF", glyph: "\u2726", fs: 22 }, + { key: "discord", name: "Discord", cat: "social", bg: "#5865F2", glyph: "D", fs: 24 }, + { key: "telegram", name: "Telegram", cat: "social", bg: "#229ED9", glyph: "\u2708", fs: 20 }, + { key: "netflix", name: "Netflix", cat: "media", bg: "#E50914", glyph: "NF", fs: 15 }, + { key: "twitch", name: "Twitch", cat: "media", bg: "#9146FF", glyph: "tw", fs: 16 }, + { key: "spotify", name: "Spotify", cat: "media", bg: "#1DB954", glyph: "\u25C9", fs: 20 }, + { key: "pinterest", name: "Pinterest", cat: "social", bg: "#E60023", glyph: "P", fs: 24 }, + { key: "threads", name: "Threads", cat: "social", bg: "#000000", glyph: "@", fs: 24 }, + { key: "uber", name: "Uber", cat: "travel", bg: "#000000", glyph: "U", fs: 24 }, + { key: "doordash", name: "DoorDash", cat: "food", bg: "#FF3008", glyph: "DD", fs: 14 }, + { key: "airbnb", name: "Airbnb", cat: "travel", bg: "#FF5A5F", glyph: "A", fs: 24 }, + { key: "paypal", name: "PayPal", cat: "finance", bg: "#003087", glyph: "P", fs: 23 }, + { key: "venmo", name: "Venmo", cat: "finance", bg: "#3D95CE", glyph: "V", fs: 24 }, + { key: "chatgpt", name: "ChatGPT", cat: "ai", bg: "#10A37F", glyph: "\u2738", fs: 20 }, + { key: "gemini", name: "Gemini", cat: "ai", bg: "#1C69FF", glyph: "\u2726", fs: 20 }, + { key: "perplexity", name: "Perplexity", cat: "ai", bg: "#1FB8CD", glyph: "\u273A", fs: 20 }, + { key: "cursor", name: "Cursor", cat: "developer", bg: "#0b0b0b", glyph: "\u25AE", fs: 18 } +]; + +// src/main.js +var $ = /* @__PURE__ */ __name((id) => document.getElementById(id), "$"); +var log = /* @__PURE__ */ __name((m) => { + const s = $("railMsg"); + if (s) s.textContent = m; + console.log("[emberglass]", m); +}, "log"); +function steps(id) { + const el = $(id), m = {}; + el.querySelectorAll(".step").forEach((s) => m[s.dataset.s] = s); + const all = /* @__PURE__ */ __name(() => Object.values(m), "all"); + return { + reset() { + all().forEach((s) => s.classList.remove("active", "done", "loop")); + }, + active(k) { + m[k]?.classList.add("active"); + }, + activeOnly(k) { + all().forEach((s) => s.classList.remove("active")); + m[k]?.classList.add("active"); + }, + done(k) { + m[k]?.classList.remove("active", "loop"); + m[k]?.classList.add("done"); + }, + loop(keys, on) { + keys.forEach((k) => m[k]?.classList.toggle("loop", on)); + } + }; +} +__name(steps, "steps"); +function startClock(id) { + const el = $(id), t = el.querySelector(".t"), t0 = performance.now(); + let run = true; + el.classList.add("on"); + (/* @__PURE__ */ __name((function f() { + if (!run) return; + t.textContent = ((performance.now() - t0) / 1e3).toFixed(1) + "s"; + requestAnimationFrame(f); + }), "f"))(); + return () => { + run = false; + el.classList.remove("on"); + }; +} +__name(startClock, "startClock"); +var session = new ModelSession({ cfg: QWEN25_3B, log }); +var adapters = new AdapterRegistry(); +var state = { + loaded: false, + busy: false, + err: null, + tuned: null, + // { name, kind:'guided'|'own', build(userText)->messages[], suggest } + activeRunId: null, + // history run currently applied + dirHandle: null + // File System Access workspace folder +}; +var GEN = { maxTokens: 2048, temperature: 0.6, topP: 0.95, topK: 64 }; +var skillByKey = /* @__PURE__ */ __name((key) => SKILLS.find((s) => key && (key === s.key || String(key).startsWith(s.key + " "))), "skillByKey"); +var selectedSkillKey = SKILLS[0].key; +var trainLosses = []; +function sampleExamples(all, n) { + const oos = all.filter(([, a]) => a === "OUT_OF_SCOPE"); + const inscope = all.filter(([, a]) => a !== "OUT_OF_SCOPE"); + const keep = Math.max(0, n - oos.length); + const stride = Math.max(1, Math.floor(inscope.length / Math.max(1, keep))); + const picked = []; + for (let i = 0; i < inscope.length && picked.length < keep; i += stride) picked.push(inscope[i]); + return [...picked, ...oos]; +} +__name(sampleExamples, "sampleExamples"); +function setBadge() { + const rail = $("rail"), chip = $("railChip"); + if (!rail || !chip) return; + if (state.err) { + rail.dataset.state = "err"; + chip.textContent = "Load failed"; + return; + } + if (state.busy === "load") { + rail.dataset.state = "busy"; + chip.textContent = "Loading\u2026"; + return; + } + if (!state.loaded) { + rail.dataset.state = "idle"; + chip.textContent = "Model not loaded"; + return; + } + const sel = $("adapterSel")?.value || "none"; + if (sel === "none") { + rail.dataset.state = "ok"; + chip.textContent = "Live \xB7 base"; + } else { + rail.dataset.state = "tuned"; + chip.textContent = "Live \xB7 tuned: " + sel; + } +} +__name(setBadge, "setBadge"); +function lockInference(on) { + $("inferLock").style.display = on ? "flex" : "none"; + $("run").disabled = on || !state.loaded || state.busy === "gen"; +} +__name(lockInference, "lockInference"); +function gateButtons() { + const ready = state.loaded && !state.busy; + $("run").disabled = !ready; + $("trainGuided").disabled = !ready; + $("trainOwn").disabled = !ready || !ownExamples().length; + for (const id of ["load", "loadHF"]) $(id).disabled = !!state.busy; + const ask = $("askSection"); + if (ask) ask.hidden = !state.loaded; +} +__name(gateButtons, "gateButtons"); +async function loadWith(reader, label) { + if (state.busy) return; + state.busy = "load"; + state.err = null; + setBadge(); + gateButtons(); + try { + await session.loadWith(reader, label); + state.loaded = true; + log("Model ready. Ask it anything below \u2014 or hit Train to teach it something new."); + } catch (e) { + state.err = e.message; + log("Load error: " + e.message); + console.error(e); + } finally { + state.busy = false; + setBadge(); + gateButtons(); + } +} +__name(loadWith, "loadWith"); +function buildMessages(userText) { + const sel = $("adapterSel")?.value || "none"; + if (sel !== "none" && state.tuned && state.tuned.name === sel) return state.tuned.build(userText); + return [{ role: "user", content: userText }]; +} +__name(buildMessages, "buildMessages"); +async function runInference() { + if (!state.loaded || state.busy) return; + const userText = $("prompt").value.trim(); + if (!userText) { + log("type something to ask first"); + return; + } + state.busy = "gen"; + gateButtons(); + const sel = $("adapterSel")?.value || "none"; + adapters.applyToRuntime(sel, session.rt); + const out = $("out"); + out.textContent = ""; + const node = document.createTextNode(""); + out.appendChild(node); + const st = steps("inferSteps"); + st.reset(); + const cap = $("inferCap"); + const stop = startClock("inferClock"); + $("inferProc").classList.add("on"); + setMacroCheck(null); + st.active("tok"); + cap.textContent = "Tokenizing your prompt with the VibeThinker tokenizer\u2026"; + const t0 = performance.now(); + let n = 0, first = true, acc = ""; + try { + const msgs = buildMessages(userText); + st.done("tok"); + st.active("prefill"); + cap.textContent = "Reading the prompt into the KV cache (prefill)\u2026"; + for await (const d of session.generate(msgs, { maxTokens: GEN.maxTokens, temperature: GEN.temperature, topP: GEN.topP, topK: GEN.topK })) { + if (first) { + first = false; + st.done("prefill"); + st.active("decode"); + cap.textContent = "Generating the answer one token at a time\u2026"; + } + node.appendData(d); + acc += d; + n++; + $("tokps").textContent = `${n} tok \xB7 ${(n / ((performance.now() - t0) / 1e3)).toFixed(1)} tok/s`; + out.scrollTop = out.scrollHeight; + } + const dt = (performance.now() - t0) / 1e3; + $("tokps").textContent = `${n} tok \xB7 ${(n / dt).toFixed(1)} tok/s \xB7 ${dt.toFixed(1)}s`; + st.done("prefill"); + st.done("decode"); + st.done("done"); + cap.textContent = `Done \u2014 ${sel === "none" ? "base model" : 'tuned adapter "' + sel + '"'}.`; + const skill = sel !== "none" && state.tuned && state.tuned.name === sel ? skillByKey(state.tuned.base) : null; + if (skill) { + const res = verifyMacro(acc, skill.spec); + setMacroCheck(res, skill, acc); + if (res.status === "ok") stageMsg(`Action resolved \u2014 compiled a ${res.n}-step plan on ${skill.label}.`); + else if (res.status === "oos") stageMsg(`That request is off the map for ${skill.label}. Try one of its actions.`); + else stageMsg(`The plan didn't validate \u2014 adjust the request and try again.`); + if (state.activeRunId) { + bumpUses(state.activeRunId); + renderDock(); + } + } + log(`done (${sel === "none" ? "base model" : "tuned adapter"}).`); + } catch (e) { + out.appendData("\n\n[error] " + e.message); + cap.textContent = "error: " + e.message; + console.error(e); + } finally { + stop(); + $("inferProc").classList.remove("on"); + state.busy = false; + gateButtons(); + } +} +__name(runInference, "runInference"); +async function runTraining({ examples, lr, epochs, accum, base, kind, system, build, suggest }) { + if (!state.loaded) { + log("load the model first (INFERENCE pane)."); + switchTab("infer"); + return; + } + if (state.busy) return; + const name = uniqueName(base); + const runId = newId(); + state.busy = "train"; + lockInference(true); + gateButtons(); + $("trainWidget").style.display = ""; + resetTrainTelemetry(); + const windows = Math.max(1, Math.ceil(examples.length / accum)); + const total = windows * epochs; + let lastLoss = null; + const ctrl = new TrainingController({ + session, + adapters, + log: /* @__PURE__ */ __name(() => { + }, "log"), + trainerOptions: { lr, maxTrainSeq: 384, lmHeadBlock: 128, maxGradNorm: 1, weightDecay: 0, warmupSteps: Math.min(4, total), totalSteps: total, gradAccumSteps: accum } + }); + const st = steps("trainSteps"); + st.reset(); + const cap = $("trainCap"); + const stop = startClock("trainClock"); + st.active("prep"); + cap.textContent = "Building masked, shifted-label examples and tokenizing on the GPU\u2026"; + renderMaskPreview(ctrl, examples[0]); + ctrl.initAdapter(name, { rank: 16, alpha: 32 }); + trainProgress(0, total, null, "warming up\u2026"); + const t0 = performance.now(); + try { + st.done("prep"); + st.loop(["fwd", "bwd", "opt"], true); + cap.textContent = "Looping forward \u2192 backward \u2192 AdamW over your examples (full-network backprop)\u2026"; + await ctrl.train(examples, { + epochs, + onStep: /* @__PURE__ */ __name((r) => { + const { step, loss } = r; + lastLoss = loss; + updateTrainTelemetry(step, total, r); + trainProgress(step, total, loss, `teaching \xB7 step ${step}/${total} \xB7 loss ${loss.toFixed(3)} \xB7 ${fmtNum(r.trainTokPerSec)} tok/s`); + cap.textContent = `Step ${step}/${total} \u2014 forward ${fmtMs(r.microStepMs)} \u2192 backward \u2192 AdamW ${fmtMs(r.optimizerStepMs)} \xB7 loss ${loss.toFixed(3)}`; + }, "onStep") + }); + const dt = ((performance.now() - t0) / 1e3).toFixed(1); + st.loop(["fwd", "bwd", "opt"], false); + st.done("fwd"); + st.done("bwd"); + st.done("opt"); + st.active("swap"); + state.tuned = { name, kind, base, build, suggest, ctrl }; + state.activeRunId = runId; + addAdapterOption(name); + $("adapterSel").value = name; + st.done("swap"); + trainProgress(total, total, null, `done in ${dt}s \u2014 adapter "${name}" is live`); + cap.textContent = `Adapter "${name}" hot-swapped into inference \u2014 live. Trained in ${dt}s.`; + $("downloadAdapter").style.display = ""; + showTryIt(suggest); + try { + const files = await exportLoraAdapter(ctrl.trainer, { name }); + await saveRun( + { + id: runId, + name, + base, + kind, + system: system || null, + suggest: suggest || "", + createdAt: Date.now(), + steps: total, + epochs, + durationSec: +dt, + finalLoss: lastLoss, + rank: 16, + alpha: 32 + }, + { safetensors: files.safetensors, configJson: files.configJson } + ); + renderHistory(); + } catch (e) { + console.warn("[history] save failed", e); + } + log(`Trained "${name}" in ${dt}s. Saved to your fine-tunes; switch to Inference to try it.`); + } catch (e) { + st.loop(["fwd", "bwd", "opt"], false); + trainProgress(0, total, null, "training error: " + e.message); + cap.textContent = "error: " + e.message; + console.error(e); + } finally { + stop(); + state.busy = false; + lockInference(false); + gateButtons(); + } +} +__name(runTraining, "runTraining"); +var MAX_CHARS = 12e3; +var MAX_CHUNKS = 24; +var MIN_WORDS = 12; +var HEAD_WORDS = 6; +function chunkText(text) { + text = (text || "").replace(/\r/g, "").slice(0, MAX_CHARS); + const paras = text.split(/\n{2,}|\.(?=\s)/).map((s) => s.trim()).filter(Boolean); + const out = []; + for (const p of paras) { + const words = p.split(/\s+/).filter(Boolean); + if (words.length < MIN_WORDS) continue; + const head = words.slice(0, HEAD_WORDS).join(" "); + const rest = words.slice(HEAD_WORDS).join(" "); + out.push({ head, rest, full: p }); + if (out.length >= MAX_CHUNKS) break; + } + return out; +} +__name(chunkText, "chunkText"); +var _ownChunks = []; +function ownExamples() { + return _ownChunks.map((c) => ({ messages: [{ role: "user", content: c.head }], completion: " " + c.rest })); +} +__name(ownExamples, "ownExamples"); +function refreshOwn() { + const text = $("ownText").value; + _ownChunks = chunkText(text); + const chars = Math.min(MAX_CHARS, (text || "").length); + $("ownStats").textContent = _ownChunks.length ? `${_ownChunks.length} snippet(s) \xB7 ${chars} chars (cap ${MAX_CHARS}) \xB7 ready to teach` : `paste/drop at least one paragraph (~${MIN_WORDS}+ words). 100% local.`; + gateButtons(); +} +__name(refreshOwn, "refreshOwn"); +function switchTab(which) { + const infer = which === "infer"; + $("paneInfer").classList.toggle("active", infer); + $("paneTrain").classList.toggle("active", !infer); + $("tabInfer").classList.toggle("on", infer); + $("tabTrain").classList.toggle("on", !infer); +} +__name(switchTab, "switchTab"); +function addAdapterOption(name) { + const sel = $("adapterSel"); + if (![...sel.options].some((o) => o.value === name)) { + const o = document.createElement("option"); + o.value = name; + o.textContent = name; + sel.appendChild(o); + } + const wrap = $("adapterWrap"); + if (wrap) wrap.hidden = false; +} +__name(addAdapterOption, "addAdapterOption"); +function trainProgress(step, total, loss, label) { + $("trainBar").style.width = (100 * step / Math.max(1, total)).toFixed(1) + "%"; + $("trainLabel").textContent = label; +} +__name(trainProgress, "trainProgress"); +function resetTrainTelemetry() { + trainLosses = []; + const box = $("trainMetrics"); + if (box) box.hidden = false; + for (const [id, v] of [["tmLoss", "\u2014"], ["tmTokps", "\u2014"], ["tmActive", "\u2014"], ["tmOpt", "\u2014"]]) { + const el = $(id); + if (el) el.textContent = v; + } + const line = $("lossLine"); + if (line) line.setAttribute("points", ""); + const preview = $("maskPreview"); + if (preview) preview.hidden = true; +} +__name(resetTrainTelemetry, "resetTrainTelemetry"); +function updateTrainTelemetry(step, total, r) { + trainLosses.push(r.loss); + $("tmLoss").textContent = r.loss.toFixed(4); + $("tmTokps").textContent = `${fmtNum(r.trainTokPerSec)} tok/s`; + $("tmActive").textContent = `${r.numActive || 0} / ${r.tokens || 0}`; + $("tmOpt").textContent = fmtMs(r.optimizerStepMs); + drawLossSpark(); +} +__name(updateTrainTelemetry, "updateTrainTelemetry"); +function drawLossSpark() { + const line = $("lossLine"); + if (!line || trainLosses.length < 2) return; + const min = Math.min(...trainLosses); + const max = Math.max(...trainLosses); + const span = Math.max(1e-6, max - min); + const points = trainLosses.map((v, i) => { + const x = i / Math.max(1, trainLosses.length - 1) * 300; + const y = 36 - (v - min) / span * 32; + return `${x.toFixed(1)},${y.toFixed(1)}`; + }).join(" "); + line.setAttribute("points", points); +} +__name(drawLossSpark, "drawLossSpark"); +function renderMaskPreview(ctrl, example) { + const box = $("maskPreview"); + const rows = $("maskRows"); + if (!box || !rows || !example) return; + try { + const preview = ctrl.inspectExample(example); + $("maskSummary").textContent = `${preview.tokens.length} tokens \xB7 ${preview.trainPositions} trained next-token labels`; + const shown = preview.rows.slice(0, 96); + rows.innerHTML = '
pos
segment
token
trained target
' + shown.map((r) => { + const cls = `${r.trainsNext ? "train" : ""} ${r.segment}`; + const target = r.trainsNext ? `${r.targetId} ${clip(r.targetText, 24)}` : ""; + return `
${r.index}
${esc(r.segment)}
${r.id} ${esc(clip(r.text, 28))}
${esc(target)}
`; + }).join("") + (preview.rows.length > shown.length ? `
\u2026
truncated
${preview.rows.length - shown.length} more rows
` : ""); + box.hidden = false; + } catch (e) { + rows.innerHTML = `
preview
error
${esc(e.message)}
`; + box.hidden = false; + } +} +__name(renderMaskPreview, "renderMaskPreview"); +function showTryIt(suggest) { + const t = $("tryIt"); + t.style.display = "flex"; + $("tryItBtn").onclick = () => { + switchTab("infer"); + $("adapterSel").value = state.tuned.name; + setBadge(); + $("prompt").value = suggest; + runInference(); + }; + renderEquipPanel(); + if (state.tuned?.name) stageMsg(`New skill learned: \u201C${state.tuned.name}\u201D \u2014 it dropped into your inventory. Equip it to act.`); +} +__name(showTryIt, "showTryIt"); +function renderEquipPanel() { + const bar = $("equipBar"); + if (!bar) return; + const skill = state.tuned ? skillByKey(state.tuned.base) : null; + if (!skill || !skill.spec) { + bar.hidden = true; + return; + } + bar.hidden = false; + const set = /* @__PURE__ */ __name((id, v) => { + const e = $(id); + if (e) e.textContent = v; + }, "set"); + set("equipIcon", skill.icon); + set("equipName", `${skill.label} skill`); + set("equipScope", `scope: ${skill.spec.scope}`); + const ops = $("equipOps"); + if (ops) { + ops.innerHTML = ""; + for (const op of skill.spec.ops) { + const c = document.createElement("span"); + c.className = "equip__op"; + c.textContent = op.name; + c.title = `${op.name}(${(op.params || []).join(", ")})`; + ops.appendChild(c); + } + } + const host = $("equipDrills"); + if (host) { + host.innerHTML = ""; + const inscope = skill.examples.filter(([, a]) => a !== "OUT_OF_SCOPE"); + const step = Math.max(1, Math.floor(inscope.length / 4)); + const picks = []; + for (let i = 0; i < inscope.length && picks.length < 4; i += step) picks.push(inscope[i][0]); + for (const q of picks) { + const b = document.createElement("button"); + b.type = "button"; + b.className = "drill"; + b.textContent = q; + b.title = "Fire this drill"; + b.onclick = () => { + $("prompt").value = q; + runInference(); + }; + host.appendChild(b); + } + } +} +__name(renderEquipPanel, "renderEquipPanel"); +function humanizePlan(text) { + const out = []; + for (const raw of String(text).split("\n")) { + const line = raw.trim(); + if (!line || line === "OUT_OF_SCOPE") continue; + const m = line.match(/^(?:[A-Za-z_]\w*\s*=\s*)?([A-Za-z_]\w*)\s*\((.*)\)\s*;?\s*$/); + if (!m) continue; + const op = m[1].replace(/_/g, " "); + const args = [...m[2].matchAll(/([A-Za-z_]\w*)\s*=\s*"([^"]*)"/g)].map((x) => x[2]).filter(Boolean); + const summary = args.slice(0, 2).join(" \xB7 "); + out.push(summary ? `${op} \u2014 ${summary}` : op); + } + return out; +} +__name(humanizePlan, "humanizePlan"); +function uniqueName(base) { + const taken = new Set(listRuns().map((r) => r.name)); + if (!taken.has(base)) return base; + let i = 2; + while (taken.has(`${base} #${i}`)) i++; + return `${base} #${i}`; +} +__name(uniqueName, "uniqueName"); +function buildFromMeta(meta) { + return meta.system ? (u) => [{ role: "system", content: meta.system }, { role: "user", content: u }] : (u) => [{ role: "user", content: u }]; +} +__name(buildFromMeta, "buildFromMeta"); +function fmtRunMeta(m) { + const parts = []; + if (m.finalLoss != null) parts.push("loss " + Number(m.finalLoss).toFixed(3)); + if (m.steps) parts.push(m.steps + " steps"); + if (m.durationSec != null) parts.push(Math.round(m.durationSec) + "s"); + try { + parts.push(new Date(m.createdAt).toLocaleDateString(void 0, { month: "short", day: "numeric" })); + } catch { + } + return parts.join(" \xB7 "); +} +__name(fmtRunMeta, "fmtRunMeta"); +function renderHistory() { + const runs = listRuns(); + $("historyCount").textContent = String(runs.length); + $("historyEmpty").style.display = runs.length ? "none" : ""; + const ul = $("historyList"); + ul.innerHTML = ""; + for (const m of runs) { + const { lv, xp } = skillLevel(m); + const rar = rarityOf(lv); + const active = m.id === state.activeRunId; + const li = document.createElement("li"); + li.className = "item" + (active ? " active" : ""); + li.dataset.id = m.id; + li.dataset.kind = m.kind || "own"; + li.dataset.rarity = rar.key; + li.title = `${m.name} \u2014 click to equip`; + li.innerHTML = `
${runIcon(m)}L${lv}
${esc(m.name)}
${rar.label} \xB7 ${esc(itemTypeLabel(m))}
${esc(fmtRunMeta(m))}
` + (active ? `
EQUIPPED
` : "") + `
`; + li.querySelector("[data-act=apply]").onclick = (e) => { + e.stopPropagation(); + applyRun(m.id); + }; + li.querySelector("[data-act=export]").onclick = (e) => { + e.stopPropagation(); + exportRun(m.id); + }; + li.querySelector("[data-act=del]").onclick = (e) => { + e.stopPropagation(); + delRun(m.id); + }; + li.onclick = () => applyRun(m.id); + ul.appendChild(li); + } + renderDock(); + renderStage(); +} +__name(renderHistory, "renderHistory"); +var SKILL_ICON = { guided: "\u2694", own: "\u{1F4DC}" }; +var usesByRun = /* @__PURE__ */ new Map(); +function bumpUses(id) { + usesByRun.set(id, (usesByRun.get(id) || 0) + 1); +} +__name(bumpUses, "bumpUses"); +function runIcon(m) { + const sk = skillByKey(m.base); + return sk ? sk.icon : SKILL_ICON[m.kind] || "\u{1F5E1}"; +} +__name(runIcon, "runIcon"); +function skillLevel(m) { + const lv = Math.max(1, Math.min(9, Math.round((m.steps || 12) / 12))); + const loss = m.finalLoss == null ? 1.5 : Number(m.finalLoss); + const xp = Math.max(6, Math.min(100, Math.round(100 * (3 - loss) / 3))); + return { lv, xp }; +} +__name(skillLevel, "skillLevel"); +function rarityOf(lv) { + if (lv >= 9) return { key: "legendary", label: "Legendary" }; + if (lv >= 7) return { key: "epic", label: "Epic" }; + if (lv >= 5) return { key: "rare", label: "Rare" }; + if (lv >= 3) return { key: "uncommon", label: "Uncommon" }; + return { key: "common", label: "Common" }; +} +__name(rarityOf, "rarityOf"); +function itemTypeLabel(m) { + const sk = skillByKey(m.base); + if (sk) return sk.label; + return m.kind === "guided" ? "Skill" : "Custom note"; +} +__name(itemTypeLabel, "itemTypeLabel"); +var BYOD_TILE = { bg: "#6b6256", fg: "#fff", glyph: "\u{1F4DC}", fs: 20 }; +var SERVICES = POPULAR_2026; +var dockRuns = []; +function renderDock() { + const tray = $("dockSlots"); + if (!tray) return; + const runs = listRuns(); + tray.innerHTML = ""; + dockRuns = []; + const seen = /* @__PURE__ */ new Set(); + const addTile = /* @__PURE__ */ __name((svc, opts) => { + const el = document.createElement("div"); + el.className = "dock__tile"; + el.tabIndex = 0; + el.setAttribute("role", "button"); + el.dataset.state = opts.state; + el.dataset.key = svc.key; + if (opts.runid) el.dataset.runid = opts.runid; + const g = document.createElement("span"); + g.className = "dock__glyph"; + g.style.background = svc.bg; + g.style.color = svc.fg || "#fff"; + g.style.fontSize = (svc.fs || 21) + "px"; + g.textContent = svc.glyph; + el.appendChild(g); + if (opts.lv != null) { + const b = document.createElement("span"); + b.className = "dock__lv"; + b.textContent = "L" + opts.lv; + el.appendChild(b); + } + if (opts.keyN != null) { + const k = document.createElement("span"); + k.className = "dock__key"; + k.textContent = opts.keyN; + el.appendChild(k); + } + if (opts.forge) { + const f = document.createElement("span"); + f.className = "dock__forge"; + f.textContent = "+"; + el.appendChild(f); + } + if (opts.lock) { + const l = document.createElement("span"); + l.className = "dock__lock"; + l.textContent = "\u{1F512}"; + el.appendChild(l); + } + const t = document.createElement("span"); + t.className = "dock__tip"; + t.textContent = opts.tip; + el.appendChild(t); + el.setAttribute("aria-label", opts.tip); + el.onclick = opts.onClick; + el.onkeydown = (e) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + opts.onClick(); + } + }; + tray.appendChild(el); + }, "addTile"); + for (const svc of SERVICES) { + if (svc.skill) { + const run = runs.find((r) => skillByKey(r.base)?.key === svc.skill); + if (run) { + seen.add(run.id); + const { lv } = skillLevel(run); + const equipped = run.id === state.activeRunId; + dockRuns.push(run.id); + const keyN = dockRuns.length <= 9 ? dockRuns.length : null; + const uses = usesByRun.get(run.id) || 0; + addTile(svc, { + state: equipped ? "equipped" : "owned", + runid: run.id, + lv, + keyN, + tip: `${svc.name} \xB7 Lv ${lv}${equipped ? " \xB7 equipped" : ""}${uses ? " \xB7 " + uses + "\xD7" : ""}${keyN ? " \xB7 [" + keyN + "]" : ""}`, + onClick: /* @__PURE__ */ __name(() => applyRun(run.id), "onClick") + }); + } else { + addTile(svc, { + state: "forge", + forge: true, + tip: `${svc.name} \u2014 forge this skill`, + onClick: /* @__PURE__ */ __name(() => { + switchTab("train"); + selectSkill(svc.skill); + }, "onClick") + }); + } + } else { + addTile(svc, { + state: "soon", + lock: true, + tip: `${svc.name} \u2014 coming soon`, + onClick: /* @__PURE__ */ __name(() => { + switchTab("train"); + log(`\u201C${svc.name}\u201D skill \u2014 coming soon. The armory grows as we add action spaces.`); + }, "onClick") + }); + } + } + const extra = runs.filter((r) => !seen.has(r.id)); + if (extra.length) { + const sep = document.createElement("div"); + sep.className = "dock__sep"; + tray.appendChild(sep); + } + for (const r of extra) { + const { lv } = skillLevel(r); + const equipped = r.id === state.activeRunId; + dockRuns.push(r.id); + const keyN = dockRuns.length <= 9 ? dockRuns.length : null; + addTile({ key: "byod-" + r.id, name: r.name, ...BYOD_TILE }, { + state: equipped ? "equipped" : "owned", + runid: r.id, + lv, + keyN, + tip: `${r.name} \xB7 Lv ${lv}${equipped ? " \xB7 equipped" : ""}${keyN ? " \xB7 [" + keyN + "]" : ""}`, + onClick: /* @__PURE__ */ __name(() => applyRun(r.id), "onClick") + }); + } +} +__name(renderDock, "renderDock"); +var lastEquipIntent = null; +function equipByIndex(i) { + if (i < 0 || i >= dockRuns.length) return; + lastEquipIntent = dockRuns[i]; + applyRun(dockRuns[i]); +} +__name(equipByIndex, "equipByIndex"); +function setMacroCheck(res, skill, text) { + const el = $("macroCheck"); + if (!el) return; + if (!res || res.status === "empty") { + el.hidden = true; + el.textContent = ""; + el.removeAttribute("data-state"); + return; + } + el.hidden = false; + if (res.status === "ok") { + el.dataset.state = "ok"; + const ops = res.calls.map((c) => c.op).join(", "); + const plan = text ? humanizePlan(text) : []; + const planHtml = plan.length ? `
    ${plan.map((p) => `
  1. ${esc(p)}
  2. `).join("")}
` : ""; + el.innerHTML = `\u2713 valid macro \xB7 ${res.n} call${res.n === 1 ? "" : "s"} on the ${esc(skill.label)} action space \xB7 ${esc(ops)}${planHtml}`; + } else if (res.status === "oos") { + el.dataset.state = "oos"; + el.innerHTML = `\u26D4 OUT_OF_SCOPE \xB7 the ${esc(skill.label)} skill correctly refused \u2014 that request is outside its actions`; + } else { + el.dataset.state = "bad"; + el.innerHTML = `\u2717 invalid macro \xB7 ${esc(res.issues.slice(0, 2).join("; "))}`; + } +} +__name(setMacroCheck, "setMacroCheck"); +var RANKS = [[12, "Grandmaster"], [9, "Master"], [6, "Artisan"], [4, "Adept"], [2, "Journeyman"], [1, "Apprentice"], [0, "Initiate"]]; +function firstColor(bg) { + if (!bg) return null; + const m = String(bg).match(/#[0-9a-f]{3,8}/i); + return m ? m[0] : String(bg).startsWith("#") ? bg : null; +} +__name(firstColor, "firstColor"); +function stageMsg(text) { + const e = $("stageMsg"); + if (e) e.textContent = "\xBB " + text; +} +__name(stageMsg, "stageMsg"); +function renderStage() { + const stage = $("stage"); + if (!stage) return; + const runs = listRuns(); + const acquired = new Set(runs.map((r) => skillByKey(r.base)?.key).filter(Boolean)); + let maxLv = 0, steps2 = 0; + for (const r of runs) { + maxLv = Math.max(maxLv, skillLevel(r).lv); + steps2 += r.steps || 0; + } + const lvl = 1 + Math.floor(steps2 / 120); + const xpPct = Math.round(steps2 % 120 / 120 * 100); + const rank = (RANKS.find(([t]) => runs.length >= t) || [0, "Initiate"])[1]; + const active = runs.find((r) => r.id === state.activeRunId); + const skill = active ? skillByKey(active.base) : null; + const d = skill ? dockOf(skill.key) : null; + const set = /* @__PURE__ */ __name((id, v) => { + const e = $(id); + if (e) e.textContent = v; + }, "set"); + set("stageScore", `${acquired.size} / ${SKILLS.length}`); + set("stageLv", String(lvl)); + set("stageRank", rank); + const xp = $("stageXp"); + if (xp) xp.style.width = xpPct + "%"; + const scene = $("stageScene"); + const icon = $("stageSignIcon"); + if (active) { + set("stageSignName", active.name); + if (icon) { + icon.textContent = d?.glyph || skill?.icon || "\u25C6"; + icon.style.background = d?.bg || "#6b6256"; + icon.style.color = d?.fg || "#fff"; + icon.style.fontSize = Math.round((d?.fs || 18) * 0.8) + "px"; + } + if (scene) scene.style.setProperty("--scene", firstColor(d?.bg) || "#1d6f6a"); + stage.dataset.where = "in"; + } else { + set("stageSignName", "The open web"); + if (icon) { + icon.textContent = "\u{1F310}"; + icon.style.background = "#13393f"; + icon.style.color = "#cdeeea"; + icon.style.fontSize = "17px"; + } + if (scene) scene.style.setProperty("--scene", "#1d6f6a"); + stage.dataset.where = "out"; + } +} +__name(renderStage, "renderStage"); +var dockOf = /* @__PURE__ */ __name((key) => POPULAR_2026.find((s) => s.key === key) || {}, "dockOf"); +function renderSkillPicker() { + const host = $("skillPicker"); + if (!host) return; + const runs = listRuns(); + host.innerHTML = ""; + for (const sk of SKILLS) { + const d = dockOf(sk.key); + const run = runs.find((r) => skillByKey(r.base)?.key === sk.key); + const lv = run ? skillLevel(run).lv : 0; + const b = document.createElement("button"); + b.type = "button"; + b.className = "skillpick__btn" + (sk.key === selectedSkillKey ? " on" : "") + (lv ? " forged" : ""); + b.dataset.key = sk.key; + b.innerHTML = `${d.glyph || sk.icon}${esc(sk.label)}${sk.spec.ops.length} actions \xB7 ${sk.examples.length} examples` + (lv ? `L${lv}` : ""); + b.onclick = () => selectSkill(sk.key); + host.appendChild(b); + } +} +__name(renderSkillPicker, "renderSkillPicker"); +function selectSkill(key) { + const sk = skillByKey(key) || SKILLS[0]; + selectedSkillKey = sk.key; + document.querySelectorAll("#skillPicker .skillpick__btn").forEach((b) => b.classList.toggle("on", b.dataset.key === sk.key)); + const title = $("skillTitle"); + if (title) title.innerHTML = `${sk.icon} ${esc(sk.label)} skill`; + const desc = $("skillDesc"); + if (desc) desc.textContent = sk.desc; + const list = $("guidedList"); + if (list) { + const inscope = sk.examples.filter(([, a]) => a !== "OUT_OF_SCOPE"); + const oos = sk.examples.filter(([, a]) => a === "OUT_OF_SCOPE"); + const sample = [...inscope.slice(0, 5), ...oos.slice(0, 1)]; + const more = sk.examples.length - sample.length; + list.innerHTML = sample.map(([q, a]) => `
  • ${esc(q)}
    ${esc(a)}
  • `).join("") + (more > 0 ? `
  • + ${more} more spec-valid pairs forge with this skill
  • ` : ""); + } +} +__name(selectSkill, "selectSkill"); +async function applyRun(id) { + const meta = getRun(id); + if (!meta) return; + if (!state.loaded) { + log("Load VibeThinker-3B first (Step 1), then tap a fine-tune to use it."); + switchTab("infer"); + return; + } + if (state.busy) return; + state.busy = "apply"; + gateButtons(); + try { + log(`Applying "${meta.name}"\u2026`); + let adapter = adapters.get(meta.name); + if (!adapter) { + const files = await loadRunFiles(id); + adapter = await loadLoraAdapterGPU(session.rt.dev, files, QWEN25_3B); + adapter.name = meta.name; + adapters.adapters[meta.name] = adapter; + } + addAdapterOption(meta.name); + state.tuned = { name: meta.name, kind: meta.kind, base: meta.base, build: buildFromMeta(meta), suggest: meta.suggest }; + state.activeRunId = id; + $("adapterSel").value = meta.name; + setMacroCheck(null); + setBadge(); + renderHistory(); + renderEquipPanel(); + switchTab("infer"); + if (meta.suggest) $("prompt").value = meta.suggest; + stageMsg(`You step into \u201C${meta.name}\u201D. Pick an action below and act.`); + log(`Now serving fine-tune "${meta.name}". Ask away.`); + } catch (e) { + log("Could not apply: " + e.message); + console.error(e); + } finally { + state.busy = false; + gateButtons(); + } +} +__name(applyRun, "applyRun"); +async function exportRun(id) { + const meta = getRun(id); + if (!meta) return; + try { + const { safetensors, configJson } = await getRunBlobs(id); + const stem = (meta.name || "adapter").replace(/[^\w.-]+/g, "_"); + if (state.dirHandle && await ensurePermission(state.dirHandle)) { + await writeFileToDir(state.dirHandle, stem + ".safetensors", safetensors); + await writeFileToDir(state.dirHandle, stem + ".adapter_config.json", configJson); + log(`Saved "${meta.name}" to your connected folder.`); + } else { + triggerBlob(safetensors, stem + ".safetensors"); + triggerBlob(new Blob([configJson], { type: "application/json" }), stem + ".adapter_config.json"); + log(`Exported "${meta.name}".`); + } + } catch (e) { + log("Export failed: " + e.message); + } +} +__name(exportRun, "exportRun"); +async function delRun(id) { + await deleteRun(id); + if (state.activeRunId === id) state.activeRunId = null; + renderHistory(); +} +__name(delRun, "delRun"); +function triggerBlob(data, filename) { + const blob = data instanceof Blob ? data : new Blob([data]); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = filename; + document.body.appendChild(a); + a.click(); + a.remove(); + setTimeout(() => URL.revokeObjectURL(url), 1e3); +} +__name(triggerBlob, "triggerBlob"); +function fmtMs(ms) { + return Number.isFinite(ms) ? `${ms.toFixed(ms >= 100 ? 0 : 1)}ms` : "\u2014"; +} +__name(fmtMs, "fmtMs"); +function fmtNum(n) { + return Number.isFinite(n) ? n >= 100 ? n.toFixed(0) : n.toFixed(1) : "\u2014"; +} +__name(fmtNum, "fmtNum"); +function clip(s, n) { + s = String(s ?? "").replace(/\s+/g, " "); + return s.length > n ? s.slice(0, Math.max(0, n - 1)) + "\u2026" : s; +} +__name(clip, "clip"); +function applyLayout() { + const mq = /* @__PURE__ */ __name((q) => { + try { + return window.matchMedia(q).matches; + } catch { + return false; + } + }, "mq"); + const fold = mq("(horizontal-viewport-segments: 2)") || mq("(spanning: single-fold-vertical)"); + const mobile = mq("(max-width: 700px)"); + document.body.dataset.layout = fold ? "foldable" : mobile ? "mobile" : "desktop"; +} +__name(applyLayout, "applyLayout"); +async function initFs() { + if (!fsSupported) { + $("fsBlock").hidden = true; + return; + } + $("fsBlock").hidden = false; + const setDir = /* @__PURE__ */ __name((h) => { + state.dirHandle = h; + $("fsForget").hidden = false; + $("ownImportDir").hidden = false; + $("fsStatus").textContent = `connected: ${h.name || "folder"} \u2014 adapters can save here; import text below.`; + }, "setDir"); + try { + const saved = await savedDirectory(); + if (saved) setDir(saved); + } catch { + } + $("fsConnect").onclick = async () => { + try { + setDir(await connectDirectory()); + } catch (e) { + if (e.name !== "AbortError") log("folder: " + e.message); + } + }; + $("fsForget").onclick = async () => { + await forgetDirectory(); + state.dirHandle = null; + $("fsForget").hidden = true; + $("ownImportDir").hidden = true; + $("fsStatus").textContent = "not connected \u2014 import training text & save adapters straight to a folder you pick."; + }; + $("ownImportDir").onclick = async () => { + if (!state.dirHandle) return; + if (!await ensurePermission(state.dirHandle, "read")) { + log("permission denied for folder"); + return; + } + try { + const { text, names } = await readDirText(state.dirHandle); + if (!text.trim()) { + $("ownStats").textContent = "no .txt/.md/.json/.csv files found in that folder"; + return; + } + $("ownText").value = (text + "\n" + $("ownText").value).slice(0, MAX_CHARS); + refreshOwn(); + $("ownStats").textContent = `imported ${names.length} file(s) \xB7 ` + $("ownStats").textContent; + } catch (e) { + log("import failed: " + e.message); + } + }; +} +__name(initFs, "initFs"); +window.addEventListener("DOMContentLoaded", () => { + renderSkillPicker(); + selectSkill(selectedSkillKey); + $("tabInfer").onclick = () => switchTab("infer"); + $("tabTrain").onclick = () => switchTab("train"); + $("gear").onclick = () => { + const open = $("settings").hidden; + $("settings").hidden = !open; + $("gear").classList.toggle("on", open); + }; + $("adapterSel").onchange = setBadge; + $("load").onclick = () => loadWith(urlReader($("modelUrl").value.trim()), $("modelUrl").value.trim()); + $("loadHF").onclick = () => { + const repo = $("hfRepo").value.trim(); + const token = ($("hfToken")?.value || "").trim(); + if (!repo) return log("enter a Hugging Face repo id, e.g. WeiboAI/VibeThinker-3B"); + loadWith(hfReader(repo, token), "HF: " + repo); + }; + $("modelFiles").onchange = (ev) => { + const files = [...ev.target.files]; + if (!files.length) return; + const map = {}; + for (const f of files) map[f.name] = f; + loadWith(fileReader(map), `${files.length} local files`); + }; + $("run").onclick = runInference; + $("prompt").addEventListener("keydown", (e) => { + if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) runInference(); + }); + document.addEventListener("keydown", (e) => { + if (e.metaKey || e.ctrlKey || e.altKey) return; + const tag = e.target && e.target.tagName || ""; + if (tag === "INPUT" || tag === "TEXTAREA" || e.target && e.target.isContentEditable) return; + if (e.key >= "1" && e.key <= "9") equipByIndex(+e.key - 1); + }); + $("trainGuided").onclick = () => { + const sk = skillByKey(selectedSkillKey) || SKILLS[0]; + const pool = sampleExamples(sk.examples, 32); + const ex = pool.map(([q, a]) => ({ messages: [{ role: "system", content: sk.system }, { role: "user", content: q }], completion: " " + a })); + const windows = Math.ceil(ex.length / 2); + runTraining({ + examples: ex, + lr: 3e-4, + epochs: Math.max(6, Math.min(14, Math.round(280 / windows))), + accum: 2, + base: sk.key, + kind: "guided", + system: sk.system, + build: /* @__PURE__ */ __name((u) => [{ role: "system", content: sk.system }, { role: "user", content: u }], "build"), + suggest: sk.suggest + }); + }; + $("ownText").addEventListener("input", refreshOwn); + $("ownFiles").onchange = async (ev) => { + const files = [...ev.target.files].slice(0, 5); + let txt = ""; + for (const f of files) { + try { + txt += await f.text() + "\n\n"; + } catch { + } + } + $("ownText").value = (txt + "\n" + $("ownText").value).slice(0, MAX_CHARS); + refreshOwn(); + }; + $("ownFetch").onclick = async () => { + const url = $("ownUrl").value.trim(); + if (!url) return; + $("ownStats").textContent = "fetching readable text via reader proxy\u2026"; + try { + const r = await fetch("https://r.jina.ai/" + url); + if (!r.ok) throw new Error("HTTP " + r.status); + const t = await r.text(); + $("ownText").value = t.slice(0, MAX_CHARS); + refreshOwn(); + } catch (e) { + $("ownStats").textContent = "could not fetch (CORS/blocked) \u2014 paste the text instead. " + e.message; + } + }; + $("trainOwn").onclick = () => { + const ex = ownExamples(); + if (!ex.length) return; + const windows = Math.ceil(ex.length / 2); + runTraining({ + examples: ex, + lr: 3e-4, + accum: 2, + epochs: Math.max(3, Math.min(8, Math.round(50 / windows))), + base: "my-notes", + kind: "own", + system: null, + build: /* @__PURE__ */ __name((u) => [{ role: "user", content: u }], "build"), + suggest: _ownChunks[0]?.head || "" + }); + }; + $("downloadAdapter").onclick = () => { + if (state.tuned?.ctrl?.trainer) downloadLoraAdapter(state.tuned.ctrl.trainer, { name: state.tuned.name }); + }; + applyLayout(); + for (const q of ["(max-width: 700px)", "(horizontal-viewport-segments: 2)", "(spanning: single-fold-vertical)"]) { + try { + window.matchMedia(q).addEventListener("change", applyLayout); + } catch { + } + } + window.__layout = (m) => { + document.body.dataset.layout = m; + }; + window.__eg = { + store: store_exports, + renderHistory, + renderDock, + renderStage, + stageMsg, + renderEquipPanel, + humanizePlan, + applyRun, + exportRun, + delRun, + state, + // devtools/test surface + SKILLS, + POPULAR_2026, + selectSkill, + renderSkillPicker, + verifyMacro, + setMacroCheck, + equipByIndex, + skillByKey, + sampleExamples, + get selectedSkillKey() { + return selectedSkillKey; + }, + get lastEquipIntent() { + return lastEquipIntent; + } + }; + initFs(); + renderHistory(); + switchTab("infer"); + setBadge(); + refreshOwn(); + gateButtons(); +}); +function esc(s) { + return String(s).replace(/[&<>]/g, (c) => ({ "&": "&", "<": "<", ">": ">" })[c]); +} +__name(esc, "esc");