import torch
import wandb
import torch.nn as nn

from models.transformer import MultiHeadSelfAttention


class LastTokenCrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, outputs, targets):
        # Select predictions for the last token
        logits_last = outputs[:, -1, :]
        # If targets are sequences, use the last token as target;
        # otherwise assume targets are already the last token predictions.
        targets_last = targets[:, -1] if targets.ndim == 2 else targets
        return self.loss_fn(logits_last, targets_last)


def get_data(param):
    if hasattr(param, "data"):
        return param.data
    elif hasattr(param, "weight") and hasattr(param.weight, "data"):
        return param.weight.data
    else:
        raise ValueError("Parameter does not have accessible data.")


def _log_param_metrics(param, param_name_prefix, metrics, wandb_writer):
    """Helper function to log norm, histogram, and gradient norm for a parameter."""
    if isinstance(param, nn.Linear):
        weight_data = param.weight.data
        weight_grad = param.weight.grad
    elif isinstance(param, nn.Parameter):
        weight_data = param.data
        weight_grad = param.grad
    else:
        # Optionally handle other types or raise an error
        print(f"Warning: Unsupported parameter type {type(param)} for {param_name_prefix}")
        return

    # Log Norm and Histogram
    norm = weight_data.norm(2).item()
    metrics[f"Norms/{param_name_prefix}_norm"] = norm
    # Check if wandb_writer is available before creating histogram
    if wandb_writer:
        metrics[f"Norms/{param_name_prefix}_hist"] = wandb.Histogram(
            weight_data.detach().cpu().numpy()
        )
    else:
        metrics[f"Norms/{param_name_prefix}_hist"] = None  # Or skip if no writer

    # Log Gradient Norm
    grad_norm = None
    if weight_grad is not None:
        grad_norm = weight_grad.data.norm(2).item()
    metrics[f"Grad_Norms/{param_name_prefix}_grad_norm"] = grad_norm


def log_model_matrix_norms(model, wandb_writer, iter):
    """
    Logs L2 norms, histograms, and gradient norms for relevant matrices
    in both DisTransformer and Transformer models.
    """
    metrics = {}

    # Loop over transformer layers
    for layer_idx, layer in enumerate(model.layers):
        if isinstance(layer, MultiHeadSelfAttention):
            # --- Handle MultiHeadSelfAttention (from standard Transformer) ---
            if layer.qk_param:
                if layer.W_Q:
                    _log_param_metrics(layer.W_Q, f"layer_{layer_idx}_W_Q", metrics, wandb_writer)
                if layer.W_K:
                    _log_param_metrics(layer.W_K, f"layer_{layer_idx}_W_K", metrics, wandb_writer)
            else:
                if layer.W_A:
                    _log_param_metrics(layer.W_A, f"layer_{layer_idx}_W_A", metrics, wandb_writer)

            # Log other potential matrices in MultiHeadSelfAttention
            if layer.enable_value and hasattr(layer, "W_V") and layer.W_V:
                _log_param_metrics(layer.W_V, f"layer_{layer_idx}_W_V", metrics, wandb_writer)
            if layer.enable_Wout and hasattr(layer, "W_out") and layer.W_out:
                _log_param_metrics(layer.W_out, f"layer_{layer_idx}_W_out", metrics, wandb_writer)
            if layer.enable_mlp:
                if hasattr(layer, "W_F1") and layer.W_F1:
                    _log_param_metrics(layer.W_F1, f"layer_{layer_idx}_W_F1", metrics, wandb_writer)
                if hasattr(layer, "W_F2") and layer.W_F2:
                    _log_param_metrics(layer.W_F2, f"layer_{layer_idx}_W_F2", metrics, wandb_writer)
        else:
            print(
                f"Warning: Unknown layer type {type(layer)} at index {layer_idx}. Skipping norm logging for this layer."
            )

    # Log the output layer norm and histograms (common logic)
    if hasattr(model, "output_layer") and model.output_layer:
        _log_param_metrics(model.output_layer, "output_layer", metrics, wandb_writer)
    elif (
        hasattr(model, "W_1") and hasattr(model, "W_2") and model.W_1 and model.W_2
    ):  # For DisTransformer MLP case
        _log_param_metrics(model.W_1, "mlp_W_1", metrics, wandb_writer)
        _log_param_metrics(model.W_2, "mlp_W_2", metrics, wandb_writer)

    # Log all collected metrics if wandb_writer is available
    if wandb_writer and metrics:
        wandb_writer.log(metrics, step=iter)


def evaluate_model(
    model, train_loader, loss_function, device, iter, wandb_writer, out_dir, plot_interval=200000
):
    model.eval()
    with torch.no_grad():
        eval_sequence = train_loader.test_tensor

        ####################################### TEST LOSS AND ATTENTION ENTROPY #######################################
        # Prepare inputs and targets
        inputs = eval_sequence[:, :-1].long().to(device)
        targets = eval_sequence[:, 1:].long().to(device)
        outputs, att_l = model(inputs, return_attention=True)

        if wandb_writer is not None:
            loss = loss_function(
                outputs[:, -1, :].unsqueeze(1).transpose(1, 2), targets[:, -1].unsqueeze(1)
            )
            wandb_writer.log({f"Eval Loss/Transformer": loss}, step=iter)

            # Calculate and log evaluation accuracy
            preds = torch.argmax(outputs[:, -1, :], dim=-1)
            correct_predictions = (preds == targets[:, -1]).float()
            accuracy = correct_predictions.sum() / len(correct_predictions)
            wandb_writer.log({f"Eval Accuracy/Transformer": accuracy.item()}, step=iter)

            # Log model parameter norms and histograms
            log_model_matrix_norms(model, wandb_writer, iter)

            # Log attention to relevant tokens if enabled in config
            if wandb_writer.config.run["log_token_attention"]:
                log_token_attention(att_l, eval_sequence.cpu(), iter, wandb_writer)

    return


def identify_token_positions(sequences):
    """
    Identify the positions of key tokens and key-value tokens in associative recall sequences.

    Args:
        sequences: Tensor of shape [batch_size, seq_len] containing the sequence data

    Returns:
        dict: Dictionary containing positions of relevant tokens
            - key_positions: List of positions for each query key token in the batch
            - value_positions: List of positions for corresponding value tokens
            - query_positions: Tensor of query token positions (always at -2)
            - query_tokens: Tensor of query token values
            - target_tokens: Tensor of target token values (always at -1)
    """
    batch_size, seq_len = sequences.shape
    positions = {}

    # Query is at position -2 (second to last position)
    query_tokens = sequences[:, -2]  # [batch_size]
    positions["query_positions"] = torch.full((batch_size,), seq_len - 2, device=sequences.device)
    positions["query_tokens"] = query_tokens
    positions["target_tokens"] = sequences[:, -1]  # Target is always last token

    # Check for matching tokens in the context (pairs are at even positions)
    key_positions = []
    for b in range(batch_size):
        # Check every other position (keys are at 0, 2, 4, etc. in context)
        matches = []
        for pos in range(0, seq_len - 2, 2):
            if sequences[b, pos] == query_tokens[b]:
                matches.append(pos)
        key_positions.append(matches)

    # Store corresponding value positions (always next to key positions)
    value_positions = []
    for b in range(batch_size):
        value_positions.append([pos + 1 for pos in key_positions[b]])

    positions["key_positions"] = key_positions
    positions["value_positions"] = value_positions

    return positions


def log_token_attention(att_l, sequences, iter, wandb_writer):
    """
    Log attention sent to relevant key tokens and key-value tokens per layer and per head.

    Args:
        att_l: List of attention matrices, one per layer [layer][head][batch, seq_len, seq_len]
        sequences: Tensor of shape [batch_size, seq_len] containing the sequence data
        iter: Current iteration number for logging
        wandb_writer: WandB logger
    """
    if not wandb_writer:
        return

    # Identify positions of relevant tokens
    token_positions = identify_token_positions(sequences)

    # For each layer and head, calculate average attention to relevant tokens
    metrics = {}

    # Process each layer's attention matrices
    for layer_idx, layer_att in enumerate(att_l):
        # Handle attention matrices for each head in the layer
        for head_idx, head_att in enumerate(layer_att):
            # Extract batch x seq_len x seq_len attention matrices
            # Shape might be [batch_size, seq_len, seq_len] or [batch_size, num_heads, seq_len, seq_len]
            if len(head_att.shape) == 4:  # [batch, heads, seq_len, seq_len]
                att_matrix = head_att
            else:  # [batch, seq_len, seq_len]
                att_matrix = head_att

            batch_size = att_matrix.shape[0]

            # Calculate attention to key tokens from the next token (i.e. the value token)
            if "key_positions" in token_positions:
                key_attention_scores = []

                # For each sequence in the batch
                for b in range(batch_size):
                    # Get positions of key tokens for this sequence
                    key_pos = token_positions["key_positions"][b]

                    if key_pos:  # Check if we have any key positions
                        # Average attention from query to all matching key tokens
                        avg_attention = torch.mean(att_matrix[b, [x + 1 for x in key_pos], key_pos])
                        key_attention_scores.append(avg_attention.item())

                if key_attention_scores:
                    avg_key_attention = sum(key_attention_scores) / len(key_attention_scores)
                    metrics[f"attention/layer{layer_idx}_head{head_idx}_key_tokens"] = (
                        avg_key_attention
                    )

            # Calculate attention to value tokens
            if "value_positions" in token_positions:
                value_attention_scores = []

                # For each sequence in the batch
                for b in range(batch_size):
                    # Get positions of value tokens for this sequence
                    val_pos = token_positions["value_positions"][b]

                    if val_pos:  # Check if we have any value positions
                        # Get attention from query position to value positions
                        query_pos = (
                            token_positions["query_positions"][b].item()
                            if isinstance(token_positions["query_positions"], torch.Tensor)
                            else token_positions["query_positions"][b]
                        )

                        # Average attention from query to all matching value tokens
                        avg_attention = torch.mean(att_matrix[b, query_pos, val_pos])
                        value_attention_scores.append(avg_attention.item())

                if value_attention_scores:
                    avg_value_attention = sum(value_attention_scores) / len(value_attention_scores)
                    metrics[f"attention/layer{layer_idx}_head{head_idx}_value_tokens"] = (
                        avg_value_attention
                    )

    # Log all metrics to WandB
    if metrics and wandb_writer:
        wandb_writer.log(metrics, step=iter)
