import argparse
import os
import torch
import json
from tqdm import tqdm
from transformers import GPT2Config, GPT2LMHeadModel
from safetensors import safe_open
from typing import List, Dict, Any, Tuple
from loader.data import _load_data
from loader.checkpoint import load_tokenizer

from utils.utils import (
    compute_attention_sparsity,
    compute_attention_ratio,
    compute_attention_sparsity2,
)  # Not used in this version
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
import pandas as pd
from utils.eval import EvaluationDataset, collate_fn, load_model, load_dataset
import itertools


def make_perm_family(L: int, n: int) -> torch.Tensor:
    '''
    Returns 2^n permutation matrices (L x L).
    Always includes Id (pattern=0) and Rev (pattern=2^{n-1}).
    Prerequisite: 2**(n-1) <= L
    '''
    assert 2 ** (n - 1) <= L, "n is too large (shifts overlap)."

    step = max(1, L // (2 ** (n - 1)))  # Step of Delta
    perm_mats = []

    for r in (0, 1):  # Inversion flag
        for s in range(2 ** (n - 1)):  # Shift sign
            shift = (s * step) % L
            idx = [(L - 1 - ((i + shift) % L)) if r else ((i + shift) % L) for i in range(L)]
            perm_mats.append(torch.eye(L)[idx])  # one-hot row
    return torch.stack(perm_mats)


def visualize_combined_attention(
    layer_attention_sample,  # [num_heads, full_seq_len, full_seq_len]
    input_actual_len,  # Actual number of input tokens without padding
    save_path_prefix,  # e.g., "attention_maps/sample_0_layer_11" (appends _attention.png in the function)
    full_seq_len,  # Sequence length with padding (dimension of the heatmap)
    input_text_str=None,  # Original input text (decoded)
    tokenizer_obj=None,  # Tokenizer object
    inverse=False,
):
    '''
    Function to visualize the Attention map of a single specified layer (for model.forward()).
    Assumes the number of heads is always 1.
    '''
    os.makedirs(os.path.dirname(save_path_prefix), exist_ok=True)
    # num_heads = layer_attention_sample.size(0) # Not used directly as we assume 1 head

    # Prepare axis labels for the heatmap
    viz_labels = None
    if input_text_str and tokenizer_obj:
        tokens = tokenizer_obj.tokenize(input_text_str)
        actual_tokens_labels = tokens[:input_actual_len]
        if input_actual_len < full_seq_len:
            padding_labels = [f"T{i+input_actual_len}" for i in range(full_seq_len - input_actual_len)]
            viz_labels = actual_tokens_labels + padding_labels
        else:
            viz_labels = actual_tokens_labels[:full_seq_len]
    elif full_seq_len > 0:
        viz_labels = [str(i) for i in range(full_seq_len)]
    else:
        viz_labels = False

    # Attention to visualize (get data for the first head as we assume 1 head)
    # layer_attention_sample is expected to be [1, full_seq_len, full_seq_len]
    single_head_attention_data = layer_attention_sample[0]  # [full_seq_len, full_seq_len]

    # Visualize the Attention map (only one head)
    head_attn_np = single_head_attention_data.detach().cpu().numpy()

    fig_width = max(10, full_seq_len / 2.0 if full_seq_len > 0 else 10)
    fig_height = max(8, full_seq_len / 2.5 if full_seq_len > 0 else 8)
    plt.figure(figsize=(fig_width, fig_height))

    sns.heatmap(
        head_attn_np, xticklabels=viz_labels, yticklabels=viz_labels, cmap="YlGnBu", cbar=True, vmax=0.1, vmin=0.0
    )
    # Remove Head Index from title (since there is only one head)
    plt.title(f"Forward Attention", fontsize=10)
    plt.xlabel("Key Token Position", fontsize=8)
    plt.ylabel("Query Token Position", fontsize=8)

    if input_actual_len > 0 and input_actual_len < full_seq_len:
        plt.axvline(x=input_actual_len - 0.5, color="r", linestyle="--", linewidth=1.5)
        plt.axhline(y=input_actual_len - 0.5, color="r", linestyle="--", linewidth=1.5)

    plt.xticks(rotation=45, ha="right", fontsize=7)
    plt.yticks(rotation=0, fontsize=7)
    plt.tight_layout(pad=1.5)
    # Remove Head Index from filename
    if inverse:
        plt.savefig(f"{save_path_prefix}_attention_inv.png", dpi=150)
    else:
        plt.savefig(f"{save_path_prefix}_attention.png", dpi=150)
    plt.close()

    # Block for visualizing the average of all heads is removed (redundant for one head)

    print(f"Attention map saved: {save_path_prefix}_attention.png")


def evaluate_model(
    model: GPT2LMHeadModel,
    tokenizer,
    dataset: List[Dict[str, Any]],
    batch_size: int = 8,
    visualize_attention: bool = False,
    attention_dir: str = "attention_maps",
    input_len: int = 50,
    target_input_text: str = None,  # Specific input text to visualize
    max_output_length: int = 50,  # This argument is not actually used as no text is generated
    target_layer: int = -1,  # Transformer layer to visualize/analyze (default is the last layer)
    candidate_rank: int = 2,  # Number of candidates to visualize (default is 2)
) -> None:  # Returns None as it does not calculate evaluation metrics
    '''
    Function to run the model's forward pass and visualize attention.
    Does not perform accuracy calculation or text generation.

    Args:
        model: Model to evaluate
        tokenizer: Tokenizer
        dataset: Evaluation dataset
        batch_size: Batch size
        visualize_attention: Whether to visualize Attention map
        attention_dir: Directory to save Attention map
        target_input_text: Specific input text to visualize
        max_output_length: (Unused) Maximum length for text generation
        target_layer: Transformer layer to visualize/analyze
    '''
    # Set up dataset and data loader
    eval_dataset = EvaluationDataset(dataset)

    def collate_wrapper(batch):
        return collate_fn(batch, tokenizer)

    dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_wrapper)

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing Batches", ncols=100)):
            input_ids_full = batch["input_ids"][:, :-1].cuda()  # [batch_size, seq_len]
            attention_mask_full = batch["attention_mask"][:, :-1].cuda()  # [batch_size, seq_len]
            input_ids_full_inv = batch["input_ids_inv"][:, :-1].cuda()  # [batch_size, seq_len]
            attention_mask_full_inv = batch["attention_mask_inv"][:, :-1].cuda()  # [batch_size, seq_len]
            # batch_targets = batch["targets"] # Not needed as we don't calculate accuracy
            batch_inputs = batch["original_inputs"]

            outputs = model(
                input_ids=input_ids_full,
                attention_mask=attention_mask_full,
                output_attentions=True,
            )
            attentions = outputs.attentions  # Tuple of [batch_size, num_heads, seq_len, seq_len]

            outputs_inv = model(
                input_ids=input_ids_full_inv,
                attention_mask=attention_mask_full_inv,
                output_attentions=True,
            )
            attentions_inv = outputs_inv.attentions  # Tuple of [batch_size, num_heads, seq_len, seq_len]
            # breakpoint()
            # attentions = attentions_inv # Change this line to use inverse attention

            if not visualize_attention or attentions is None:
                continue  # To the next batch if not visualizing

            input_lengths = [len(tokenizer.encode(text)) for text in batch_inputs]  # Actual input length without padding
            target_layer_attention = attentions[target_layer]  # [batch_size, num_heads, seq_len, seq_len]
            seq_len_padded = target_layer_attention.size(-1)  # Sequence length with padding
            target_layer_attention_inv = attentions_inv[target_layer]  # [batch_size, num_heads, seq_len, seq_len]

            # Select samples to visualize (as before)
            if target_input_text is not None:
                target_indices = [i for i, text in enumerate(batch_inputs) if text == target_input_text]
                if not target_indices:
                    print(
                        f"Warning: Target input text '{target_input_text}' not found. Visualizing first min(3, batch_size) samples."
                    )
                    sample_indices = range(min(3, input_ids_full.size(0)))
                else:
                    sample_indices = target_indices
            else:
                sample_indices = range(min(3, input_ids_full.size(0)))

            for sample_idx in sample_indices:
                if sample_idx >= input_ids_full.size(0):
                    continue

                input_text = batch_inputs[sample_idx]
                sample_input_actual_len = input_lengths[sample_idx]
                attention_sample_layer = target_layer_attention[
                    sample_idx
                ]  # [num_heads, seq_len_padded, seq_len_padded]

                # Generate save path prefix
                sample_viz_dir = os.path.join(
                    attention_dir, f"forward_sample_batch{batch_idx}_sample{sample_idx}_layer{target_layer}"
                )
                # os.makedirs(sample_viz_dir, exist_ok=True) # Not needed as the function creates the directory
                save_prefix = os.path.join(sample_viz_dir, "attention_viz")  # Filename itself is appended in the function

                try:
                    pass
                    # visualize_combined_attention(
                    #     layer_attention_sample=attention_sample_layer,
                    #     input_actual_len=sample_input_actual_len,
                    #     save_path_prefix=save_prefix,
                    #     full_seq_len=seq_len_padded,
                    #     input_text_str=input_text,
                    #     tokenizer_obj=tokenizer,
                    # )
                    # visualize_combined_attention(
                    #     layer_attention_sample=target_layer_attention_inv[sample_idx],
                    #     input_actual_len=sample_input_actual_len,
                    #     save_path_prefix=save_prefix,
                    #     full_seq_len=seq_len_padded,
                    #     input_text_str=input_text,
                    #     tokenizer_obj=tokenizer,
                    #     inverse=True,
                    # )
                except Exception as e:
                    print(f"Error visualizing attention for sample_idx {sample_idx} in batch {batch_idx}: {e}")
                    import traceback

                    traceback.print_exc()

            # all_perms = generate_all_permutation_matrices(input_len)
            # Ps = all_perms
            Ps = make_perm_family(input_len, candidate_rank)  # Generate 2^n permutation matrices

            # List to store sparsity information
            sparsity_info = []

            for i in range(len(Ps)):
                perm_mat = Ps[i]

                B, L = input_ids_full.shape
                target_len = L - input_len
                if target_len <= 0:
                    raise ValueError("target_len must be positive")

                # --- 1) Extract head and tail ---
                prefix = input_ids_full[:, :input_len]  # [B, input_len]
                targets = input_ids_full[:, input_len : input_len + target_len]  # [B, T]

                # --- 2) Convert permutation to tensor index ---
                if perm_mat.dim() == 2:  # Broadcast if common matrix [T,T] -> [1,T,T]
                    perm_mat = perm_mat.unsqueeze(0).expand(B, -1, -1)  # [B, T, T]
                if perm_mat.shape[1:] != (target_len, target_len):
                    raise ValueError("perm_mat shape mismatch")

                order_idx = perm_mat.argmax(dim=-1).cuda()  # [B, T]  (each row is "original index after permutation")

                # --- 3) Permute target with gather ---
                permuted_targets = targets.gather(-1, order_idx)

                # --- 4) Concatenate with original prefix / suffix ---
                permuted_input_ids = torch.cat([prefix, permuted_targets], dim=-1)

                outputs = model(
                    input_ids=permuted_input_ids,
                    attention_mask=attention_mask_full,
                    output_attentions=True,
                )
                attentions = outputs.attentions  # Tuple of [batch_size, num_heads, seq_len, seq_len]
                sparsitys = compute_attention_sparsity(attentions[target_layer][:, :, input_len:])

                # Store sparsity information (take batch average)
                sparsity_info.append(
                    {
                        "permutation_matrix": perm_mat[0].cpu(),  # Store only the first element of the batch
                        "mean_sparsity": sparsitys["mean_sparsity"],  # Take batch average
                        "std_sparsity": sparsitys["std_sparsity"],  # Take batch average
                    }
                )

                print(f"Permutation {i} Sparsity: {sparsitys['mean_sparsity']}")
                print(f"Permutation {i} Sparsity (std): {sparsitys['std_sparsity']}")

            # Sort by sparsity and rank
            sorted_indices = sorted(
                range(len(sparsity_info)), key=lambda i: sparsity_info[i]["mean_sparsity"], reverse=True
            )

            # Add rank information
            for rank, idx in enumerate(sorted_indices):
                sparsity_info[idx]["rank"] = rank + 1

            # Save results
            save_dict = {
                "permutation_matrices": torch.stack([info["permutation_matrix"] for info in sparsity_info]),
                "sparsity_info": sparsity_info,
            }

            save_path = os.path.join(attention_dir, f"permutation_sparsity_results.pt")
            torch.save(save_dict, save_path)
            print(f"\nResults saved to {save_path}")

    print("\nProcessing finished.")


def generate_all_permutation_matrices(N: int):
    '''
    Generate all N! permutation matrices of size NxN.
    Returns:
        A tensor of shape (N!, N, N), where each [i] is a permutation matrix.
    '''
    perms = list(itertools.permutations(range(N)))  # N! permutations
    matrices = []

    for perm in tqdm(perms):
        mat = torch.zeros(N, N)
        for i, j in enumerate(perm):
            mat[i, j] = 1.0
        matrices.append(mat)

    return torch.stack(matrices)  # Shape: (N!, N, N)


def main():
    parser = argparse.ArgumentParser(description="Visualize model attentions using forward pass")  # Changed description
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
    parser.add_argument("--checkpoint_id", type=str, default=None, help="Checkpoint ID to use")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the test dataset")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for processing")  # Recommend changing default
    parser.add_argument("--visualize_attention", action="store_true", help="Visualize attention maps")
    parser.add_argument(
        "--target_input_text", type=str, default=None, help="Specific input text to visualize attention for"
    )
    # max_output_length is no longer needed, but if left for backward compatibility, amend help
    parser.add_argument(
        "--max_output_length", type=int, default=50, help="(Unused in this version) Max output length for generation"
    )
    parser.add_argument(
        "--target_layer",
        type=int,
        default=-1,
        help="Target transformer layer for attention visualization (-1 for last)",
    )
    parser.add_argument(
        "--input_len",
        type=int,
        default=50,
        help="Length of the input sequence for permutation (default: 50)",
    )
    parser.add_argument(
        "--candidate_rank",
        type=int,
        default=2,
        help="Number of candidates to visualize (default: 2)",
    )
    args = parser.parse_args()

    print(f"Loading model from {args.model_path}")
    model, config = load_model(args.model_path, args.checkpoint_id)
    tokenizer = load_tokenizer(args.model_path)

    print(f"Loading dataset from {args.dataset_path}")
    dataset = load_dataset(args.dataset_path)

    print(f"Processing dataset to visualize attentions with batch size {args.batch_size}...")
    evaluate_model(
        model,
        tokenizer,
        dataset,
        batch_size=args.batch_size,
        visualize_attention=args.visualize_attention,
        attention_dir=args.model_path,  # Save to model path directly
        input_len=args.input_len,
        target_input_text=args.target_input_text,
        max_output_length=args.max_output_length,  # Passed as argument but not used
        target_layer=args.target_layer,
        candidate_rank=args.candidate_rank,
    )

    print("\nProcessing finished.")


if __name__ == "__main__":
    main()