"""
Gradient Monitoring Mechanisms for GDO-DPO

Implements Srep and Adisc monitors as described in Section 4.3.
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Tuple, Optional
from collections import deque


class GradientMonitor:
    """
    Monitors layer-wise gradient statistics to guide curriculum progression.

    Implements the representation stability monitor (Srep) and
    discrimination readiness monitor (Adisc) from Section 4.3.
    """

    def __init__(
        self,
        num_layers: int,
        layer_mid: int,
        ema_decay: float = 0.9,
        device: str = "cuda",
    ):
        """
        Args:
            num_layers: Total number of layers in the model
            layer_mid: Boundary layer (Lmid) between representation and discrimination
            ema_decay: Decay factor γ for exponential moving average
            device: Device for computations
        """
        self.num_layers = num_layers
        self.layer_mid = layer_mid
        self.ema_decay = ema_decay
        self.device = device

        # Representation layers: 0 to Lmid
        self.repr_layers = list(range(layer_mid + 1))
        # Discrimination layers: Lmid+1 to L
        self.disc_layers = list(range(layer_mid + 1, num_layers))

        # EMA for Srep
        self.srep_ema = None

        # History tracking
        self.srep_history = deque(maxlen=1000)
        self.adisc_history = deque(maxlen=1000)

    def compute_representation_stability(
        self,
        model: nn.Module,
        epsilon: float = 1e-8,
    ) -> float:
        """
        Compute representation stability metric Srep.

        Following Equation 7:
        Srep = EMA_γ(Σ_{ℓ∈L_rep} ||∇θ_ℓ L||² / Σ_{ℓ∈L_disc} ||∇θ_ℓ L||²)

        Args:
            model: The model being trained
            epsilon: Small constant to prevent division by zero

        Returns:
            Srep value (float)
        """
        repr_grad_norm = 0.0
        disc_grad_norm = 0.0

        # Compute gradient norms for each layer
        for name, param in model.named_parameters():
            if param.grad is None:
                continue

            # Extract layer index from parameter name
            layer_idx = self._extract_layer_index(name)
            if layer_idx is None:
                continue

            grad_norm_sq = (param.grad ** 2).sum().item()

            if layer_idx in self.repr_layers:
                repr_grad_norm += grad_norm_sq
            elif layer_idx in self.disc_layers:
                disc_grad_norm += grad_norm_sq

        # Compute ratio
        ratio = repr_grad_norm / (disc_grad_norm + epsilon)

        # Update EMA
        if self.srep_ema is None:
            self.srep_ema = ratio
        else:
            self.srep_ema = (
                self.ema_decay * self.srep_ema +
                (1 - self.ema_decay) * ratio
            )

        self.srep_history.append(self.srep_ema)
        return self.srep_ema

    @torch.no_grad()
    def compute_discrimination_accuracy(
        self,
        model: nn.Module,
        tokenizer,
        validation_data: List[Dict],
        current_lambda_sem: float,
        low_uncertainty_threshold: float = 0.3,
    ) -> float:
        """
        Compute discrimination accuracy Adisc on validation set.

        Following Equation 8:
        Adisc = E_{(x,y^w,y^l)~D_val^{(λ)}} [1{log π_θ(y^w|x) > log π_θ(y^l|x)}]

        Args:
            model: The model being trained
            tokenizer: Tokenizer
            validation_data: Validation dataset with precomputed Rsem, Runc
            current_lambda_sem: Current semantic complexity threshold
            low_uncertainty_threshold: Threshold for "low Runc" (bottom 30%)

        Returns:
            Adisc accuracy (float)
        """
        model.eval()

        # Filter validation data: Rsem ≤ λ_sem and low Runc
        filtered_data = [
            sample for sample in validation_data
            if sample['Rsem'] <= current_lambda_sem and
               sample['Runc'] <= low_uncertainty_threshold
        ]

        if len(filtered_data) == 0:
            return 0.0

        correct = 0
        total = 0

        for sample in filtered_data:
            prompt = sample['prompt']
            chosen = sample['chosen']
            rejected = sample['rejected']

            # Compute log prob for chosen
            chosen_text = prompt + chosen
            chosen_inputs = tokenizer(
                chosen_text,
                return_tensors="pt",
                truncation=True,
                max_length=1024
            ).to(model.device)

            chosen_outputs = model(**chosen_inputs)
            chosen_logits = chosen_outputs.logits

            prompt_len = len(tokenizer(prompt, return_tensors="pt").input_ids[0])
            chosen_labels = chosen_inputs.input_ids[:, 1:]
            chosen_logits_shifted = chosen_logits[:, :-1, :]

            chosen_log_prob = 0.0
            chosen_log_probs = torch.log_softmax(chosen_logits_shifted, dim=-1)
            for i in range(prompt_len - 1, chosen_labels.shape[1]):
                if i < chosen_log_probs.shape[1]:
                    chosen_log_prob += chosen_log_probs[0, i, chosen_labels[0, i]].item()

            # Compute log prob for rejected
            rejected_text = prompt + rejected
            rejected_inputs = tokenizer(
                rejected_text,
                return_tensors="pt",
                truncation=True,
                max_length=1024
            ).to(model.device)

            rejected_outputs = model(**rejected_inputs)
            rejected_logits = rejected_outputs.logits

            rejected_labels = rejected_inputs.input_ids[:, 1:]
            rejected_logits_shifted = rejected_logits[:, :-1, :]

            rejected_log_prob = 0.0
            rejected_log_probs = torch.log_softmax(rejected_logits_shifted, dim=-1)
            for i in range(prompt_len - 1, rejected_labels.shape[1]):
                if i < rejected_log_probs.shape[1]:
                    rejected_log_prob += rejected_log_probs[0, i, rejected_labels[0, i]].item()

            # Check if model prefers chosen over rejected
            if chosen_log_prob > rejected_log_prob:
                correct += 1
            total += 1

        accuracy = correct / total if total > 0 else 0.0
        self.adisc_history.append(accuracy)

        model.train()
        return accuracy

    def _extract_layer_index(self, param_name: str) -> Optional[int]:
        """
        Extract layer index from parameter name.

        Handles common naming patterns:
        - model.layers.{idx}.*
        - transformer.h.{idx}.*
        - model.model.layers.{idx}.*

        Args:
            param_name: Parameter name

        Returns:
            Layer index or None if not found
        """
        parts = param_name.split('.')

        # Try different patterns
        for i, part in enumerate(parts):
            if part in ['layers', 'h', 'blocks']:
                if i + 1 < len(parts) and parts[i + 1].isdigit():
                    return int(parts[i + 1])

        return None

    def get_layer_wise_gradient_norms(
        self,
        model: nn.Module,
    ) -> Dict[int, float]:
        """
        Get gradient norms for each layer.

        Useful for analysis and visualization.

        Args:
            model: The model being trained

        Returns:
            Dictionary mapping layer index to gradient norm
        """
        layer_norms = {i: 0.0 for i in range(self.num_layers)}

        for name, param in model.named_parameters():
            if param.grad is None:
                continue

            layer_idx = self._extract_layer_index(name)
            if layer_idx is not None and layer_idx < self.num_layers:
                grad_norm_sq = (param.grad ** 2).sum().item()
                layer_norms[layer_idx] += grad_norm_sq

        # Take square root to get L2 norm
        layer_norms = {k: np.sqrt(v) for k, v in layer_norms.items()}
        return layer_norms

    def should_advance_semantic(
        self,
        tau_stable: float = 1.2,
    ) -> bool:
        """
        Check if semantic complexity should be advanced.

        Condition: Srep < τ_stable

        Args:
            tau_stable: Stability threshold

        Returns:
            True if should advance
        """
        if self.srep_ema is None:
            return False
        return self.srep_ema < tau_stable

    def should_advance_uncertainty(
        self,
        current_adisc: float,
        tau_acc: float = 0.65,
    ) -> bool:
        """
        Check if preference uncertainty should be advanced.

        Condition: Adisc > τ_acc

        Args:
            current_adisc: Current discrimination accuracy
            tau_acc: Accuracy threshold

        Returns:
            True if should advance
        """
        return current_adisc > tau_acc

    def get_statistics(self) -> Dict[str, float]:
        """
        Get current monitoring statistics.

        Returns:
            Dictionary with Srep and recent Adisc
        """
        return {
            'Srep': self.srep_ema if self.srep_ema is not None else 0.0,
            'Adisc': self.adisc_history[-1] if len(self.adisc_history) > 0 else 0.0,
            'Srep_std': np.std(list(self.srep_history)) if len(self.srep_history) > 0 else 0.0,
        }
