import { ReportColumnStats } from "@/types/evaluate";
import {
  BarElement,
  CategoryScale,
  Chart as ChartJS,
  Legend,
  LinearScale,
  Title,
  Tooltip,
} from "chart.js";
import "chartjs-plugin-datalabels";
import ChartDataLabels from "chartjs-plugin-datalabels";
import moment from "moment";
import React from "react";
import { Bar } from "react-chartjs-2";

ChartJS.register(
  CategoryScale,
  LinearScale,
  BarElement,
  Title,
  Tooltip,
  Legend,
  ChartDataLabels,
);

const getAllCosts = (reportColumnStats: ReportColumnStats): number[] => {
  return (
    reportColumnStats.map(
      (reportColumnStat) => reportColumnStat.request_price,
    ) || []
  );
};

const getAllDurations = (reportColumnStats: ReportColumnStats): number[] => {
  return (
    reportColumnStats.map(
      (reportColumnStat) => reportColumnStat.request_latency,
    ) || []
  );
};

// New helper functions for tokens
const getAllOutputTokens = (reportColumnStats: ReportColumnStats): number[] => {
  return (
    reportColumnStats.map(
      (reportColumnStat) => reportColumnStat.request_output_tokens,
    ) || []
  );
};

const getAllTotalTokens = (reportColumnStats: ReportColumnStats): number[] => {
  return (
    reportColumnStats.map(
      (reportColumnStat) =>
        reportColumnStat.request_input_tokens +
        reportColumnStat.request_output_tokens,
    ) || []
  );
};

const StatisticGraphs: React.FC<{
  reportColumnStats: ReportColumnStats;
}> = ({ reportColumnStats }) => {
  // Existing data
  const allCosts = getAllCosts(reportColumnStats);
  const allDurations = getAllDurations(reportColumnStats);

  // New data for tokens
  const allOutputTokens = getAllOutputTokens(reportColumnStats);
  const allTotalTokens = getAllTotalTokens(reportColumnStats);

  // Existing totals
  const totalCost = allCosts.reduce((a, b) => a + b, 0);
  const totalLatency = allDurations.reduce((a, b) => a + b, 0);

  // New totals for tokens
  const totalOutputTokens = allOutputTokens.reduce((a, b) => a + b, 0);
  const totalTotalTokens = allTotalTokens.reduce((a, b) => a + b, 0);

  const getHistogramData = (data: number[], costs: boolean = false) => {
    const bins = Array.from({ length: 10 }, () => 0);
    const min = Math.min(...data);
    const max = Math.max(...data);
    const range = max - min || 1; // Prevent division by zero
    const binWidth = range / bins.length;

    data.forEach((value) => {
      const binIndex = Math.min(
        Math.floor((value - min) / binWidth),
        bins.length - 1,
      );
      bins[binIndex] += 1;
    });

    return {
      labels: bins.map((_, i) => {
        const start = min + i * binWidth;
        const end = min + (i + 1) * binWidth;
        // Make label the average
        if (costs) {
          return `${(((start + end) / 2) * 10000).toFixed(1)}`;
        } else {
          return `${((start + end) / 2).toFixed(1)}`;
        }
      }),
      datasets: [
        {
          data: bins,
          backgroundColor: "rgba(59, 130, 246, 0.5)", // Tailwind blue-500
          borderColor: "rgba(59, 130, 246, 0.5)", // Tailwind blue-500
          borderWidth: 1,
        },
      ],
    };
  };

  const histogramOptions: any = {
    responsive: true,
    animation: {
      duration: 300,
    },
    maintainAspectRatio: false,
    plugins: {
      legend: { display: false },
      datalabels: {
        display: true,
        align: "end",
        anchor: "end",
        color: function (context: any) {
          return context.dataset.backgroundColor;
        },
        formatter: function (value: number) {
          return value || "";
        },
      },
      tooltip: {
        animation: {
          duration: 100,
        },
        callbacks: {
          title: function (context: any) {
            const graphLabel = context[0].label;
            return `${graphLabel} seconds`;
          },
          label: function (context: any) {
            const value = context.parsed.y;
            return `Request count: ${value}`;
          },
        },
      },
    },
    scales: {
      x: {
        grid: {
          display: false,
        },
      },
      y: {
        title: {
          display: true,
          text: "Frequency",
          font: { size: 12 },
          color: "rgba(75, 85, 99)", // Tailwind gray-600
        },
        grid: {
          borderColor: "rgba(244, 245, 247)",
          borderWidth: 1,
          borderDash: [3],
        },
        beginAtZero: true,
        ticks: {
          callback: function (value: number) {
            if (Math.floor(value) === value) {
              return value;
            }
          },
        },
      },
    },
  };

  const histogramOptionsCosts: any = {
    ...histogramOptions,
    plugins: {
      ...histogramOptions.plugins,
      tooltip: {
        ...histogramOptions.plugins.tooltip,
        callbacks: {
          ...histogramOptions.plugins.tooltip.callbacks,
          title: function (context: any) {
            const graphLabel = context[0].label;
            const convertedValue = (parseFloat(graphLabel) * 0.0001).toFixed(4);
            return `$${convertedValue}`;
          },
        },
      },
    },
  };

  // New tooltip options for tokens
  const histogramOptionsTokens: any = {
    ...histogramOptions,
    plugins: {
      ...histogramOptions.plugins,
      tooltip: {
        ...histogramOptions.plugins.tooltip,
        callbacks: {
          ...histogramOptions.plugins.tooltip.callbacks,
          title: function (context: any) {
            const graphLabel = context[0].label;
            return `${graphLabel} tokens`;
          },
          label: function (context: any) {
            const value = context.parsed.y;
            return `Request count: ${value}`;
          },
        },
      },
    },
  };

  const getGraph = (
    graphLabel: string,
    title: string,
    data: number[],
    label: string,
    subLabel: string,
    isCost: boolean = false,
    isToken: boolean = false, // New parameter to handle token tooltips
  ) => {
    return (
      <div
        className="flex flex-col rounded-md border border-gray-200 p-4"
        key={title}
      >
        <div className="flex-1 text-sm text-gray-500">{title}</div>
        <div className="flex">
          <div className="flex flex-1 flex-col text-2xl font-semibold 2xl:text-3xl">
            <div className="flex-1">{label}</div>
            <div className="text-sm font-medium text-gray-500">
              Average {subLabel}
            </div>
          </div>
          <div>
            <h4 className="2xl:text-md mb-2 text-center text-sm font-semibold">
              {graphLabel}
            </h4>
            <div className="flex items-center justify-center">
              <div className="flex h-[120px] w-[150px] items-center justify-center 2xl:h-[120px] 2xl:w-[180px]">
                <Bar
                  data={getHistogramData(data, isCost)}
                  options={
                    isCost
                      ? histogramOptionsCosts
                      : isToken
                      ? histogramOptionsTokens
                      : histogramOptions
                  }
                />
              </div>
            </div>
          </div>
        </div>
      </div>
    );
  };

  return (
    <div className="mx-auto mt-4 flex w-full flex-col gap-4 px-4 2xl:px-0">
      {/* Existing graphs */}
      {getGraph(
        "Latency (s)",
        "Total compute time",
        allDurations,
        totalLatency
          ? moment.duration(totalLatency, "seconds").humanize()
          : "-",
        `${(totalLatency / allDurations.length).toFixed(2)}s`,
      )}
      {getGraph(
        "Cost ($0.0001)",
        "Total cost",
        allCosts,
        totalCost ? `$${totalCost.toFixed(4)}` : "-",
        `$${(totalCost / allCosts.length).toFixed(4)}`,
        true, // isCost flag
      )}
      {/* New graphs for tokens */}
      {getGraph(
        "Output Tokens",
        "Total output tokens",
        allOutputTokens,
        totalOutputTokens ? `${totalOutputTokens.toLocaleString()}` : "-",
        `${(totalOutputTokens / allOutputTokens.length).toLocaleString(
          undefined,
          { minimumFractionDigits: 2, maximumFractionDigits: 2 },
        )} output tokens`,
        false,
        true, // isToken flag
      )}
      {getGraph(
        "Total Tokens",
        "Total request tokens",
        allTotalTokens,
        totalTotalTokens ? `${totalTotalTokens.toLocaleString()}` : "-",
        `${(totalTotalTokens / allTotalTokens.length).toLocaleString(
          undefined,
          { minimumFractionDigits: 2, maximumFractionDigits: 2 },
        )} tokens`,
        false,
        true, // isToken flag
      )}
    </div>
  );
};

export default StatisticGraphs;
