from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset
from tqdm import tqdm
from typing import List
import os
import json
from jsonargparse import CLI
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import cycle
import matplotlib.lines as mlines

def get_model_and_tokenizer(model_name, device):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        device_map=device,
        torch_dtype=torch.float32,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

def plot_func():
    models = os.listdir("loss_over_rec_eval")

    df = []
    for model in models:
        checkpoints = os.listdir(f"loss_over_rec_eval/{model}")

        for checkpoint in checkpoints:
            with open(f"loss_over_rec_eval/{model}/{checkpoint}", 'r') as f:
                data = json.load(f)

            mean_data = {int(k): float(np.mean(v)) for k,v in data.items()}
            df.append(
                {
                    "name": name.replace("_part_2",""),
                    "checkpoint": int(checkpoint.replace("chkpt_","").replace(".json","")),
                } | mean_data
            )

    df = pd.DataFrame(df)
    df = df.sort_values(by=["name", "checkpoint"])
    df = df.drop(64, axis=1)
    df = df.drop(32, axis=1)
    print(df)

    df_melted = df.melt(id_vars=['name', 'checkpoint'], var_name='linesize', value_name='value')
    df_melted['linesize'] = df_melted['linesize'].astype(str)

    linestyles = {
        '1': '-', '2': '--', '4': '-.', '8': ':',
        '16': (0, (3, 1, 1, 1)), '32': (0, (5, 1)), '64': (0, (1, 1))
    }

    fig, ax = plt.subplots(figsize=(10, 6))
    colors = cycle(plt.cm.tab10.colors)
    model_colors = {model: next(colors) for model in df['name'].unique()}

    # Store used models and linesizes for legends
    used_models = set()
    used_linesizes = set()

    # Plot each group
    for (model, linesize), group in df_melted.groupby(['name', 'linesize']):
        # if not ("n_schedule_only_max_32_warmup_0" in model):
        # if not model.startswith("k_schedule"):
        #     continue
        # if ("n_schedule_only_max_32_warmup_06" in model):
        #     continue

        ax.plot(
            group['checkpoint'],
            group['value'],
            label=f'model={model}, linesize={linesize}',
            linestyle=linestyles.get(linesize, '-'),
            color=model_colors[model]
        )
        used_models.add(model)
        used_linesizes.add(linesize)

    # Create custom legends
    model_legend_handles = [
        mlines.Line2D([], [], color=model_colors[model], label=f'{model}', linestyle='-')
        for model in sorted(used_models)
    ]
    linestyle_legend_handles = [
        mlines.Line2D([], [], color='black', label=f'{ls}', linestyle=linestyles[ls])
        for ls in sorted(used_linesizes, key=int)
    ]

    # Plot details
    ax.set_xlabel("Optim Step")
    ax.set_ylabel("Loss")

    # Add custom legends below the plot
    fig.legend(handles=model_legend_handles, title="Models", loc='lower center', bbox_to_anchor=(0.25, -0.1), ncol=2)
    fig.legend(handles=linestyle_legend_handles, title="Num Rec", loc='lower center', bbox_to_anchor=(0.75, -0.1), ncol=2)
    ax.set_yscale('log')
    plt.grid(True, axis='both', which='both')
    plt.savefig("loss_over_rec_eval/plot_5.pdf", bbox_inches="tight")


@torch.no_grad()
def main(
    model_name: str, 
    ckpts: List[int], 
    plot: bool = False,
    batch_size: int = 64, 
    device: str = "cuda"
):
    if plot:
        plot_func()
        exit()

    dataset = load_dataset(
        "parquet", 
        data_files="val dataset path"
    )["train"].select(range(1024))
    dataset.set_format("pt")

    for ckpt in ckpts:
        output = {}
        model, tokenizer = get_model_and_tokenizer(f"{model_name}/model_only_chkpt_{ckpt}", device)

        batch = {"input_ids": [], "labels": []}
        for data_idx, inputs in tqdm(enumerate(dataset, start=1)):
            input_ids = inputs["input_ids"][:-1].to(dtype=torch.long, device=device, non_blocking=True)
            mask = ~inputs["attention_mask"].bool()
            labels = torch.where(mask[1:], -100, inputs["input_ids"][1:]).to(
                dtype=torch.long, device=device, non_blocking=True
            )
            batch["input_ids"].append(input_ids)
            batch["labels"].append(labels)

            if (data_idx % batch_size == 0) or (data_idx == len(dataset)):
                input_ids = torch.stack(batch["input_ids"], dim=0)
                labels = torch.stack(batch["labels"], dim=0)
                mask = (labels != -100).float()

                for num_rec in [1,2,4,8,16,32,64]:
                    logits = model(input_ids, num_steps=torch.tensor([num_rec,0], device=model.device)).logits

                    loss = torch.nn.functional.cross_entropy(
                        logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100, reduction='none' 
                    )
                    loss = loss.view(logits.size(0), logits.size(1))
                    loss_per_sample = (loss * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)

                    this_list = output.get(num_rec, [])
                    this_list.append(loss_per_sample.tolist())
                    output[num_rec] = this_list
                    del logits

                batch["input_ids"] = []
                batch["labels"] = []

        output_dir = f"{os.getcwd()}/loss_over_rec_eval/{model_name.split('/')[-1]}"
        os.makedirs(output_dir, exist_ok=True)
        with open(f"{output_dir}/chkpt_{ckpt}.json", "w") as f:
            json.dump(output, f)

if __name__ == "__main__":
    CLI(main)
