import React, { SetStateAction } from "react";
import { download } from "./download";
import * as ort from "onnxruntime-web";

type createInferenceSessionType = {
  modelName: string;
  setLoading: React.Dispatch<SetStateAction<boolean | null>>;
};

type warmupType = {
  setLoading: React.Dispatch<SetStateAction<boolean | null>>;
  modelInputShape: number[];
  model: any;
};

export const createInferenceSession = async ({
  modelName,
  setLoading,
}: createInferenceSessionType) => {
  if (!ort) throw new Error("onnxruntime-web could not be initialized");

  const arrBufNet = await download(
    `/model/${modelName}`, // url
    // ["Loading YOLOv8 Segmentation model"]
  );
  const session = await ort.InferenceSession.create(arrBufNet as any);
  return session;
};

export const warmup = async ({
  setLoading,
  modelInputShape,
  model,
}: warmupType) => {
  setLoading(true);
  const tensor = new ort.Tensor(
    "float32",
    new Float32Array(modelInputShape.reduce((a, b) => a * b)),
    modelInputShape,
  );
  await model.run({ images: tensor });
};

type loadYoloV8Type = createInferenceSessionType & {
  modelInputShape: number[];
};

export const loadYoloV8 = async ({
  modelName,
  setLoading,
  modelInputShape,
}: loadYoloV8Type) => {
  const model = await createInferenceSession({ modelName, setLoading });
  warmup({ setLoading, modelInputShape, model });
  return model;
};

export const yoloV8CellConfig = {
  modelName: "yolo_cell_best.onnx",
  modelInputShape: [1, 3, 640, 640],
  topk: 300,
  iouThreshold: 0.5,
  scoreThreshold: 0.25,
};
