await dl.tidy(async () => { const forgetBias = dl.scalar(1.0); const lstm1 = (data: dl.Tensor2D, c: dl.Tensor2D, h: dl.Tensor2D) => dl.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); const lstm2 = (data: dl.Tensor2D, c: dl.Tensor2D, h: dl.Tensor2D) => dl.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); let c: dl.Tensor2D[] = [ dl.zeros([1, lstmBias1.shape[0] / 4]), dl.zeros([1, lstmBias2.shape[0] / 4]) ]; let h: dl.Tensor2D[] = [ dl.zeros([1, lstmBias1.shape[0] / 4]), dl.zeros([1, lstmBias2.shape[0] / 4]) ]; let input = primerData; for (let i = 0; i < expected.length; i++) { const onehot = dl.oneHot(dl.tensor1d([input]), 10); const output = dl.multiRNNCell([lstm1, lstm2], onehot, c, h); c = output[0]; h = output[1]; const outputH = h[1]; const logits = outputH.matMul(fullyConnectedWeights).add(fullyConnectedBiases); const result = await dl.argMax(logits).val(); results.push(result); input = result; } });
await dl.tidy(async () => { const lstm1 = (data: dl.Tensor2D, c: dl.Tensor2D, h: dl.Tensor2D) => dl.basicLSTMCell(forgetBias, lstmKernel1, lstmBias1, data, c, h); const lstm2 = (data: dl.Tensor2D, c: dl.Tensor2D, h: dl.Tensor2D) => dl.basicLSTMCell(forgetBias, lstmKernel2, lstmBias2, data, c, h); const lstm3 = (data: dl.Tensor2D, c: dl.Tensor2D, h: dl.Tensor2D) => dl.basicLSTMCell(forgetBias, lstmKernel3, lstmBias3, data, c, h); const outputs: dl.Scalar[] = []; // Generate some notes. for (let i = 0; i < STEPS_PER_GENERATE_CALL; i++) { // Use last sampled output as the next input. const eventInput = dl.oneHot(lastSample.as1D(), EVENT_SIZE).as1D(); // Dispose the last sample from the previous generate call, since we // kept it. if (i === 0) { lastSample.dispose(); } const conditioning = getConditioning(); const axis = 0; const input = conditioning.concat(eventInput, axis); const output = dl.multiRNNCell([lstm1, lstm2, lstm3], input.as2D(1, -1), c, h); c = output[0]; h = output[1]; const outputH = h[2]; const logits = outputH.matMul(fcW).add(fcB); const softmax = logits.as1D().softmax(); const sampledOutput = dl.multinomial(softmax, 1).asScalar(); outputs.push(sampledOutput); dl.keep(sampledOutput); lastSample = sampledOutput; } c.forEach(val => dl.keep(val)); h.forEach(val => dl.keep(val)); await outputs[outputs.length - 1].data(); for (let i = 0; i < outputs.length; i++) { playOutput(await outputs[i].val()); } if (piano.now() - currentPianoTimeSec > MAX_GENERATION_LAG_SECONDS) { console.warn( `Generation is ${ piano.now() - currentPianoTimeSec} seconds behind, ` + `which is over ${MAX_NOTE_DURATION_SECONDS}. Resetting time!`); currentPianoTimeSec = piano.now(); } const delta = Math.max( 0, currentPianoTimeSec - piano.now() - GENERATION_BUFFER_SECONDS); setTimeout(() => generateStep(loopId), delta * 1000); });