import json
from pathlib import Path
from dataclasses import dataclass

import tyro
import pandas as pd
import matplotlib.pyplot as plt


@dataclass
class Config:
    use_ablation: bool = False


def main():
    args = tyro.cli(Config)
    root = "./stats" if not args.use_ablation else "./stats_ablation"
    root = Path(root)
    files = root.glob("*.json")
    stats = []
    for path in files:
        with path.open("r") as f:
            data = json.load(f)
        if args.use_ablation:
            seed, dset, model, *loss, ablation = path.stem.split("_")
            loss = "_".join(loss)
            model = "_".join([model, ablation])
            if "no-split" in ablation:
                continue
            loss = model
        else:
            seed, dset, model, *loss = path.stem.split("_")
            loss = "_".join(loss)
        row = {
            "data": dset,
            "loss": loss,
            "model": model,
            "seed": seed,
            "rmse": data["rmse"],
            "mae": data["mae"],
        }
        stats.append(row)

    if args.use_ablation:
        # add base
        files = Path("./stats").glob("*_digit.json")
        for path in files:
            with path.open("r") as f:
                data = json.load(f)
            seed, dset, model, *loss = path.stem.split("_")
            loss = "_".join(loss)
            model = "base"
            loss = model
            row = {
                "data": dset,
                "loss": loss,
                "model": model,
                "seed": seed,
                "rmse": data["rmse"],
                "mae": data["mae"],
            }
            stats.append(row)

    # Convert to DataFrame
    df = pd.DataFrame(stats)

    # Calculate mean across seeds
    filtered_df = df[df["data"].isin([f"train{i}" for i in range(1, 11)])]

    mean_df = filtered_df.groupby(["data", "model", "loss"], as_index=False).mean(
        numeric_only=True
    )

    # Calculate mean and std across seeds
    agg_df = filtered_df.groupby(["data", "model", "loss"], as_index=False).agg(
        rmse_mean=("rmse", "mean"),
        rmse_std=("rmse", "std"),
        mae_mean=("mae", "mean"),
        mae_std=("mae", "std"),
    )

    # Generate LaTeX table with mean and std columns
    latex_table = agg_df.to_latex(
        index=False,
        header=[
            "data",
            "model",
            "loss",
            "rmse_mean",
            "rmse_std",
            "mae_mean",
            "mae_std",
        ],
        float_format="%.3f",
    )

    draw(root, mean_df, 22, 20, 16)
    # Save or print the LaTeX table
    print(latex_table)
    with open(root / "table.txt", "w") as f:
        f.write(latex_table)


def draw(
    root, df, fontsize: int = 18, fontsize_small: int = 16, fontsize_tiny: int = 14
):
    df["data_scale"] = df["data"].str.extract("(\d+)").astype(int)
    # Map 'digit_base' to 'vocab' for legend simplification
    df["loss"] = df["loss"].replace({"digit_base": "vocab"})

    # Define custom color palette
    colors = ["#4B6A94", "#B2182B", "#1A936F", "#1B1464"]  # Blue, Red, Green, Purple

    fig, ax = plt.subplots(figsize=(6, 5))
    losses = df.loss.unique()
    if "sft" in losses:
        losses = ["sft", "vocab", "digit"]
    for i, loss_type in enumerate(losses):
        group = df[df["loss"] == loss_type].sort_values(
            by="data_scale"
        )  # Sort by 'data_scale'
        ax.plot(
            group["data_scale"],
            group["rmse"],
            label=loss_type,
            marker="o",
            linewidth=2,
            markersize=6,
            color=colors[i],
        )

    # Customizing the RMSE plot with larger text
    ax.set_xlabel("Training Data", fontsize=fontsize_small, fontweight="bold")
    ax.set_title("RMSE", fontsize=fontsize, fontweight="bold")
    ax.invert_yaxis()  # Invert y-axis
    ax.legend(
        fontsize=fontsize_tiny,
        loc="lower right",
    )
    ax.grid(False)  # Hide grid lines
    plt.xticks(fontsize=fontsize_tiny)
    plt.yticks(fontsize=fontsize_tiny)
    plt.tight_layout()
    fig.savefig(str(root / "rmse_plot.svg"))  # Save RMSE plot as SVG
    plt.close(fig)

    # Plot for MAE with enlarged text
    fig, ax = plt.subplots(figsize=(6, 5))
    for i, loss_type in enumerate(losses):
        group = df[df["loss"] == loss_type].sort_values(
            by="data_scale"
        )  # Sort by 'data_scale'
        ax.plot(
            group["data_scale"],
            group["mae"],
            label=loss_type,
            marker="o",
            linewidth=2,
            markersize=6,
            color=colors[i],
        )

    # Customizing the MAE plot with larger text
    ax.set_xlabel("Training Data", fontsize=fontsize_small, fontweight="bold")
    ax.set_title("MAE", fontsize=fontsize, fontweight="bold")
    ax.invert_yaxis()  # Invert y-axis
    ax.legend(
        fontsize=fontsize_tiny,
        loc="lower right",
    )
    ax.grid(False)  # Hide grid lines
    plt.xticks(fontsize=fontsize_tiny)
    plt.yticks(fontsize=fontsize_tiny)
    plt.tight_layout()
    fig.savefig(str(root / "mae_plot.svg"))  # Save MAE plot as SVG
    plt.close(fig)


if __name__ == "__main__":
    main()
