import json
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from argparse import ArgumentParser
import os

def load_jsonl(file_path):
    data = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def setup_axis(ax, xlabel="Training Steps", ylabel="", title="", add_grid=True):
    """Common axis setup"""
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    if add_grid:
        ax.grid(True, alpha=0.2, linewidth=0.5)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    return ax


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--model_size", type=str, required=True, choices=["160m", "1b"])
    args = parser.parse_args()

    folder_name = "tld_plots"
    
    model_size = args.model_size
    if model_size == "160m":
        model_types = ["clean_v3", "masked_bigram_loss_v4"]
    elif model_size == "1b":
        model_types = ["clean_1b", "masked_bigram_loss_1b"]

    # Apply ICLR style globally
    plt.style.use("default")
    plt.rcParams.update(
        {
            "figure.dpi": 100,
            "font.size": 10,
            "font.family": "serif",
            "font.serif": ["DejaVu Serif", "Computer Modern Roman", "Times New Roman"],
            "mathtext.fontset": "cm",
            "axes.linewidth": 0.8,
            "axes.labelsize": 10,
            "axes.titlesize": 11,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "xtick.labelsize": 9,
            "ytick.labelsize": 9,
            "legend.fontsize": 9,
            "legend.frameon": False,
            "lines.linewidth": 1.5,
            "savefig.dpi": 300,
            "savefig.bbox": "tight",
        }
    )

    # Define colors
    colors = {
        f"{model_types[0]}": "#0173B2",
        f"{model_types[1]}": "#DE8F05",
    }
    colors_50 = colors.copy()
    colors_500 = {
        f"{model_types[0]}": "#56B4E9",
        f"{model_types[1]}": "#ECE133",
    }

    
    # Load data
    tld_path = Path(__file__).parent / "results"
    type_to_display = {
        f"{model_types[0]}": "Clean",
        f"{model_types[1]}": "Masked Bigram",
    }

    if not os.path.exists(tld_path / folder_name):
        os.makedirs(tld_path / folder_name)


    # Create list of plot configurations
    plot_configs = [
        {
            "filename": "tld_all_samples",
            "sample_type": "",
            "plot_type": "tld",
            "title": "Token-Loss Difference: All Samples",
            "ylabel": "Token-Loss Difference",
        },
        {
            "filename": "tld_non_extractive",
            "sample_type": "_skip_bigrams",
            "plot_type": "tld",
            "title": "Token-Loss Difference: Non-Extractive Samples",
            "ylabel": "Token-Loss Difference",
        },
        {
            "filename": "position_losses_all_samples",
            "sample_type": "",
            "plot_type": "position",
            "title": "Position-Specific Losses: All Samples",
            "ylabel": "Loss",
        },
        {
            "filename": "position_losses_non_extractive",
            "sample_type": "_skip_bigrams",
            "plot_type": "position",
            "title": "Position-Specific Losses: Non-Extractive Samples",
            "ylabel": "Loss",
        },
    ]

    # Create each figure
    for config in plot_configs:
        fig, ax = plt.subplots(figsize=(7, 4))

        for model_type in model_types:
            # Load data
            path = tld_path / f"{model_type}_tld{config['sample_type']}.jsonl"
            result = load_jsonl(path)

            step = [point["step"] for point in result]

            if config["plot_type"] == "tld":
                # TLD plots
                avg = [point["avg_tld"] for point in result]
                ci_lower = [point["ci_lower_tld"] for point in result]
                ci_upper = [point["ci_upper_tld"] for point in result]

                ax.plot(
                    step,
                    avg,
                    label=type_to_display[model_type],
                    color=colors[model_type],
                    alpha=0.9,
                )
                ax.fill_between(
                    step, ci_lower, ci_upper, alpha=0.2, color=colors[model_type]
                )

            else:  # position plots
                # 50th token
                avg_50 = [point["avg_loss_50"] for point in result]
                ci_lower_50 = [point["ci_lower_loss_50"] for point in result]
                ci_upper_50 = [point["ci_upper_loss_50"] for point in result]

                ax.plot(
                    step,
                    avg_50,
                    label=f"{type_to_display[model_type]} - 50th",
                    color=colors_50[model_type],
                    alpha=0.9,
                    linestyle="-",
                )
                ax.fill_between(
                    step,
                    ci_lower_50,
                    ci_upper_50,
                    alpha=0.15,
                    color=colors_50[model_type],
                )

                # 500th token
                avg_500 = [point["avg_loss_500"] for point in result]
                ci_lower_500 = [point["ci_lower_loss_500"] for point in result]
                ci_upper_500 = [point["ci_upper_loss_500"] for point in result]

                ax.plot(
                    step,
                    avg_500,
                    label=f"{type_to_display[model_type]} - 500th",
                    color=colors_500[model_type],
                    alpha=0.9,
                    linestyle="--",
                )
                ax.fill_between(
                    step,
                    ci_lower_500,
                    ci_upper_500,
                    alpha=0.15,
                    color=colors_500[model_type],
                )

        # Setup axis
        setup_axis(ax, ylabel=config["ylabel"], title=config["title"])

        # Add legend
        if config["plot_type"] == "position":
            ax.legend(loc="best", ncol=2, fontsize=8)
        else:
            ax.legend(loc="best")
            # Add text annotation for TLD plots

        plt.tight_layout()

        # Save as both PDF and PNG
        plt.savefig(tld_path / folder_name / f"{config['filename']}_{model_size}.pdf", dpi=300)
        plt.savefig(tld_path / folder_name / f"{config['filename']}_{model_size}.png", dpi=300)
        plt.show()
        plt.close()
