"""
Gradient-Guided Disentangled DPO (GDO-DPO) Trainer

Implements Algorithm 1 from the paper.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from transformers import Trainer, TrainingArguments
from trl import DPOTrainer
import wandb

from .gradient_monitors import GradientMonitor


@dataclass
class GDODPOConfig:
    """Configuration for GDO-DPO training."""

    # Curriculum parameters
    tau_stable: float = 1.2  # Representation stability threshold
    tau_acc: float = 0.65    # Discrimination accuracy threshold
    delta_sem: float = 0.1   # Initial semantic step size
    delta_unc: float = 0.1   # Initial uncertainty step size

    # Layer configuration
    layer_mid: int = 21      # Boundary between repr and disc layers (for 32-layer models)

    # Monitoring parameters
    ema_decay: float = 0.9   # EMA decay γ for Srep
    eval_interval: int = 50  # Evaluation interval E_eval

    # Acceleration parameters (Equations 9-10)
    accel_factor: float = 1.1
    accel_margin_stable: float = 0.8  # Accelerate if Srep < 0.8 * tau_stable
    accel_margin_acc: float = 1.1     # Accelerate if Adisc > 1.1 * tau_acc

    # Minimum growth rate to prevent stagnation
    min_unc_growth: float = 0.01
    min_unc_growth_interval: int = 100

    # DPO parameters
    beta: float = 0.1  # DPO temperature


class GDODPOTrainer(DPOTrainer):
    """
    GDO-DPO Trainer extending DPOTrainer with curriculum learning.

    Implements Algorithm 1: Gradient-Guided Disentangled DPO
    """

    def __init__(
        self,
        model,
        ref_model,
        args: TrainingArguments,
        train_dataset,
        eval_dataset,
        tokenizer,
        gdo_config: GDODPOConfig,
        **kwargs
    ):
        """
        Args:
            model: Policy model π_θ
            ref_model: Reference model π_ref
            args: Training arguments
            train_dataset: Training data with precomputed Rsem, Runc
            eval_dataset: Validation data
            tokenizer: Tokenizer
            gdo_config: GDO-DPO configuration
            **kwargs: Additional arguments for DPOTrainer
        """
        super().__init__(
            model=model,
            ref_model=ref_model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            beta=gdo_config.beta,
            **kwargs
        )

        self.gdo_config = gdo_config

        # Initialize pace parameters (Line 3 in Algorithm 1)
        self.lambda_sem = gdo_config.delta_sem
        self.lambda_unc = gdo_config.delta_unc

        # Initialize step sizes
        self.delta_sem = gdo_config.delta_sem
        self.delta_unc = gdo_config.delta_unc

        # Initialize gradient monitor
        num_layers = model.config.num_hidden_layers
        self.gradient_monitor = GradientMonitor(
            num_layers=num_layers,
            layer_mid=gdo_config.layer_mid,
            ema_decay=gdo_config.ema_decay,
            device=args.device.type if hasattr(args.device, 'type') else 'cuda'
        )

        # Track curriculum progress
        self.curriculum_history = {
            'lambda_sem': [self.lambda_sem],
            'lambda_unc': [self.lambda_unc],
            'Srep': [],
            'Adisc': [],
        }

        # For minimum growth rate
        self.last_unc_update_step = 0

    def _get_active_dataset(self) -> List[int]:
        """
        Construct active training set D_t (Line 5 in Algorithm 1).

        Returns indices of samples where:
        Rsem(x) ≤ λ_sem AND Runc(x, y^w, y^l) ≤ λ_unc

        Returns:
            List of active sample indices
        """
        active_indices = []

        for idx, sample in enumerate(self.train_dataset):
            if (sample['Rsem'] <= self.lambda_sem and
                sample['Runc'] <= self.lambda_unc):
                active_indices.append(idx)

        return active_indices

    def get_train_dataloader(self):
        """
        Override to implement active set filtering.
        """
        # Get active indices
        active_indices = self._get_active_dataset()

        # Create subset
        if len(active_indices) < len(self.train_dataset):
            active_dataset = torch.utils.data.Subset(
                self.train_dataset,
                active_indices
            )
        else:
            active_dataset = self.train_dataset

        # Create dataloader
        return torch.utils.data.DataLoader(
            active_dataset,
            batch_size=self.args.per_device_train_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
        )

    def training_step(self, model, inputs):
        """
        Override training step to include gradient monitoring.
        """
        # Standard DPO training step (Line 6-9 in Algorithm 1)
        loss = super().training_step(model, inputs)

        # Compute layer-wise gradient statistics (Line 7-8)
        with torch.no_grad():
            srep = self.gradient_monitor.compute_representation_stability(model)

        self.curriculum_history['Srep'].append(srep)

        return loss

    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
        """
        Override to add curriculum updates.
        """
        # Call parent evaluation
        super()._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)

        # Curriculum update (Line 10-18 in Algorithm 1)
        if self.state.global_step % self.gdo_config.eval_interval == 0:
            self._update_curriculum()

    def _update_curriculum(self):
        """
        Update curriculum parameters based on gradient statistics.

        Implements Lines 10-18 of Algorithm 1.
        """
        # Compute discrimination accuracy on validation set (Line 11)
        adisc = self.gradient_monitor.compute_discrimination_accuracy(
            model=self.model,
            tokenizer=self.tokenizer,
            validation_data=self.eval_dataset,
            current_lambda_sem=self.lambda_sem,
        )

        self.curriculum_history['Adisc'].append(adisc)

        # Get current Srep
        stats = self.gradient_monitor.get_statistics()
        srep = stats['Srep']

        # Check if should advance semantic complexity (Line 12-14)
        if self.gradient_monitor.should_advance_semantic(
            tau_stable=self.gdo_config.tau_stable
        ):
            # Update step size with acceleration (Equation 9)
            if srep < self.gdo_config.accel_margin_stable * self.gdo_config.tau_stable:
                self.delta_sem *= self.gdo_config.accel_factor

            self.lambda_sem = min(1.0, self.lambda_sem + self.delta_sem)

        # Check if should advance preference uncertainty (Line 15-17)
        if self.gradient_monitor.should_advance_uncertainty(
            current_adisc=adisc,
            tau_acc=self.gdo_config.tau_acc
        ):
            # Update step size with acceleration (Equation 10)
            if adisc > self.gdo_config.accel_margin_acc * self.gdo_config.tau_acc:
                self.delta_unc *= self.gdo_config.accel_factor

            self.lambda_unc = min(1.0, self.lambda_unc + self.delta_unc)
            self.last_unc_update_step = self.state.global_step

        # Minimum growth rate to prevent stagnation
        steps_since_update = self.state.global_step - self.last_unc_update_step
        if (steps_since_update >= self.gdo_config.min_unc_growth_interval and
            self.lambda_sem >= 1.0):
            self.lambda_unc = min(
                1.0,
                self.lambda_unc + self.gdo_config.min_unc_growth
            )
            self.last_unc_update_step = self.state.global_step

        # Record history
        self.curriculum_history['lambda_sem'].append(self.lambda_sem)
        self.curriculum_history['lambda_unc'].append(self.lambda_unc)

        # Log to wandb
        if wandb.run is not None:
            wandb.log({
                'curriculum/lambda_sem': self.lambda_sem,
                'curriculum/lambda_unc': self.lambda_unc,
                'curriculum/Srep': srep,
                'curriculum/Adisc': adisc,
                'curriculum/active_samples': len(self._get_active_dataset()),
                'curriculum/delta_sem': self.delta_sem,
                'curriculum/delta_unc': self.delta_unc,
            }, step=self.state.global_step)

        # Print curriculum status
        print(f"\n[Step {self.state.global_step}] Curriculum Status:")
        print(f"  λ_sem: {self.lambda_sem:.3f}, λ_unc: {self.lambda_unc:.3f}")
        print(f"  S_rep: {srep:.3f}, A_disc: {adisc:.3f}")
        print(f"  Active samples: {len(self._get_active_dataset())}/{len(self.train_dataset)}")

    def save_curriculum_history(self, save_path: str):
        """
        Save curriculum progression history.

        Args:
            save_path: Path to save history
        """
        np.savez(
            save_path,
            **{k: np.array(v) for k, v in self.curriculum_history.items()}
        )
        print(f"Saved curriculum history to {save_path}")
