from pathlib import Path
from typing import Optional, Union

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from benchmark import Experiment
from evaluation import aggregate
from jsonargparse import ArgumentParser
from pydantic import BaseModel
from rich.console import Console
from rich.table import Table
from rich.text import Text
from tqdm import tqdm

from utilities import get_root

plt.rcParams.update({"font.size": 16, "font.family": "Times New Roman"})
sns.set_style("whitegrid")
sns.set_palette("pastel")


class ResultSettings(BaseModel):
    write: bool = True
    plot: bool = False
    save: bool = False
    save_raw: bool = False
    path: Optional[Path] = None


def metric_at_k(n: int, c: int, k: int) -> float:
    """
    Args:
        n: Number of samples.
        c: Number of correct samples.
        k: k
    """
    if c == 0:
        return 0.0
    if n - c < k:
        return 1.0
    product = 1.0
    for i in range(n - c + 1, n + 1):
        product *= 1.0 - k / i
    return 1.0 - product


def table(df: pd.DataFrame) -> pd.DataFrame:
    group = df.groupby(
        ["dataset", "name", "model-name", "input-kind", "input-which"]
    ).agg(
        count=("compiles", "count"),
        compiles=("compiles", "mean"),
        executes=("executes", "mean"),
        value_match=("value_match", "mean"),
        exact_match=("exact_match", "mean"),
    )
    group = group.sort_values(by=["input-kind", "input-which", "name"])
    return group


def write_table(df: pd.DataFrame) -> str:
    group = table(df).reset_index()
    rich_table = Table()

    # Add columns
    for col in group.columns:
        if col in ["compiles", "executes", "value_match", "exact_match"]:
            rich_table.add_column(col, justify="right")
        else:
            rich_table.add_column(col)

    # Find best values for each metric within each model
    best_values = {}
    for name in group["model-name"].unique():
        model_data = group[group["model-name"] == name]
        best_values[name] = {}
        for metric in ["compiles", "executes", "value_match", "exact_match"]:
            best_values[name][metric] = model_data[metric].max()

    # Add rows with highlighting
    prev_model = None
    for _, row in group.iterrows():
        if row["input-kind"] != prev_model:
            if prev_model is not None:
                rich_table.add_section()
            prev_model = row["input-kind"]
        row_data = list()
        for col in group.columns:
            value = row[col]
            if col in ["compiles", "executes", "value_match", "exact_match"]:
                # Check if this is the best value for this metric within this model
                if value == best_values[row["model-name"]][col]:
                    row_data.append(Text(f"{value:.3f}", style="bold"))
                else:
                    row_data.append(f"{value:.3f}")
            else:
                row_data.append(str(value))
        rich_table.add_row(*row_data)

    # Print table
    console = Console()
    console.print(rich_table)

    return ""


def save_table(df: pd.DataFrame, path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    group = table(df)
    if path.suffix == ".csv":
        group.to_csv(path)
    if path.suffix == ".tex":
        group.to_latex(path)
    if path.suffix == ".md":
        group.reset_index().to_markdown(path)


def plot_distribution(df: pd.DataFrame, path: Path, metric: str) -> None:
    path.mkdir(parents=True, exist_ok=True)

    def plot_group(data: pd.DataFrame) -> None:
        name = f"{data['model-name'].iloc[0]}-{data['input-kind'].iloc[0]}-{data['input-which'].iloc[0]}".lower()
        path_out = path / f"{name}.png"
        plt.figure()
        fig = sns.histplot(data=data, x=metric, binwidth=0.1, stat="percent")
        fig.set(xlim=(0, 1), ylim=(0, None), title=name)
        plt.savefig(path_out)

    df.groupby(["model-name", "input-kind", "input-which"])[
        ["model-name", "input-kind", "input-which", metric]
    ].apply(plot_group)


if __name__ == "__main__":

    parser = ArgumentParser()
    parser.add_argument("--results", type=Union[Path, list[Path]], default=None)
    parser.add_argument("--settings", type=ResultSettings, default=ResultSettings())
    args = parser.parse_args()
    args_init = parser.instantiate_classes(args)
    args.results = args.results or [get_root() / Path("results")]

    settings: ResultSettings = args_init.settings

    results = list()
    if isinstance(args.results, Path):
        args.results = [args.results]
    for result in args.results:
        if result.is_dir():
            results.extend(result.rglob("*.json"))
        else:
            results.append(result)

    # load dataframe
    records = list()
    for file in tqdm(results):
        print(file)
        data = Experiment.model_validate_json(file.read_text(encoding="utf-8"))
        base_input = {
            f"input-{k}": v for k, v in data.input.model_dump(mode="json").items()
        }
        base_model = {
            f"model-{k}": v for k, v in data.model.model_dump(mode="json").items()
        }
        base = {
            **data.solver,
            **base_input,
            **base_model,
        }
        for solution in data.solutions:
            if len(solution.predictions) == 0:
                continue
            solution_metrics = [p.metrics for p in solution.predictions]
            solution_aggregated = {
                k.name: v for k, v in aggregate(solution_metrics).items()
            }
            records.append(
                {
                    "identifier": solution.identifier,
                    "dataset": data.dataset,
                    **base,
                    **solution_aggregated,
                }
            )
    df = pd.DataFrame.from_records(records)

    if settings.save_raw:
        df.to_csv(settings.path / "results-raw.csv", index=False)

    if settings.write:
        print("\nℹ️")
        print("Found", len(df), "solutions")
        print("  on", len(df["model-name"].unique()), "models,")
        print("  using", len(df["name"].unique()), "solvers.")
        print("\n📈")
        print(write_table(df))

    if settings.save:
        settings.path = settings.path or (get_root() / "plots")
        save_table(df, settings.path / "results.csv")
        save_table(df, settings.path / "results.md")

    if settings.plot:
        settings.path = settings.path or (get_root() / "plots")

        plot_distribution(df, settings.path / "value_match", "value_match")
