| | |
| | let init, ModelConditionalGeneration; |
| |
|
| | async function fetchArrayBuffer(url) { |
| | const cacheName = "t5-candle-cache"; |
| | const cache = await caches.open(cacheName); |
| | const cachedResponse = await cache.match(url); |
| | if (cachedResponse) { |
| | const data = await cachedResponse.arrayBuffer(); |
| | return new Uint8Array(data); |
| | } |
| | const res = await fetch(url, { cache: "force-cache" }); |
| | cache.put(url, res.clone()); |
| | return new Uint8Array(await res.arrayBuffer()); |
| | } |
| | class ConditionalGeneration { |
| | static instance = {}; |
| |
|
| | static async getInstance(weightsURL, tokenizerURL, configURL, modelID) { |
| | if (modelID.includes("quantized")) { |
| | ({ default: init, ModelConditionalGeneration } = await import( |
| | "./build/m-quantized.js" |
| | )); |
| | } else { |
| | ({ default: init, ModelConditionalGeneration } = await import( |
| | "./build/m.js" |
| | )); |
| | } |
| | if (!this.instance[modelID]) { |
| | await init(); |
| |
|
| | self.postMessage({ status: "loading", message: "Loading Model" }); |
| | const [weightsArrayU8, tokenizerArrayU8, configArrayU8] = |
| | await Promise.all([ |
| | fetchArrayBuffer(weightsURL), |
| | fetchArrayBuffer(tokenizerURL), |
| | fetchArrayBuffer(configURL), |
| | ]); |
| |
|
| | this.instance[modelID] = new ModelConditionalGeneration( |
| | weightsArrayU8, |
| | tokenizerArrayU8, |
| | configArrayU8 |
| | ); |
| | } else { |
| | self.postMessage({ status: "ready", message: "Model Already Loaded" }); |
| | } |
| | return this.instance[modelID]; |
| | } |
| | } |
| |
|
| | self.addEventListener("message", async (event) => { |
| | const { weightsURL, tokenizerURL, configURL, modelID, prompt, params } = |
| | event.data; |
| | let { |
| | temperature = 0.0, |
| | seed = 299792458, |
| | repeat_penalty = 1.1, |
| | repeat_last_n = 64, |
| | top_p = 1, |
| | } = { ...params }; |
| | try { |
| | self.postMessage({ |
| | status: "ready", |
| | message: "Starting T5 Conditional Generation", |
| | }); |
| | const model = await ConditionalGeneration.getInstance( |
| | weightsURL, |
| | tokenizerURL, |
| | configURL, |
| | modelID |
| | ); |
| | self.postMessage({ |
| | status: "decoding", |
| | message: "Decoding Prompt", |
| | }); |
| | const output = model.decode({ |
| | prompt, |
| | temperature, |
| | seed, |
| | top_p, |
| | repeat_penalty, |
| | repeat_last_n, |
| | }); |
| | self.postMessage({ |
| | status: "complete", |
| | message: "complete", |
| | output: output, |
| | }); |
| | } catch (e) { |
| | self.postMessage({ error: e }); |
| | } |
| | }); |
| |
|