//import * as _ from "lodash";
import Tone from "tone";

import {
  GANSYNTH_CHECKPOINT,
  NUM_INTERPS,
  MIN_MIDI_PITCH,
  BATCH_SIZE,
  MIDI_PITCHES,
  T,
  PLAYBACK_SR,
  GENERATOR_SR
} from "./constants";
import { resample } from "./resampler";

// import sample from './sample.json';

//import * as mm from "@magenta/music";
let mm = window.mm;
let tf = mm.tf;

let model, modelLoadPromise;

async function loadModel() {
  if (!model) {
    model = new mm.GANSynth(GANSYNTH_CHECKPOINT);
    modelLoadPromise = model.initialize();
  }
  await modelLoadPromise;
  return model;
}

async function getInterpolatedZs(z0Param, z1Param) {
  let zs = [];
  for (let i = 0; i < NUM_INTERPS; i++) {
    let x = tf.scalar(i / NUM_INTERPS);
    zs.push(await slerp(x, z0Param.z, z1Param.z));
    x.dispose();
  }
  let stacked = tf.stack(zs);
  zs.forEach(z => z.dispose());
  return stacked;
}

async function slerp(val, low, high) {
  let omega = tf.tidy(() =>
    tf.acos(
      tf.clipByValue(
        tf.dot(tf.div(low, tf.norm(low)), tf.div(high, tf.norm(high))),
        -1,
        1
      )
    )
  );
  let so = tf.sin(omega);
  let soData = await so.data();
  if (soData[0] === 0) {
    so.dispose();
    return tf.tidy(() =>
      tf.add(tf.mul(tf.sub(1.0, val), low), tf.mul(val, high))
    ); // L'Hopital's rule/LERP
  } else {
    return tf.tidy(() => {
      let res = tf.add(
        tf.mul(tf.div(tf.sin(tf.mul(tf.sub(1.0, val), omega)), so), low),
        tf.mul(tf.div(tf.sin(tf.mul(val, omega)), so), high)
      );
      //so.dispose();
      return res;
    });
  }
}

let pregenCache = {};
async function fetchPregenerated(z0, z1, audioCtx) {
  let cacheKey = `${z0.name}${z1.name}`;
  if (pregenCache.hasOwnProperty(cacheKey)) {
    return pregenCache[cacheKey];
  }
  try {
    let pregen = await fetch(
      `/audio_${z0.name.toLowerCase()}_${z1.name.toLowerCase()}.wav`
    ).then(res => res.arrayBuffer());
    let audioBuffer = await audioCtx.decodeAudioData(pregen);
    let res = audioBuffer.getChannelData(0);
    pregenCache[cacheKey] = res;
    return res;
  } catch (err) {
    try {
      let pregenReverse = await fetch(
        `/audio_${z1.name.toLowerCase()}_${z0.name.toLowerCase()}.wav`
      ).then(res => res.arrayBuffer());
      let audioBufferReverse = await audioCtx.decodeAudioData(pregenReverse);
      let resReverse = audioBufferReverse.getChannelData(0);
      let res = new Float32Array(resReverse.length);
      for (let i = 0; i < NUM_INTERPS; i++) {
        let offset = i * (T * PLAYBACK_SR);
        let revOffset = (NUM_INTERPS - i - 1) * (T * PLAYBACK_SR);
        res.set(
          resReverse.subarray(revOffset, revOffset + T * PLAYBACK_SR),
          offset
        );
      }
      pregenCache[cacheKey] = res;
      return res;
    } catch (err) {}
  }
}

export async function generateWaveforms(z0, z1, audioCtx) {
  if (z0.name !== "Random" && z1.name !== "Random") {
    let pregen = await fetchPregenerated(z0, z1, audioCtx);
    if (pregen) {
      return pregen;
    }
  }
  let model = await loadModel();
  let zs = await getInterpolatedZs(z0, z1);
  let pitch = Tone.Frequency("C4").toMidi();
  let pitchIdx = tf.tensor1d([pitch - MIN_MIDI_PITCH], "int32");
  let pitchOneHot = tf.oneHot(pitchIdx, MIDI_PITCHES).tile([BATCH_SIZE, 1]);
  let sharedBuffer = new window.SharedArrayBuffer(
    T * PLAYBACK_SR * NUM_INTERPS * Float32Array.BYTES_PER_ELEMENT
  );
  let combinedArr = new Float32Array(sharedBuffer);
  let combinedSpecgrams = [];
  for (let i = 0; i < NUM_INTERPS / BATCH_SIZE; i++) {
    let batchZs = zs.slice(i * BATCH_SIZE, BATCH_SIZE);
    let cond = tf
      .concat([batchZs, pitchOneHot], 1)
      .expandDims(1)
      .expandDims(1);
    let steps = await model
      .predict(cond, BATCH_SIZE)
      .unstack()
      .map(t => t.expandDims(0));

    for (let j = 0; j < steps.length; j++) {
      combinedSpecgrams.push(Array.from(await steps[j].data()));
      let audio = await model.specgramsToAudio(steps[j]);
      peakNormalize(audio);
      let audioBuffer = Tone.context.createBuffer(
        1,
        T * GENERATOR_SR,
        GENERATOR_SR
      );
      audioBuffer.copyToChannel(audio, 0, 0);
      let resBuffer;
      if (PLAYBACK_SR !== GENERATOR_SR) {
        resBuffer = await resample(audioBuffer, PLAYBACK_SR);
      } else {
        resBuffer = audioBuffer;
      }
      combinedArr.set(
        resBuffer.getChannelData(0),
        (i * BATCH_SIZE + j) * T * PLAYBACK_SR
      );
    }
    batchZs.dispose();
    cond.dispose();
    steps.forEach(step => step.dispose());
    await tf.nextFrame();
  }
  pitchOneHot.dispose();
  pitchIdx.dispose();
  zs.dispose();
  //console.log(Array.from(combinedArr));

  // Create a mono wave file, 44.1 kHz, 32-bit and 4 samples
  // let wav = new window.WaveFile();
  // wav.fromScratch(1, PLAYBACK_SR, "32f", combinedArr);
  // downloadBlob(wav.toBuffer(), `audio_${z0.name}_${z1.name}.wav`, "audio/wav");

  return combinedArr;
}

// function downloadBlob(data, fileName, mimeType) {
//   var blob, url;
//   blob = new Blob([data], {
//     type: mimeType
//   });
//   url = window.URL.createObjectURL(blob);
//   downloadURL(url, fileName);
//   setTimeout(function() {
//     return window.URL.revokeObjectURL(url);
//   }, 1000);
// }

// function downloadURL(data, fileName) {
//   var a;
//   a = document.createElement("a");
//   a.href = data;
//   a.download = fileName;
//   document.body.appendChild(a);
//   a.style = "display: none";
//   a.click();
//   a.remove();
// }

function peakNormalize(arr) {
  let peakHigh = 0,
    peakLow = 0;
  for (let i = 0; i < arr.length; i++) {
    peakHigh = Math.max(peakHigh, arr[i]);
    peakLow = Math.min(peakLow, arr[i]);
  }
  if (peakHigh > 0.7 || peakLow < -0.7) {
    let ratio = Math.max(peakHigh, -peakLow) / 0.7;
    for (let i = 0; i < arr.length; i++) {
      arr[i] /= ratio;
    }
  }
}
