import { CompletionModel } from "docuchatcommontypes";
import type { ModelCardOption } from "~/types/input";
import IconOpenAI from "~/components/Icon/OpenAI.vue";
import IconMistral from "~/components/Icon/Mistral.vue";
import IconAnthropic from "~/components/Icon/Anthropic.vue";
import IconMeta from "~/components/Icon/Meta.vue";

/**
 * Source for grades: https://artificialanalysis.ai/leaderboards/models
 */
function getAllAIModels(): ModelCardOption[] {
  const { plan } = useUserStore();
  const requiresAzureOpenAiIntegration = plan?.requiresAzureOpenAiIntegration ?? false;
  const { modelPricing } = useLimits();

  return [
    {
      id: CompletionModel.Gpt4oMini,
      icon: IconOpenAI,
      company: requiresAzureOpenAiIntegration ? "Azure OpenAI" : "OpenAI",
      label: mapCompletionModel[CompletionModel.Gpt4oMini],
      pricing: modelPricing.value[CompletionModel.Gpt4oMini] ?? 1,
      grades: {
        intelligence: 4,
        speed: 3,
        costEfficiency: 5,
      },
      euResident: !!requiresAzureOpenAiIntegration,
    },
    {
      id: CompletionModel.Gpt4,
      icon: IconOpenAI,
      company: requiresAzureOpenAiIntegration ? "Azure OpenAI" : "OpenAI",
      label: mapCompletionModel[CompletionModel.Gpt4],
      pricing: modelPricing.value[CompletionModel.Gpt4] ?? 1,
      grades: {
        intelligence: 4,
        speed: 1,
        costEfficiency: 1,
      },
      euResident: !!requiresAzureOpenAiIntegration,
    },
    {
      id: CompletionModel.Gpt4o,
      icon: IconOpenAI,
      company: requiresAzureOpenAiIntegration ? "Azure OpenAI" : "OpenAI",
      label: mapCompletionModel[CompletionModel.Gpt4o],
      pricing: modelPricing.value[CompletionModel.Gpt4o] ?? 1,
      grades: {
        intelligence: 5,
        speed: 2,
        costEfficiency: 3,
      },
      euResident: !!requiresAzureOpenAiIntegration,
    },
    {
      id: CompletionModel.MistralLarge,
      icon: IconMistral,
      company: "Mistral",
      label: mapCompletionModel[CompletionModel.MistralLarge],
      pricing: modelPricing.value[CompletionModel.MistralLarge] ?? 1,
      grades: {
        intelligence: 4,
        speed: 1,
        costEfficiency: 3,
      },
      euResident: true,
    },
    {
      id: CompletionModel.Claude3Haiku,
      icon: IconAnthropic,
      company: "Anthropic",
      label: mapCompletionModel[CompletionModel.Claude3Haiku],
      pricing: modelPricing.value[CompletionModel.Claude3Haiku] ?? 1,
      grades: {
        intelligence: 1,
        speed: 3,
        costEfficiency: 5,
      },
      euResident: true,
    },
    {
      id: CompletionModel.Claude3_5Sonnet,
      icon: IconAnthropic,
      company: "Anthropic",
      label: mapCompletionModel[CompletionModel.Claude3_5Sonnet],
      pricing: modelPricing.value[CompletionModel.Claude3_5Sonnet] ?? 1,
      grades: {
        intelligence: 5,
        speed: 2,
        costEfficiency: 3,
      },
      euResident: true,
    },
    {
      id: CompletionModel.LLama3M,
      icon: IconMeta,
      company: "Meta",
      label: mapCompletionModel[CompletionModel.LLama3M],
      pricing: modelPricing.value[CompletionModel.LLama3M] ?? 1,
      grades: {
        intelligence: 3,
        speed: 5,
        costEfficiency: 5,
      },
      euResident: false,
    },
  ];
}

export function getUsableAIModels() {
  const { organization, plan } = useUserStore();
  const enforceDataResidency = organization?.dataResidencyEnforced ?? false;
  const requiresAzureOpenAiIntegration = plan?.requiresAzureOpenAiIntegration ?? false;
  const { modelsAllowed } = useLimits();

  // All models
  let models: ModelCardOption[] = getAllAIModels();

  // Filter US-based models if data residency is enforced
  if (enforceDataResidency)
    models = models.filter(model => model.euResident);

  // Filter based on user's plan
  models = models.filter(model => modelsAllowed.value.includes(model.id as CompletionModel));

  // Filter non-OpenAI models if Azure OpenAI integration is required
  if (requiresAzureOpenAiIntegration)
    models = models.filter(model => model.company === "Azure OpenAI");

  return models.sort((a, b) => a.label!.localeCompare(b.label!));
}

function getUnallowedAIModels() {
  const { modelsAllowed } = useLimits();
  const allModels = getAllAIModels();
  return allModels.filter(model => !modelsAllowed.value.includes(model.id as CompletionModel));
}

export function getDefaultAiModel() {
  const models = getUsableAIModels();

  // Calculate the score for each model
  const scores = models.map(model => ({
    model,
    score: calculateModelScore(model),
  }));

  // Sort the models by score
  scores.sort((a, b) => b.score - a.score);

  return scores[0].model;
}

export function calculateModelScore(model: ModelCardOption) {
  const weights = {
    speed: 1,
    intelligence: 6,
    costEfficiency: 3,
  };

  return ((model.grades?.speed ?? 0) * weights.speed + (model.grades?.intelligence ?? 0) * weights.intelligence + (model.grades?.costEfficiency ?? 0) * weights.costEfficiency) / (weights.speed + weights.intelligence + weights.costEfficiency);
}

export function getModelDetailsTable(): string {
  const t = useNuxtApp().$i18n.t;
  const { modelPricing } = useLimits();
  const { isAdmin, organization } = useUserStore();
  const enforceDataResidency = organization?.dataResidencyEnforced ?? false;
  const models = getUsableAIModels();

  if (!modelPricing.value || models.length === 0)
    return t("common.genericError");

  let markdownTable = `| ${t("base.aiModels.table.headers.provider")} | ${t("base.aiModels.table.headers.model")} | ${t("base.aiModels.table.headers.credits")} | ${t("base.aiModels.labels.intelligence")} | ${t("base.aiModels.labels.speed")} | ${t("base.aiModels.labels.costEfficiency")} | ${t("base.aiModels.labels.hostedIn")} |\n`;
  markdownTable += "|----------|-------|------------------|--------------|-------|-----------------|-----------|\n";

  for (const model of models) {
    const price = modelPricing.value[model.id as CompletionModel];
    const intelligence = "◼︎".repeat(model.grades!.intelligence) + "◻︎".repeat(5 - model.grades!.intelligence);
    const speed = "◼︎".repeat(model.grades!.speed) + "◻︎".repeat(5 - model.grades!.speed);
    const costEfficiency = "◼︎".repeat(model.grades!.costEfficiency) + "◻︎".repeat(5 - model.grades!.costEfficiency);
    const hostedIn = model.euResident ? t("base.aiModels.eu") : t("base.aiModels.us");

    markdownTable += `| ${model.company} | **${model.label}** | ${price} | ${intelligence} | ${speed} | ${costEfficiency} | ${hostedIn} |\n`;
  }

  if (enforceDataResidency)
    markdownTable += `\n${isAdmin ? t("base.aiModels.table.message.forAdmin") : t("base.aiModels.table.message.forOthers")} ${t("base.aiModels.table.message.generic")}`;

  if (getUnallowedAIModels().length > 0)
    markdownTable += `\n${t("base.aiModels.table.message.unallowed")} ${getUnallowedAIModels().map(m => m.label).join(", ")}.`;

  return markdownTable;
}
