import { Parameter, Provider, ComponentType } from "@/queries/models/types";
import { PromptBlueprint, PromptVersion, SchemaDefinition } from "@/types";

import { PromptData } from "./types";
import { DEFAULT_MODEL_NAMES, HUGGINGFACE_PROVIDER } from "./constants";

export type ParameterValue =
  | string
  | number
  | boolean
  | { type: string; json_schema: SchemaDefinition }
  | Record<string, string | number | boolean>;
export type ParametersObject = Record<string, ParameterValue>;

export const convertToNumber = (value: ParameterValue) => {
  if (typeof value === "string" && !isNaN(Number(value))) {
    return Number(value);
  }

  return value;
};

const processParameter = (param: Parameter): ParameterValue => {
  if (param.component_type === ComponentType.GROUP) {
    return param.parameters?.reduce(
      (acc: ParametersObject, nestedParam: Parameter) => {
        if (!nestedParam.is_disabled) {
          const processedValue = processParameter(nestedParam);
          if (processedValue !== undefined) {
            acc[nestedParam.param_id] = processedValue;
          }
        }
        return acc;
      },
      {},
    ) as ParameterValue;
  }

  return convertToNumber(param.value ?? param.default);
};

export const convertToBlueprintModel = (
  promptData: PromptData | undefined,
): PromptData | undefined => {
  if (!promptData?.metadata?.model?.parameters) return promptData;

  const parametersObject = promptData.metadata.model.parameters.reduce(
    (acc: ParametersObject, param: Parameter) => {
      if (!param.is_disabled) {
        const paramId = param.param_id;
        const paramValue = processParameter(param);

        if (paramValue !== undefined) {
          acc[paramId] = paramValue;
        }
      }
      return acc;
    },
    {} as ParametersObject,
  );

  return {
    inference_client_name: promptData.inference_client_name || null,
    provider_base_url_name: promptData.provider_base_url_name || null,
    metadata: {
      ...promptData.metadata,
      model: {
        ...promptData.metadata.model,
        parameters: parametersObject,
      },
    },
  };
};

export const mapInitialDataWithExistingParams = (
  initialData?: Partial<PromptBlueprint> | Partial<PromptVersion>,
  existingParams: Parameter[] = [],
): PromptData | undefined => {
  if (!initialData?.metadata?.model) return undefined;

  const initialParamsMap = new Map(
    Object.entries(initialData.metadata.model.parameters || {}),
  );

  const updatedExistingParams = existingParams.map((param) => {
    const paramId = param.param_id;
    const hasMatch = initialParamsMap.has(paramId);
    const newValue = initialParamsMap.get(paramId);

    initialParamsMap.delete(paramId);

    return {
      ...param,
      is_disabled: hasMatch
        ? newValue === undefined && param.is_disabled
        : true,
      value: newValue !== undefined ? newValue : param.value,
    };
  });

  const customParams = Array.from(initialParamsMap).map(([key, value]) => ({
    param_id: key,
    name: key,
    value,
    default: null,
    is_disabled: false,
    is_custom: true,
    component_type: ComponentType.INPUT,
  }));

  const combinedParameters = [...updatedExistingParams, ...customParams];

  return {
    ...initialData,
    metadata: {
      ...initialData.metadata,
      model: {
        ...initialData.metadata.model,
        parameters: combinedParameters,
      },
    },
  };
};

export const getDefaultModel = (
  apiModels: Provider[],
  isPromptChat?: boolean,
) => {
  const provider = "openai";
  const defaultModel = !isPromptChat
    ? "gpt-3.5-turbo-instruct"
    : DEFAULT_MODEL_NAMES[provider];
  const defaultParameters = getModelParams(apiModels, provider, defaultModel);

  return {
    name: defaultModel,
    provider,
    parameters: defaultParameters as Parameter[],
  };
};

export const getHuggingFaceModel = (
  models: Provider[],
  isPromptChat?: boolean,
) => {
  return models
    .find((p) => p.provider_name === HUGGINGFACE_PROVIDER)
    ?.model_configs.find((m) => m.is_chat === isPromptChat);
};

export const getModel = (
  models: Provider[],
  provider: string,
  name: string,
) => {
  return models
    .find((p) => p.provider_name === provider)
    ?.model_configs.find((m) => m.llm_model_name === name);
};

export const getModelParams = (
  apiModels?: Provider[],
  provider?: string,
  name?: string,
): Parameter[] => {
  if (!apiModels || !provider || !name) {
    return [];
  }

  return getModel(apiModels, provider, name)?.params || [];
};
