import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from aim import Image
import wandb
from tqdm import tqdm


def log_model_param_norms(model, wandb_writer, iter):
    """Logs L2 norms and gradient norms for model parameters."""
    metrics = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            # Log Parameter Norm
            norm = param.data.norm(2).item()
            metrics[f"Norms/{name}_norm"] = norm

            # Log Gradient Norm
            grad_norm = None
            if param.grad is not None:
                grad_norm = param.grad.data.norm(2).item()
            metrics[f"Grad_Norms/{name}_grad_norm"] = grad_norm

    if wandb_writer and metrics:
        wandb_writer.log(metrics, step=iter)


def evaluate_model(
    model,
    loader,
    loss_function,
    device,
    iter,
    wandb_writer,
    out_dir,
    config=None,
    plot_interval=200000,
    prefix="Primary",
):
    """
    Evaluates the model on the test set for the Single Location Linear Regression task.
    Also calculates and logs aggregated attention to relevant tokens per layer if configured.

    Args:
        model: The model to evaluate.
        X_test (torch.Tensor): Test input data [test_size, seq_len, dim].
        Y_test (torch.Tensor): Test target data [test_size, dim].
        loss_function: The loss function (e.g., F.mse_loss).
        device: Torch device.
        iter: Current iteration number for logging.
        wandb_writer: WandB writer object.
        out_dir: Output directory (unused in this simplified version, but kept for consistency).
        config: The experiment configuration object (OmegaConf).
        plot_interval: How often to generate plots (unused here).
        prefix: Prefix for the logged metrics to differentiate between evaluation datasets.
    """
    model.eval()
    test_loss = 0.0
    # Initialize list to store total relevant attention sum *per layer*
    layer_total_relevant_attention = None
    num_batches = 0
    num_attention_layers = 0
    num_attention_heads = 0  # Assume constant across layers for simplicity
    log_attention = wandb_writer is not None and (
        config.dataset.show_relevant_token or not config.dataset.random_relevant_token_positions
    )

    test_batch_size = 512  # Use a reasonable batch size for evaluation

    X_test, Y_test = loader.X_test, loader.Y_test

    with torch.no_grad():
        for i in range(0, X_test.size(0), test_batch_size):
            # Get batch
            X_batch = X_test[i : i + test_batch_size].to(device)
            Y_batch = Y_test[i : i + test_batch_size].to(device)

            # Forward pass - request attention if logging
            outputs, attention_weights_list = model(
                X_batch, return_attention=True
            )  # Input [batch, seq_len, dim]

            # Assuming the relevant prediction is at the last sequence position
            predicted_Y = outputs[:, -1, :]  # Shape: [batch, dim]

            # Calculate loss for the batch
            loss = loss_function(predicted_Y, Y_batch).sum(axis=-1).mean()
            test_loss += loss.item()

            # --- Calculate Aggregated Attention to Relevant Tokens Per Layer ---
            if log_attention and attention_weights_list:
                num_layers = len(attention_weights_list)
                if num_layers > 0:
                    # Initialize layer attention accumulator on first batch
                    if layer_total_relevant_attention is None:
                        num_attention_layers = num_layers
                        layer_total_relevant_attention = [0.0] * num_attention_layers
                        # Get number of heads from the first layer (assume constant)
                        if len(attention_weights_list[0]) > 0:
                            num_attention_heads = len(attention_weights_list[0])
                        else:
                            log_attention = False  # Disable if no heads found
                            print(
                                "Warning: Layer 0 has no attention heads. Disabling attention logging."
                            )
                            continue  # Skip attention calculation for this batch

                    # Find relevant token indices for this batch
                    if config.dataset.show_relevant_token:
                        # Assumes the last feature indicates the relevant token
                        relevant_indices = (X_batch[:, :, -1] == 1.0).nonzero(as_tuple=False)
                    elif not config.dataset.random_relevant_token_positions:
                        burst_indices = loader.task.burst_indices
                        burst_indices = torch.tensor(burst_indices, device=device, dtype=torch.long)
                        batch_size = X_batch.size(0)
                        # Create batch indices: [0, 0, ..., 1, 1, ..., batch_size-1, ...]
                        batch_indices_col = (
                            torch.arange(batch_size, device=device)
                            .unsqueeze(1)
                            .repeat(1, burst_indices.numel())
                            .flatten()
                        )
                        # Create sequence indices: [b1, b2, ..., b1, b2, ..., b1, b2, ...]
                        sequence_indices_col = burst_indices.repeat(batch_size)
                        # Stack them: [[0, b1], [0, b2], ..., [1, b1], [1, b2], ...]
                        relevant_indices = torch.stack(
                            [batch_indices_col, sequence_indices_col], dim=1
                        )

                    if relevant_indices.numel() > 0:  # Only proceed if relevant tokens exist
                        batch_indices = relevant_indices[:, 0]
                        sequence_indices = relevant_indices[:, 1]

                        # Accumulate attention for each layer separately
                        for layer_idx, layer_attention_heads in enumerate(attention_weights_list):
                            layer_batch_head_attention_sum = 0.0
                            current_num_heads = len(layer_attention_heads)
                            if current_num_heads != num_attention_heads:
                                print(
                                    f"Warning: Layer {layer_idx} has {current_num_heads} heads, expected {num_attention_heads}. Using {current_num_heads} for averaging."
                                )
                                current_num_heads_for_avg = (
                                    current_num_heads if current_num_heads > 0 else 1
                                )  # Avoid division by zero
                            else:
                                current_num_heads_for_avg = (
                                    num_attention_heads if num_attention_heads > 0 else 1
                                )

                            if current_num_heads == 0:
                                continue  # Skip layer if it has no heads

                            for head_idx, head_attention in enumerate(layer_attention_heads):
                                # Attention for the last query token: shape [batch_size, seq_len]
                                last_token_head_attention = head_attention[:, -1, :]

                                # Gather attention weights pointing to relevant keys
                                attention_to_relevant_keys = last_token_head_attention[
                                    batch_indices, sequence_indices
                                ]

                                # Sum attention per batch item
                                attention_sum_per_batch_item = torch.zeros(
                                    X_batch.size(0), device=device
                                )
                                attention_sum_per_batch_item.scatter_add_(
                                    0, batch_indices, attention_to_relevant_keys
                                )

                                # Add the mean attention sum for this head to the layer's batch total
                                layer_batch_head_attention_sum += (
                                    attention_sum_per_batch_item.mean().item()
                                )

                            # Average attention sum across heads for this layer and batch
                            layer_avg_attention_for_batch = (
                                layer_batch_head_attention_sum / current_num_heads_for_avg
                            )
                            # Add the layer's average batch attention to the overall layer total
                            if layer_idx < len(layer_total_relevant_attention):
                                layer_total_relevant_attention[
                                    layer_idx
                                ] += layer_avg_attention_for_batch
                            else:
                                # Should not happen if num_layers is consistent
                                print(
                                    f"Warning: Layer index {layer_idx} out of bounds for accumulator."
                                )

                    else:
                        # Handle case where no relevant tokens are found in the batch
                        pass  # layer_total_relevant_attention remains unchanged for this batch
                else:
                    log_attention = False  # Disable if attention list is empty
                    print("Warning: attention_weights_list is empty. Disabling attention logging.")

            num_batches += 1

    # Calculate average test loss
    avg_test_loss = test_loss / num_batches if num_batches > 0 else 0.0

    # Calculate average relevant attention per layer
    avg_layer_relevant_attention = []
    if log_attention and layer_total_relevant_attention is not None and num_batches > 0:
        avg_layer_relevant_attention = [
            total / num_batches for total in layer_total_relevant_attention
        ]

    # Log evaluation metrics
    if wandb_writer is not None:
        log_payload = {f"Eval loss/MSE ({prefix})": avg_test_loss}
        if log_attention and avg_layer_relevant_attention:
            for layer_idx, avg_attn in enumerate(avg_layer_relevant_attention):
                log_payload[
                    f"Eval Attention/Layer_{layer_idx}_Relevant_Token_Attention ({prefix})"
                ] = avg_attn
        wandb_writer.log(log_payload, step=iter)

    return avg_test_loss
