"""
Training Dynamics Analysis

Implements visualization and analysis of training dynamics
(Figures 4, 5 and related analysis).
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List


class TrainingDynamicsAnalyzer:
    """
    Analyze and visualize training dynamics.

    Reproduces Figures 4, 5 from Section 5.4.
    """

    def __init__(self):
        sns.set_style("whitegrid")
        sns.set_palette("husl")

    def plot_curriculum_dynamics(
        self,
        history: Dict[str, List[float]],
        save_path: str = "curriculum_dynamics.pdf"
    ):
        """
        Plot curriculum progression during training (Figure 4).

        Shows λ_sem, λ_unc, S_rep, A_disc over training steps.

        Args:
            history: Dictionary with keys 'lambda_sem', 'lambda_unc', 'Srep', 'Adisc'
            save_path: Path to save figure
        """
        fig, ax = plt.subplots(figsize=(10, 6))

        steps = list(range(len(history['lambda_sem'])))

        # Plot pace parameters (solid lines)
        ax.plot(steps, history['lambda_sem'], '-', linewidth=2.5,
                label=r'$\lambda_{sem}$', color='#1f77b4')
        ax.plot(steps, history['lambda_unc'], '-', linewidth=2.5,
                label=r'$\lambda_{unc}$', color='#ff7f0e')

        # Plot monitoring metrics (dashed lines)
        # Normalize Srep and Adisc to [0, 1] for visualization
        srep_normalized = np.array(history['Srep'])
        if len(srep_normalized) > 0:
            srep_normalized = srep_normalized / max(srep_normalized.max(), 1.0)

        adisc_normalized = np.array(history['Adisc'])

        srep_steps = steps[:len(srep_normalized)]
        adisc_steps = steps[:len(adisc_normalized)]

        ax.plot(srep_steps, srep_normalized, '--', linewidth=2,
                label=r'$S_{rep}$', color='#2ca02c', alpha=0.8)
        ax.plot(adisc_steps, adisc_normalized, '--', linewidth=2,
                label=r'$A_{disc}$', color='#d62728', alpha=0.8)

        ax.set_xlabel('Training Step', fontsize=13)
        ax.set_ylabel('Value', fontsize=13)
        ax.set_title('Curriculum Dynamics During Training', fontsize=14)
        ax.legend(fontsize=12, loc='best')
        ax.grid(True, alpha=0.3)
        ax.set_ylim([0, 1.1])

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved curriculum dynamics plot to {save_path}")
        plt.close()

    def plot_gradient_variance(
        self,
        variance_histories: Dict[str, List[float]],
        save_path: str = "gradient_variance.pdf"
    ):
        """
        Plot gradient variance in representation layers (Figure 5).

        Shows variance reduction under different training regimes.

        Args:
            variance_histories: Dict with keys 'DPO', 'High-Upref-First', 'GDO-DPO'
            save_path: Path to save figure
        """
        fig, ax = plt.subplots(figsize=(10, 6))

        colors = {
            'DPO': '#1f77b4',
            'High-Upref-First': '#ff7f0e',
            'GDO-DPO': '#2ca02c'
        }

        for name, variance in variance_histories.items():
            steps = list(range(len(variance)))
            ax.plot(steps, variance, linewidth=2.5,
                   label=name, color=colors.get(name, 'gray'))

        ax.set_xlabel('Training Step', fontsize=13)
        ax.set_ylabel('Gradient Variance\n(Layers 0-16)', fontsize=13)
        ax.set_title('Gradient Variance in Representation Layers', fontsize=14)
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved gradient variance plot to {save_path}")
        plt.close()

    def plot_mt_bench_by_category(
        self,
        dpo_scores: Dict[str, float],
        gdo_dpo_scores: Dict[str, float],
        save_path: str = "mtbench_category.pdf"
    ):
        """
        Plot MT-Bench scores by category (Figure 3).

        Args:
            dpo_scores: Category scores for DPO
            gdo_dpo_scores: Category scores for GDO-DPO
            save_path: Path to save figure
        """
        categories = list(dpo_scores.keys())
        x = np.arange(len(categories))
        width = 0.35

        fig, ax = plt.subplots(figsize=(12, 6))

        dpo_vals = [dpo_scores[cat] for cat in categories]
        gdo_vals = [gdo_dpo_scores[cat] for cat in categories]

        bars1 = ax.bar(x - width/2, dpo_vals, width, label='DPO',
                      color='#1f77b4', alpha=0.8)
        bars2 = ax.bar(x + width/2, gdo_vals, width, label='GDO-DPO',
                      color='#2ca02c', alpha=0.8)

        ax.set_xlabel('Category', fontsize=13)
        ax.set_ylabel('MT-Bench Score', fontsize=13)
        ax.set_title('MT-Bench Score by Category (Llama-3-8B)', fontsize=14)
        ax.set_xticks(x)
        ax.set_xticklabels(categories, rotation=45, ha='right')
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_ylim([7, 9])

        # Add value labels on bars
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.2f}',
                       ha='center', va='bottom', fontsize=9)

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved MT-Bench category plot to {save_path}")
        plt.close()

    def compute_gradient_statistics(
        self,
        model,
        data_loader,
        num_batches: int = 100,
        repr_layers: List[int] = None
    ) -> Dict[str, float]:
        """
        Compute gradient statistics for analysis.

        Args:
            model: Model to analyze
            data_loader: Data loader
            num_batches: Number of batches to analyze
            repr_layers: List of representation layer indices

        Returns:
            Dictionary with gradient statistics
        """
        if repr_layers is None:
            repr_layers = list(range(0, 17))  # Layers 0-16 for 32-layer model

        gradient_norms = []

        model.train()
        for i, batch in enumerate(data_loader):
            if i >= num_batches:
                break

            model.zero_grad()
            # Compute loss (assuming batch has loss computation)
            loss = self._compute_batch_loss(model, batch)
            loss.backward()

            # Collect gradient norms from repr layers
            batch_norm = 0.0
            for name, param in model.named_parameters():
                if param.grad is None:
                    continue

                layer_idx = self._extract_layer_index(name)
                if layer_idx in repr_layers:
                    batch_norm += (param.grad ** 2).sum().item()

            gradient_norms.append(batch_norm)

        # Compute statistics
        gradient_norms = np.array(gradient_norms)
        return {
            'mean': float(np.mean(gradient_norms)),
            'std': float(np.std(gradient_norms)),
            'variance': float(np.var(gradient_norms)),
        }

    def _compute_batch_loss(self, model, batch):
        """Compute loss for a batch. Override if needed."""
        # This is a placeholder - actual implementation depends on the model
        outputs = model(**batch)
        return outputs.loss if hasattr(outputs, 'loss') else outputs[0]

    def _extract_layer_index(self, param_name: str) -> int:
        """Extract layer index from parameter name."""
        parts = param_name.split('.')
        for i, part in enumerate(parts):
            if part in ['layers', 'h', 'blocks']:
                if i + 1 < len(parts) and parts[i + 1].isdigit():
                    return int(parts[i + 1])
        return None
