import argparse
import os
import torch
import json
from tqdm import tqdm
from transformers import GPT2Config, GPT2LMHeadModel
from safetensors import safe_open
from typing import List, Dict, Any, Tuple
from loader.data import _load_data
from loader.checkpoint import load_tokenizer
from utils.utils import compute_attention_sparsity, compute_attention_ratio, compute_attention_sparsity2
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
import pandas as pd


class EvaluationDataset(Dataset):
    """Evaluation dataset class"""

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]


def collate_fn(batch, tokenizer):
    """collate_fn function for the data loader"""
    inputs = [item["input"] for item in batch]
    targets = [item["target"] for item in batch]

    input_and_target = [item["input"] + " " + item["target"] for item in batch]
    input_and_target_inv = [item["input"] + " " + " ".join(item["target"].split()[::-1]) for item in batch]
    # breakpoint()

    # Tokenize the input
    encoded_inputs = tokenizer(input_and_target, padding=True, return_tensors="pt")
    encoded_inputs_inv = tokenizer(input_and_target_inv, padding=True, return_tensors="pt")

    return {
        "input_ids": encoded_inputs["input_ids"],
        "attention_mask": encoded_inputs["attention_mask"],
        "input_ids_inv": encoded_inputs_inv["input_ids"],
        "attention_mask_inv": encoded_inputs_inv["attention_mask"],
        "targets": targets,
        "original_inputs": input_and_target,
    }


def load_model(model_path: str, checkpoint_id: str = None) -> Tuple[GPT2LMHeadModel, GPT2Config]:
    """
    Function to load the model and settings

    Args:
        model_path: Path to the model
        checkpoint_id: Checkpoint ID (if not specified, use the latest)

    Returns:
        model: Loaded model
        config: Model settings
    """
    # If the checkpoint ID is not specified, use the latest one
    if checkpoint_id is None:
        checkpoints = [d for d in os.listdir(model_path) if d.startswith("checkpoint-")]
        if not checkpoints:
            raise ValueError(f"No checkpoints found in {model_path}")
        checkpoint_id = max([int(d.split("-")[1]) for d in checkpoints])

    checkpoint_path = os.path.join(model_path, f"checkpoint-{checkpoint_id}")
    config_path = os.path.join(checkpoint_path, "config.json")

    # Load settings
    config = GPT2Config.from_pretrained(config_path)
    config.output_attentions = True  # Enable attention output

    # Initialize the model
    model = GPT2LMHeadModel(config)

    # Load weights from the safetensors file
    state_dict = {}
    with safe_open(os.path.join(checkpoint_path, "model.safetensors"), framework="pt", device="cuda") as f:
        for k in f.keys():
            state_dict[k] = f.get_tensor(k)

    # Solve the problem of missing lm_head weights
    # In GPT2, lm_head usually shares weights with the token embedding layer
    if "lm_head.weight" not in state_dict and "transformer.wte.weight" in state_dict:
        state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

    # Load weights into the model
    model.load_state_dict(state_dict)

    # Move to GPU and set to evaluation mode
    model = model.cuda().eval()

    return model, config


def load_dataset(dataset_path: str) -> List[Dict[str, Any]]:
    """
    Function to load the dataset

    Args:
        dataset_path: Path to the dataset

    Returns:
        dataset: Loaded dataset
    """
    dataset = _load_data(dataset_path)
    return dataset