import argparse
import os
import torch
import json
from tqdm import tqdm
from transformers import GPT2Config
from loader.models.my_gpt2 import MyGPT2
from safetensors import safe_open
from typing import List, Dict, Any, Tuple
from loader.data import _load_data
from loader.checkpoint import load_tokenizer
from utils.utils import compute_attention_sparsity, compute_attention_ratio
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
import pandas as pd


class EvaluationDataset(Dataset):
    """Evaluation dataset class"""

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]


def collate_fn(batch, tokenizer):
    """collate_fn function for the data loader"""
    inputs = [item["input"] for item in batch]
    targets = [item["target"] for item in batch]

    # Tokenize the input
    encoded_inputs = tokenizer(inputs, padding=True, return_tensors="pt")

    return {
        "input_ids": encoded_inputs["input_ids"],
        "attention_mask": encoded_inputs["attention_mask"],
        "targets": targets,
        "original_inputs": inputs,
    }


def visualize_token_attention(attention_tensor, input_len, save_path, token_idx, sample_idx=0):
    """
    Function to visualize the Attention map of a specific output token

    Args:
        attention_tensor: attention tensor of a specific token [batch, num_heads, tgt_len, src_len]
        input_len: Length of the input sequence
        save_path: Path to save
        token_idx: Index of the output token
        sample_idx: Sample index in the batch
    """
    # Create save directory
    os.makedirs(save_path, exist_ok=True)

    # Select a specific sample in the batch
    sample_attention = attention_tensor[sample_idx]  # [num_heads, tgt_len, src_len]

    # Get the number of heads
    num_heads = sample_attention.size(0)

    # Visualize the attention map of each head
    for head_idx in range(num_heads):
        head_attention = sample_attention[head_idx].detach().cpu().numpy()  # [tgt_len, src_len]

        plt.figure(figsize=(12, 10))
        sns.heatmap(head_attention, cmap="viridis")
        plt.title(f"Token {token_idx}, Head {head_idx} Attention")
        plt.xlabel("Source Token Position")
        plt.ylabel("Attention")

        # Add a boundary line between input and output
        if input_len < head_attention.shape[1]:
            plt.axvline(x=input_len, color="r", linestyle="-")

        plt.savefig(os.path.join(save_path, f"attention_token{token_idx}_head{head_idx}_sample{sample_idx}.png"))
        plt.close()

    # Visualize the average of all heads
    avg_attention = sample_attention.mean(dim=0).detach().cpu().numpy()  # [tgt_len, src_len]

    plt.figure(figsize=(12, 10))
    sns.heatmap(avg_attention, cmap="viridis")
    plt.title(f"Token {token_idx}, Average Attention")
    plt.xlabel("Source Token Position")
    plt.ylabel("Attention")

    # Add a boundary line between input and output
    if input_len < avg_attention.shape[1]:
        plt.axvline(x=input_len, color="r", linestyle="-")

    plt.savefig(os.path.join(save_path, f"attention_token{token_idx}_avg_sample{sample_idx}.png"))
    plt.close()


def visualize_combined_attention(
    attentions, input_len, output_len, save_path, sample_idx=0, input_text=None, output_text=None, tokenizer=None
):
    """
    Function to visualize the combined attention of all output tokens

    Args:
        attentions: attentions during model generation
            - attentions[0]: [batch, num_heads, generated_length, src_length]
            - attentions[1] and after: [batch, num_heads, 1, src_length]
        input_len: Length of the input sequence
        output_len: Length of the output sequence
        save_path: Path to save
        sample_idx: Sample index in the batch
        input_text: Input text
        output_text: Output text
        tokenizer: Tokenizer
    """
    # Create save directory
    os.makedirs(save_path, exist_ok=True)

    # Get the number of heads
    num_heads = attentions[0][0].size(1)
    batch_size = attentions[0][0].size(0)

    # Create token labels
    if input_text is not None and output_text is not None and tokenizer is not None:
        input_tokens = tokenizer.tokenize(input_text)
        output_tokens = tokenizer.tokenize(output_text)
        all_tokens = input_tokens + output_tokens

    # List for calculating sparsity and Attention ratio for the entire batch
    batch_sparsity_results = []
    batch_ratio_results = []

    # Create combined attention for each head
    for head_idx in range(num_heads):
        # Calculate sparsity and Attention ratio for the entire batch
        head_sparsity_results = []
        head_ratio_results = []

        for batch_idx in range(batch_size):
            # Get the first attention matrix
            first_attn = attentions[0][0][batch_idx, head_idx]  # [generated_length, src_length]

            # Initialize an empty matrix for the remaining attention matrices
            src_length = attentions[0][0][batch_idx, head_idx].size(1)
            remaining_attn = torch.zeros(output_len, src_length + output_len - 1, device=first_attn.device)

            # Set the first attention
            remaining_attn[0, :src_length] = attentions[0][0][batch_idx, head_idx, -1]

            # Fill the remaining attention matrices
            for i in range(1, len(attentions)):
                if i > output_len:
                    break
                remaining_attn[i, : src_length + i] = attentions[i][0][batch_idx, head_idx].squeeze(0)

            # Calculate sparsity
            sparsity_result = compute_attention_sparsity(remaining_attn)
            head_sparsity_results.append(sparsity_result)

            # Calculate Attention ratio
            # Convert remaining_attn to the shape [batch_size=1, num_heads=1, seq_len, seq_len]
            attn_matrix = remaining_attn.unsqueeze(0).unsqueeze(0)
            ratio_result = compute_attention_ratio(attn_matrix, input_len)
            head_ratio_results.append(ratio_result)
            # Visualization of the specified sample
            if batch_idx == sample_idx:
                # Visualize as a heatmap
                plt.figure(figsize=(30, 16))

                if input_text is not None and output_text is not None and tokenizer is not None:
                    y_labels = output_tokens[len(input_tokens) :] + ["</s>"]
                    sns.heatmap(
                        remaining_attn.detach().cpu().numpy(),
                        cmap="YlGnBu",
                        vmin=0,
                        vmax=0.1,
                        xticklabels=output_tokens,
                        yticklabels=y_labels,
                    )
                    plt.xticks(rotation=45, ha="right")
                    plt.yticks(rotation=0)
                else:
                    sns.heatmap(remaining_attn.detach().cpu().numpy(), cmap="YlGnBu", vmin=0, vmax=0.1)

                plt.title(f"Combined Attention Map, Head {head_idx}")
                plt.xlabel("Source Token Position")
                plt.ylabel("Target Token Position")
                plt.axhline(y=input_len, color="r", linestyle="-")
                plt.axvline(x=input_len, color="r", linestyle="-")
                plt.tight_layout()
                plt.savefig(os.path.join(save_path, f"combined_attention_head{head_idx}_sample{sample_idx}.png"))
                plt.close()

        # Calculate the average sparsity for each head
        mean_sparsity = np.mean([r["mean_sparsity"] for r in head_sparsity_results])
        std_sparsity = np.mean([r["std_sparsity"] for r in head_sparsity_results])

        # Calculate the average Attention ratio for each head
        mean_ratio = np.mean([r["mean_ratio"].item() for r in head_ratio_results])
        std_ratio = np.mean([r["std_ratio"].item() for r in head_ratio_results])

        batch_sparsity_results.append(
            {"head_idx": head_idx, "mean_sparsity": mean_sparsity, "std_sparsity": std_sparsity}
        )

        batch_ratio_results.append({"head_idx": head_idx, "mean_ratio": mean_ratio, "std_ratio": std_ratio})
        # breakpoint()

    # Save sparsity results
    sparsity_df = pd.DataFrame(batch_sparsity_results)
    sparsity_df.to_csv(os.path.join(save_path, "attention_sparsity.csv"), index=False)

    # Save Attention ratio results
    ratio_df = pd.DataFrame(batch_ratio_results)
    ratio_df.to_csv(os.path.join(save_path, "attention_ratio.csv"), index=False)

    # Visualize sparsity
    plt.figure(figsize=(12, 6))
    plt.errorbar(
        sparsity_df["head_idx"],
        sparsity_df["mean_sparsity"],
        yerr=sparsity_df["std_sparsity"],
        fmt="o-",
        label="Sparsity",
    )
    plt.xlabel("Head Index")
    plt.ylabel("Sparsity")
    plt.title("Attention Sparsity by Head")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, "attention_sparsity_plot.png"))
    plt.close()

    # Visualize Attention ratio
    plt.figure(figsize=(12, 6))
    plt.errorbar(
        ratio_df["head_idx"],
        ratio_df["mean_ratio"],
        yerr=ratio_df["std_ratio"],
        fmt="o-",
        label="Output Token Attention Ratio",
    )
    plt.xlabel("Head Index")
    plt.ylabel("Attention Ratio")
    plt.title("Output Token Attention Ratio by Head")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, "attention_ratio_plot.png"))
    plt.close()


def load_model(model_path: str, checkpoint_id: str = None) -> Tuple[MyGPT2, GPT2Config]:
    """
    Function to load the model and settings

    Args:
        model_path: Path to the model
        checkpoint_id: Checkpoint ID (if not specified, use the latest)

    Returns:
        model: Loaded model
        config: Model settings
    """
    # If the checkpoint ID is not specified, use the latest one
    if checkpoint_id is None:
        checkpoints = [d for d in os.listdir(model_path) if d.startswith("checkpoint-")]
        if not checkpoints:
            raise ValueError(f"No checkpoints found in {model_path}")
        checkpoint_id = max([int(d.split("-")[1]) for d in checkpoints])

    checkpoint_path = os.path.join(model_path, f"checkpoint-{checkpoint_id}")
    config_path = os.path.join(checkpoint_path, "config.json")

    # Load settings
    config = GPT2Config.from_pretrained(config_path)
    config.output_attentions = True  # Enable attention output

    # Initialize the model
    # model = GPT2LMHeadModel(config)
    model = MyGPT2(config)

    # Load weights from the safetensors file
    state_dict = {}
    with safe_open(os.path.join(checkpoint_path, "model.safetensors"), framework="pt", device="cuda") as f:
        for k in f.keys():
            state_dict[k] = f.get_tensor(k)

    # Solve the problem of missing lm_head weights
    # In GPT2, lm_head usually shares weights with the token embedding layer
    if "lm_head.weight" not in state_dict and "transformer.wte.weight" in state_dict:
        state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

    # Load weights into the model
    model.load_state_dict(state_dict)

    # Move to GPU and set to evaluation mode
    model = model.cuda().eval()

    return model, config

def load_dataset(dataset_path: str) -> List[Dict[str, Any]]:
    """
    Function to load the dataset

    Args:
        dataset_path: Path to the dataset

    Returns:
        dataset: Loaded dataset
    """
    dataset = _load_data(dataset_path)
    return dataset

def evaluate_model(
    model: MyGPT2,
    tokenizer,
    dataset: List[Dict[str, Any]],
    batch_size: int = 8,
    visualize_attention: bool = False,
    attention_dir: str = "attention_maps",
    target_input_text: str = None,  # Specific input text to visualize
) -> Dict[str, float]:
    """
    Function to evaluate the model (batch processing version)

    Args:
        model: Model to evaluate
        tokenizer: Tokenizer
        dataset: Evaluation dataset
        batch_size: Batch size
        visualize_attention: Whether to visualize Attention map
        attention_dir: Directory to save Attention map
        target_input_text: Specific input text to visualize

    Returns:
        metrics: Dictionary of evaluation metrics
    """
    correct = 0
    total = 0

    # List for saving results
    results = []

    # Set up dataset and data loader
    eval_dataset = EvaluationDataset(dataset)

    # Create a data loader using a custom collate_fn
    def collate_wrapper(batch):
        return collate_fn(batch, tokenizer)

    dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_wrapper)

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Evaluating", ncols=100)):
            # Move input to GPU
            input_ids = batch["input_ids"].cuda()[:, :-1]
            attention_mask = batch["attention_mask"].cuda()[:, :-1]
            batch_targets = batch["targets"]
            batch_inputs = batch["original_inputs"]
            input_length = input_ids.size(1)
            # breakpoint()

            # Generate in batch processing (get attention with output_attentions=True)
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=50,
                pad_token_id=tokenizer.eos_token_id,
                output_attentions=True,
                return_dict_in_generate=True,
            )
            # breakpoint()

            # Decode the generated text
            generated_sequences = outputs.sequences
            generated_texts = tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)

            # Visualize Attention map
            if visualize_attention and batch_idx == 0:
                # Get attentions - this is a tuple for each output token
                attentions = outputs.attentions

                # Select a sample to visualize
                if target_input_text is not None:
                    # Find a sample that matches the specified input text
                    target_indices = [i for i, input_text in enumerate(batch_inputs) if input_text == target_input_text]
                    if not target_indices:
                        print(
                            f"Warning: Target input text '{target_input_text}' not found in the current batch. Visualizing first 3 samples instead."
                        )
                        sample_indices = range(min(3, len(generated_texts)))
                    else:
                        sample_indices = target_indices
                else:
                    # Visualize the first 3 samples as before
                    sample_indices = range(min(3, len(generated_texts)))

                # Visualize attention for the selected sample
                for sample_idx in sample_indices:
                    input_text = batch_inputs[sample_idx]
                    output_text = generated_texts[sample_idx]

                    # Get the length of input and output
                    # Get the input length directly from input_ids
                    # The output length is the length of the generated sequence - the input length
                    output_length = len(generated_sequences[sample_idx]) - input_length

                    # Create a directory for each sample
                    sample_dir = os.path.join(attention_dir, f"sample_{batch_idx}_{sample_idx}")
                    os.makedirs(sample_dir, exist_ok=True)

                    try:
                        # Visualize the 200x200 attention map combining all output tokens
                        visualize_combined_attention(
                            attentions,
                            input_length,
                            output_length,
                            os.path.join(sample_dir, "attn_viz"),
                            sample_idx,
                            input_text=input_text,
                            output_text=output_text,
                            tokenizer=tokenizer,
                        )
                    except Exception as e:
                        print(f"Error visualizing attention for sample {sample_idx}: {e}")
                        import traceback

                        traceback.print_exc()
            if "[SEP]" in generated_texts[0]:
                input_length = input_length - 1

            # Calculate the match rate
            for gen_text, target, input_text in zip(generated_texts, batch_targets, batch_inputs):
                gen_lst = gen_text.split()
                target_lst = target.split()
                input_lst = input_text.split()
                gen_lst = gen_lst[len(input_lst) :]  # Get only the output tokens

                # Add the result to the list
                results.append(
                    {
                        "input_text": " ".join(input_lst),  # Remove whitespace
                        "target_text": " ".join(target_lst),
                        "generated_text": " ".join(gen_lst),
                        "is_correct": gen_lst == target_lst,
                    }
                )

                # Evaluation (exact match)
                if gen_lst == target_lst:
                    correct += 1

                total += 1

    # Calculate the accuracy rate
    accuracy = correct / total if total > 0 else 0

    # Save the results to a CSV file
    df = pd.DataFrame(results)
    csv_path = os.path.join(attention_dir, "generation_results.csv")
    df.to_csv(csv_path, index=False, encoding="utf-8")
    print(f"\nGeneration results saved to {csv_path}")

    return {"accuracy": accuracy, "correct": correct, "total": total}

def main():
    parser = argparse.ArgumentParser(description="Evaluate GPT model on test data")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
    parser.add_argument("--checkpoint_id", type=str, default=None, help="Checkpoint ID to use")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the test dataset")
    parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for evaluation")
    parser.add_argument("--visualize_attention", action="store_true", help="Visualize attention maps")
    parser.add_argument(
        "--target_input_text", type=str, default=None, help="Specific input text to visualize attention for"
    )
    args = parser.parse_args()

    # Load the model and tokenizer
    print(f"Loading model from {args.model_path}")
    model, config = load_model(args.model_path, args.checkpoint_id)
    # breakpoint()
    tokenizer = load_tokenizer(args.model_path)

    # Load the dataset
    print(f"Loading dataset from {args.dataset_path}")
    dataset = load_dataset(args.dataset_path)

    # Evaluate the model
    print(f"Evaluating model with batch size {args.batch_size}...")
    metrics = evaluate_model(
        model,
        tokenizer,
        dataset,
        batch_size=args.batch_size,
        visualize_attention=args.visualize_attention,
        attention_dir=args.model_path,
        target_input_text=args.target_input_text,
    )

    # Output the results
    print("\nEvaluation Results:")
    print(f"Accuracy: {metrics['accuracy']:.4f} ({metrics['correct']}/{metrics['total']})")


if __name__ == "__main__":
    main()