"""
Gradient Analysis for Layer-wise Localization

Implements gradient analysis to verify layer-wise gradient localization
phenomenon (Figure 1 in paper).
"""

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
from tqdm import tqdm


class GradientAnalyzer:
    """
    Analyze layer-wise gradient patterns stratified by difficulty dimensions.

    Reproduces Figure 1 and related analysis from Section 3.4.
    """

    def __init__(
        self,
        model: nn.Module,
        num_layers: int,
        device: str = "cuda"
    ):
        """
        Args:
            model: The model to analyze
            num_layers: Total number of layers
            device: Device for computations
        """
        self.model = model
        self.num_layers = num_layers
        self.device = device

    def compute_layer_wise_gradients(
        self,
        samples: List[Dict],
        tokenizer,
        beta: float = 0.1,
    ) -> Dict[int, List[float]]:
        """
        Compute layer-wise gradient norms for a set of samples.

        Args:
            samples: List of preference pairs
            tokenizer: Tokenizer
            beta: DPO temperature parameter

        Returns:
            Dictionary mapping layer index to list of gradient norms
        """
        layer_gradients = {i: [] for i in range(self.num_layers)}

        self.model.train()

        for sample in tqdm(samples, desc="Computing gradients"):
            # Zero gradients
            self.model.zero_grad()

            # Compute DPO loss
            prompt = sample['prompt']
            chosen = sample['chosen']
            rejected = sample['rejected']

            # Tokenize
            chosen_text = prompt + chosen
            rejected_text = prompt + rejected

            chosen_inputs = tokenizer(
                chosen_text,
                return_tensors="pt",
                truncation=True,
                max_length=1024
            ).to(self.device)

            rejected_inputs = tokenizer(
                rejected_text,
                return_tensors="pt",
                truncation=True,
                max_length=1024
            ).to(self.device)

            # Compute log probabilities
            with torch.enable_grad():
                # Chosen
                chosen_outputs = self.model(**chosen_inputs)
                chosen_logits = chosen_outputs.logits
                prompt_len = len(tokenizer(prompt, return_tensors="pt").input_ids[0])
                chosen_labels = chosen_inputs.input_ids[:, 1:]
                chosen_logits_shifted = chosen_logits[:, :-1, :]
                chosen_log_probs = torch.log_softmax(chosen_logits_shifted, dim=-1)

                chosen_log_prob = 0.0
                for i in range(prompt_len - 1, min(chosen_labels.shape[1], chosen_log_probs.shape[1])):
                    chosen_log_prob += chosen_log_probs[0, i, chosen_labels[0, i]]
                chosen_log_prob = chosen_log_prob / (chosen_labels.shape[1] - prompt_len + 1)

                # Rejected
                rejected_outputs = self.model(**rejected_inputs)
                rejected_logits = rejected_outputs.logits
                rejected_labels = rejected_inputs.input_ids[:, 1:]
                rejected_logits_shifted = rejected_logits[:, :-1, :]
                rejected_log_probs = torch.log_softmax(rejected_logits_shifted, dim=-1)

                rejected_log_prob = 0.0
                for i in range(prompt_len - 1, min(rejected_labels.shape[1], rejected_log_probs.shape[1])):
                    rejected_log_prob += rejected_log_probs[0, i, rejected_labels[0, i]]
                rejected_log_prob = rejected_log_prob / (rejected_labels.shape[1] - prompt_len + 1)

                # DPO loss
                margin = chosen_log_prob - rejected_log_prob
                loss = -torch.log(torch.sigmoid(beta * margin))

                # Backward
                loss.backward()

            # Extract layer-wise gradients
            for name, param in self.model.named_parameters():
                if param.grad is None:
                    continue

                layer_idx = self._extract_layer_index(name)
                if layer_idx is not None and layer_idx < self.num_layers:
                    grad_norm = param.grad.norm().item()
                    layer_gradients[layer_idx].append(grad_norm)

        return layer_gradients

    def analyze_by_difficulty(
        self,
        dataset: List[Dict],
        tokenizer,
        num_samples: int = 1000,
        tercile_threshold: float = 0.33,
    ) -> Dict[str, Dict[int, float]]:
        """
        Analyze gradient patterns stratified by difficulty dimensions.

        Reproduces Figure 1 from the paper.

        Args:
            dataset: Dataset with precomputed Csem, Upref, Rsem, Runc
            tokenizer: Tokenizer
            num_samples: Number of samples per tercile
            tercile_threshold: Threshold for top/bottom terciles

        Returns:
            Dictionary with gradient statistics
        """
        # Separate high/low Csem samples
        high_csem = [s for s in dataset if s['Rsem'] >= (1 - tercile_threshold)]
        low_csem = [s for s in dataset if s['Rsem'] <= tercile_threshold]

        # Separate high/low Upref samples
        high_upref = [s for s in dataset if s['Runc'] >= (1 - tercile_threshold)]
        low_upref = [s for s in dataset if s['Runc'] <= tercile_threshold]

        # Sample
        high_csem = np.random.choice(high_csem, min(num_samples, len(high_csem)), replace=False).tolist()
        low_csem = np.random.choice(low_csem, min(num_samples, len(low_csem)), replace=False).tolist()
        high_upref = np.random.choice(high_upref, min(num_samples, len(high_upref)), replace=False).tolist()
        low_upref = np.random.choice(low_upref, min(num_samples, len(low_upref)), replace=False).tolist()

        print("\nAnalyzing high Csem samples...")
        high_csem_grads = self.compute_layer_wise_gradients(high_csem, tokenizer)

        print("\nAnalyzing low Csem samples...")
        low_csem_grads = self.compute_layer_wise_gradients(low_csem, tokenizer)

        print("\nAnalyzing high Upref samples...")
        high_upref_grads = self.compute_layer_wise_gradients(high_upref, tokenizer)

        print("\nAnalyzing low Upref samples...")
        low_upref_grads = self.compute_layer_wise_gradients(low_upref, tokenizer)

        # Compute average gradient norms per layer
        results = {
            'high_csem': {i: np.mean(high_csem_grads[i]) if len(high_csem_grads[i]) > 0 else 0.0
                         for i in range(self.num_layers)},
            'low_csem': {i: np.mean(low_csem_grads[i]) if len(low_csem_grads[i]) > 0 else 0.0
                        for i in range(self.num_layers)},
            'high_upref': {i: np.mean(high_upref_grads[i]) if len(high_upref_grads[i]) > 0 else 0.0
                          for i in range(self.num_layers)},
            'low_upref': {i: np.mean(low_upref_grads[i]) if len(low_upref_grads[i]) > 0 else 0.0
                         for i in range(self.num_layers)},
        }

        return results

    def plot_gradient_localization(
        self,
        gradient_stats: Dict[str, Dict[int, float]],
        save_path: str = "gradient_localization.pdf"
    ):
        """
        Plot layer-wise gradient localization (Figure 1).

        Args:
            gradient_stats: Output from analyze_by_difficulty
            save_path: Path to save figure
        """
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # Normalize gradients for each plot
        # (a) Stratified by semantic complexity
        layers = list(range(self.num_layers))
        high_csem_norms = np.array([gradient_stats['high_csem'][i] for i in layers])
        low_csem_norms = np.array([gradient_stats['low_csem'][i] for i in layers])

        # Normalize
        max_val = max(high_csem_norms.max(), low_csem_norms.max())
        if max_val > 0:
            high_csem_norms /= max_val
            low_csem_norms /= max_val

        axes[0].plot(layers, high_csem_norms, 'o-', label=r'High $C_{sem}$', linewidth=2, markersize=4)
        axes[0].plot(layers, low_csem_norms, 's-', label=r'Low $C_{sem}$', linewidth=2, markersize=4)
        axes[0].set_xlabel('Layer Index', fontsize=12)
        axes[0].set_ylabel('Normalized\nGradient Norm', fontsize=12)
        axes[0].set_title('(a) Stratified by semantic complexity', fontsize=12)
        axes[0].legend(fontsize=11)
        axes[0].grid(True, alpha=0.3)
        axes[0].set_ylim([0, 1.1])

        # (b) Stratified by preference uncertainty
        high_upref_norms = np.array([gradient_stats['high_upref'][i] for i in layers])
        low_upref_norms = np.array([gradient_stats['low_upref'][i] for i in layers])

        max_val = max(high_upref_norms.max(), low_upref_norms.max())
        if max_val > 0:
            high_upref_norms /= max_val
            low_upref_norms /= max_val

        axes[1].plot(layers, high_upref_norms, 'o-', label=r'High $U_{pref}$', linewidth=2, markersize=4)
        axes[1].plot(layers, low_upref_norms, 's-', label=r'Low $U_{pref}$', linewidth=2, markersize=4)
        axes[1].set_xlabel('Layer Index', fontsize=12)
        axes[1].set_ylabel('Normalized\nGradient Norm', fontsize=12)
        axes[1].set_title('(b) Stratified by preference uncertainty', fontsize=12)
        axes[1].legend(fontsize=11)
        axes[1].grid(True, alpha=0.3)
        axes[1].set_ylim([0, 1.1])

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nSaved gradient localization plot to {save_path}")
        plt.close()

    def _extract_layer_index(self, param_name: str) -> int:
        """Extract layer index from parameter name."""
        parts = param_name.split('.')
        for i, part in enumerate(parts):
            if part in ['layers', 'h', 'blocks']:
                if i + 1 < len(parts) and parts[i + 1].isdigit():
                    return int(parts[i + 1])
        return None
