📄 src/lib/birdnet.worker.ts
import * as tf from "@tensorflow/tfjs";
import type { Prediction, WorkerInMessage, WorkerOutMessage } from "./types";

const MODEL_PATH = "/models/birdnet/model.json";
const LABELS_PATH = "/models/birdnet/labels/en_us.txt";
const WINDOW_SAMPLES = 144000;
const ALPHA = 5.0;

interface BirdLabel {
  scientificName: string;
  commonName: string;
}

let model: tf.LayersModel | null = null;
let labels: BirdLabel[] = [];

// ── Custom MelSpecLayerSimple ──────────────────────────────────────────────
// Ported from https://github.com/birdnet-team/real-time-pwa (MIT)
class MelSpecLayerSimple extends tf.layers.Layer {
  sampleRate: number;
  specShape: number[];
  frameStep: number;
  frameLength: number;
  melFilterbank: tf.Tensor2D;
  magScale!: tf.LayerVariable;

  constructor(config: Record<string, unknown>) {
    super(config as tf.serialization.ConfigDict);
    this.sampleRate = config.sampleRate as number;
    this.specShape = config.specShape as number[];
    this.frameStep = config.frameStep as number;
    this.frameLength = config.frameLength as number;
    this.melFilterbank = tf.tensor2d(config.melFilterbank as number[][]);
  }

  build(_inputShape: tf.Shape | tf.Shape[]) {
    this.magScale = this.addWeight(
      "magnitude_scaling",
      [],
      "float32",
      tf.initializers.constant({ value: 1.23 }),
    );
    super.build(_inputShape);
  }

  computeOutputShape(inputShape: tf.Shape): tf.Shape {
    return [inputShape[0], this.specShape[0], this.specShape[1], 1];
  }

  call(inputs: tf.Tensor | tf.Tensor[]): tf.Tensor {
    return tf.tidy(() => {
      const x = Array.isArray(inputs) ? inputs[0] : inputs;
      const frameLength = this.frameLength;
      const frameStep = this.frameStep;
      return tf.stack(
        x.split(x.shape[0]).map((input) => {
          let spec = input.squeeze();
          spec = tf.sub(spec, tf.min(spec, -1, true));
          spec = tf.div(spec, tf.max(spec, -1, true).add(1e-6));
          spec = tf.sub(spec, 0.5).mul(2.0);
          // eslint-disable-next-line @typescript-eslint/no-explicit-any
          spec = (tf.engine() as any).runKernel("STFT", {
            signal: spec,
            frameLength,
            frameStep,
          });
          spec = tf.matMul(spec as tf.Tensor2D, this.melFilterbank).pow(2.0);
          spec = spec.pow(
            tf.div(1.0, tf.add(1.0, tf.exp(this.magScale.read()))),
          );
          spec = tf.reverse(spec, -1);
          spec = tf.transpose(spec).expandDims(-1);
          return spec;
        }),
      );
    });
  }

  static get className() {
    return "MelSpecLayerSimple";
  }
}

// ── Custom STFT WebGL kernel ───────────────────────────────────────────────
// Ported from https://github.com/birdnet-team/real-time-pwa (MIT)
function registerStftKernel() {
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
  const kernelFunc: tf.KernelFunc = (params: any) => {
    const { backend, inputs } = params as {
      backend: unknown;
      inputs: { signal: unknown; frameLength: number; frameStep: number };
    };
      const b = backend as Record<string, (...args: unknown[]) => unknown>;
      const { signal, frameLength, frameStep } = inputs;
      const fl = frameLength as number;
      const fs = frameStep as number;
      const innerDim = fl / 2;
      const log2Inner = Math.log2(innerDim);

      // Stage 1: windowing + bit-reversal
      let cur = b.runWebGLProgram(
        {
          variableNames: ["x"],
          outputShape: [(signal as { size: number }).size],
          userCode: `void main(){
            ivec2 c=getOutputCoords();
            int p=c[1]%${innerDim};
            int k=0;
            for(int i=0;i<${log2Inner};++i){
              if((p & (1<<i))!=0){ k|=(1<<(${log2Inner - 1}-i)); }
            }
            int i=2*k;
            if(c[1]>=${innerDim}){ i=2*(k%${innerDim})+1; }
            int q=c[0]*${fl}+i;
            float val=getX((q/${fl})*${fs}+ q % ${fl});
            float cosArg=${(2.0 * Math.PI) / fl}*float(q);
            float mul=0.5-0.5*cos(cosArg);
            setOutput(val*mul);
          }`,
        } as unknown,
        [signal],
        "float32",
      ) as unknown;

      // Stage 2: FFT butterflies
      for (let len = 1; len < innerDim; len *= 2) {
        const prev = cur;
        cur = b.runWebGLProgram(
          {
            variableNames: ["x"],
            outputShape: [innerDim * 2],
            userCode: `void main(){
              ivec2 c=getOutputCoords();
              int b=c[0];
              int i=c[1];
              int k=i%${innerDim};
              int isHigh=(k%${len * 2})/${len};
              int highSign=(1 - isHigh*2);
              int baseIndex=k - isHigh*${len};
              float t=${Math.PI / len}*float(k%${len});
              float a=cos(t);
              float bsin=sin(-t);
              float oddK_re=getX(b, baseIndex+${len});
              float oddK_im=getX(b, baseIndex+${len + innerDim});
              if(i<${innerDim}){
                float evenK_re=getX(b, baseIndex);
                setOutput(evenK_re + (oddK_re*a - oddK_im*bsin)*float(highSign));
              } else {
                float evenK_im=getX(b, baseIndex+${innerDim});
                setOutput(evenK_im + (oddK_re*bsin + oddK_im*a)*float(highSign));
              }
            }`,
          } as unknown,
          [prev],
          "float32",
        ) as unknown;
        (
          b.disposeIntermediateTensorInfo as (t: unknown) => void
        )(prev);
      }

      // Stage 3: real RFFT output
      const real = b.runWebGLProgram(
        {
          variableNames: ["x"],
          outputShape: [innerDim + 1],
          userCode: `void main(){
            ivec2 c=getOutputCoords();
            int b=c[0];
            int i=c[1];
            int zI=i%${innerDim};
            int conjI=(${innerDim}-i)%${innerDim};
            float Zk0=getX(b,zI);
            float Zk1=getX(b,zI+${innerDim});
            float Zk_conj0=getX(b,conjI);
            float Zk_conj1=-getX(b,conjI+${innerDim});
            float t=${-2.0 * Math.PI}*float(i)/float(${innerDim * 2});
            float diff0=Zk0 - Zk_conj0;
            float diff1=Zk1 - Zk_conj1;
            float result=(Zk0+Zk_conj0 + cos(t)*diff1 + sin(t)*diff0)*0.5;
            setOutput(result);
          }`,
        } as unknown,
        [cur],
        "float32",
      ) as unknown;
      (b.disposeIntermediateTensorInfo as (t: unknown) => void)(cur);
      return real as tf.TensorInfo;
  };
  tf.registerKernel({ kernelName: "STFT", backendName: "webgl", kernelFunc });
}

// ── Init ───────────────────────────────────────────────────────────────────

async function init() {
  try {
    await tf.setBackend("webgl");
    registerStftKernel();
    tf.serialization.registerClass(MelSpecLayerSimple);

    const labelsText = await fetch(LABELS_PATH).then((r) => r.text());
    labels = labelsText
      .split("\n")
      .filter(Boolean)
      .map((line) => {
        const idx = line.indexOf("_");
        return {
          scientificName: idx >= 0 ? line.slice(0, idx) : line,
          commonName: idx >= 0 ? line.slice(idx + 1) : line,
        };
      });

    model = await tf.loadLayersModel(MODEL_PATH);
    console.log(
      "BirdNET model inputs:",
      model.inputs.map((t) => t.shape),
      "outputs:",
      model.outputs.map((t) => t.shape),
    );

    tf.tidy(() => {
      (model as tf.LayersModel).predict(tf.zeros([1, WINDOW_SAMPLES]));
    });

    const out: WorkerOutMessage = { type: "ready" };
    self.postMessage(out);
  } catch (err) {
    const out: WorkerOutMessage = {
      type: "error",
      message: String(err),
    };
    self.postMessage(out);
  }
}

// ── Analyze ────────────────────────────────────────────────────────────────

async function analyze(samples: Float32Array) {
  if (!model) return;
  try {
    const audioTensor = tf.tensor2d(samples, [1, WINDOW_SAMPLES]);
    const resTensor = model.predict(audioTensor) as tf.Tensor;
    const rawPreds = (await resTensor.array()) as number[][];
    resTensor.dispose();
    audioTensor.dispose();

    const frame = rawPreds[0];
    const sumsExp = frame.map((p) => Math.exp(ALPHA * p));
    const pooled = sumsExp.map((s) => Math.log(s) / ALPHA);

    const predictions: Prediction[] = pooled
      .map((confidence, i) => ({
        confidence,
        commonName: labels[i]?.commonName ?? `Species ${i}`,
        scientificName: labels[i]?.scientificName ?? "",
      }))
      .filter((p) => p.confidence > 0.1)
      .sort((a, b) => b.confidence - a.confidence)
      .slice(0, 10);

    const out: WorkerOutMessage = { type: "results", predictions };
    self.postMessage(out);
  } catch (err) {
    const out: WorkerOutMessage = { type: "error", message: String(err) };
    self.postMessage(out);
  }
}

self.addEventListener("message", (e: MessageEvent<WorkerInMessage>) => {
  const msg = e.data;
  if (msg.type === "init") {
    init();
  } else if (msg.type === "analyze") {
    analyze(msg.samples);
  }
});