"""
Ablation Experiments for GDO-DPO

Implements various ablations as described in Section 5.5:
- Table 5: Curriculum structure ablations
- Table 6: Monitoring mechanism ablations
- Table 9: Layer boundary sensitivity
- Table 10: Threshold sensitivity
"""

import torch
import numpy as np
from typing import Dict, List, Optional
from dataclasses import dataclass

from ..core.gdo_dpo import GDODPOConfig, GDODPOTrainer


@dataclass
class AblationConfig:
    """Configuration for ablation experiments."""
    name: str
    description: str


class CurriculumAblations:
    """
    Ablation experiments for curriculum structure (Table 5).
    """

    @staticmethod
    def get_single_dimension_curriculum(
        dataset: List[Dict],
        combine_mode: str = "sum"
    ) -> List[Dict]:
        """
        Create single-dimensional curriculum by combining Csem and Upref.

        Corresponds to "Single-dim combined score" in Table 5.

        Args:
            dataset: Dataset with Rsem and Runc
            combine_mode: How to combine ('sum', 'max', 'product')

        Returns:
            Dataset sorted by combined difficulty
        """
        for sample in dataset:
            if combine_mode == "sum":
                sample['combined_difficulty'] = sample['Rsem'] + sample['Runc']
            elif combine_mode == "max":
                sample['combined_difficulty'] = max(sample['Rsem'], sample['Runc'])
            elif combine_mode == "product":
                sample['combined_difficulty'] = sample['Rsem'] * sample['Runc']

        # Sort by combined difficulty
        sorted_dataset = sorted(dataset, key=lambda x: x['combined_difficulty'])
        return sorted_dataset

    @staticmethod
    def get_reverse_curriculum(dataset: List[Dict]) -> List[Dict]:
        """
        Create reverse curriculum (hard → easy).

        Corresponds to "Reverse order (hard→easy)" in Table 5.

        Args:
            dataset: Dataset with Rsem and Runc

        Returns:
            Dataset sorted in reverse difficulty order
        """
        # Sort by combined difficulty in descending order
        sorted_dataset = sorted(
            dataset,
            key=lambda x: x['Rsem'] + x['Runc'],
            reverse=True
        )
        return sorted_dataset

    @staticmethod
    def get_random_curriculum(dataset: List[Dict], seed: int = 42) -> List[Dict]:
        """
        Random order baseline.

        Corresponds to "Random order (DPO)" in Table 5.

        Args:
            dataset: Dataset
            seed: Random seed

        Returns:
            Randomly shuffled dataset
        """
        np.random.seed(seed)
        shuffled = dataset.copy()
        np.random.shuffle(shuffled)
        return shuffled


class MonitoringAblations:
    """
    Ablation experiments for monitoring mechanisms (Table 6).
    """

    @staticmethod
    def create_fixed_schedule_config(
        gdo_config: GDODPOConfig,
        total_steps: int
    ) -> GDODPOConfig:
        """
        Create config with fixed linear schedule instead of gradient-based.

        Corresponds to "Fixed linear schedule" in Table 6.

        Args:
            gdo_config: Base GDO-DPO config
            total_steps: Total training steps

        Returns:
            Modified config with disabled monitoring
        """
        config = GDODPOConfig(
            tau_stable=float('inf'),  # Never advance based on Srep
            tau_acc=0.0,              # Always advance based on Adisc
            delta_sem=1.0 / total_steps,  # Linear growth
            delta_unc=1.0 / total_steps,
            layer_mid=gdo_config.layer_mid,
            ema_decay=gdo_config.ema_decay,
            eval_interval=gdo_config.eval_interval,
            beta=gdo_config.beta,
        )
        return config

    @staticmethod
    def create_loss_based_pacing_config(
        gdo_config: GDODPOConfig
    ) -> GDODPOConfig:
        """
        Create config with loss-based pacing instead of gradient-based.

        Corresponds to "Loss-based pacing" in Table 6.

        This requires modifying the trainer to use loss instead of gradients.

        Args:
            gdo_config: Base GDO-DPO config

        Returns:
            Config for loss-based pacing
        """
        # Same config but trainer should use loss-based monitoring
        return gdo_config

    @staticmethod
    def create_no_srep_config(gdo_config: GDODPOConfig) -> GDODPOConfig:
        """
        Disable Srep monitor (only use Adisc).

        Corresponds to "w/o Srep monitor" in Table 6.

        Args:
            gdo_config: Base config

        Returns:
            Config with disabled Srep monitoring
        """
        config = GDODPOConfig(
            tau_stable=0.0,  # Always advance semantic
            tau_acc=gdo_config.tau_acc,
            delta_sem=gdo_config.delta_sem,
            delta_unc=gdo_config.delta_unc,
            layer_mid=gdo_config.layer_mid,
            ema_decay=gdo_config.ema_decay,
            eval_interval=gdo_config.eval_interval,
            beta=gdo_config.beta,
        )
        return config

    @staticmethod
    def create_no_adisc_config(gdo_config: GDODPOConfig) -> GDODPOConfig:
        """
        Disable Adisc monitor (only use Srep).

        Corresponds to "w/o Adisc monitor" in Table 6.

        Args:
            gdo_config: Base config

        Returns:
            Config with disabled Adisc monitoring
        """
        config = GDODPOConfig(
            tau_stable=gdo_config.tau_stable,
            tau_acc=0.0,  # Always advance uncertainty
            delta_sem=gdo_config.delta_sem,
            delta_unc=gdo_config.delta_unc,
            layer_mid=gdo_config.layer_mid,
            ema_decay=gdo_config.ema_decay,
            eval_interval=gdo_config.eval_interval,
            beta=gdo_config.beta,
        )
        return config


class SensitivityAnalysis:
    """
    Sensitivity analysis for hyperparameters.
    """

    @staticmethod
    def layer_boundary_sweep(
        base_config: GDODPOConfig,
        num_layers: int = 32
    ) -> Dict[str, GDODPOConfig]:
        """
        Sweep layer boundary Lmid (Table 9).

        Tests: L/2, 2L/3, 3L/4

        Args:
            base_config: Base configuration
            num_layers: Total number of layers

        Returns:
            Dictionary mapping boundary description to config
        """
        configs = {}

        boundaries = {
            'L/2': num_layers // 2,
            '2L/3': int(2 * num_layers / 3),
            '3L/4': int(3 * num_layers / 4),
        }

        for name, layer_mid in boundaries.items():
            config = GDODPOConfig(
                tau_stable=base_config.tau_stable,
                tau_acc=base_config.tau_acc,
                delta_sem=base_config.delta_sem,
                delta_unc=base_config.delta_unc,
                layer_mid=layer_mid,
                ema_decay=base_config.ema_decay,
                eval_interval=base_config.eval_interval,
                beta=base_config.beta,
            )
            configs[name] = config

        return configs

    @staticmethod
    def threshold_sweep(
        base_config: GDODPOConfig
    ) -> Dict[str, GDODPOConfig]:
        """
        Sweep threshold values (Table 10).

        Tests various combinations of tau_stable and tau_acc.

        Args:
            base_config: Base configuration

        Returns:
            Dictionary mapping parameter combination to config
        """
        configs = {}

        threshold_combinations = [
            (0.8, 0.60),
            (1.0, 0.65),
            (1.2, 0.65),  # Default
            (1.4, 0.70),
            (1.6, 0.75),
        ]

        for tau_stable, tau_acc in threshold_combinations:
            name = f"τ_stable={tau_stable}, τ_acc={tau_acc}"
            config = GDODPOConfig(
                tau_stable=tau_stable,
                tau_acc=tau_acc,
                delta_sem=base_config.delta_sem,
                delta_unc=base_config.delta_unc,
                layer_mid=base_config.layer_mid,
                ema_decay=base_config.ema_decay,
                eval_interval=base_config.eval_interval,
                beta=base_config.beta,
            )
            configs[name] = config

        return configs


class AblationRunner:
    """
    Runner for ablation experiments.
    """

    def __init__(self, base_config: GDODPOConfig):
        """
        Args:
            base_config: Base GDO-DPO configuration
        """
        self.base_config = base_config
        self.results = {}

    def run_curriculum_ablations(
        self,
        train_fn,
        dataset,
        **kwargs
    ) -> Dict[str, Dict]:
        """
        Run curriculum structure ablations (Table 5).

        Args:
            train_fn: Training function
            dataset: Training dataset
            **kwargs: Additional arguments for training

        Returns:
            Results dictionary
        """
        ablations = [
            ('Bi-dimensional (Ours)', None),  # Default GDO-DPO
            ('Single-dim combined score',
             CurriculumAblations.get_single_dimension_curriculum(dataset)),
            ('Reverse order (hard→easy)',
             CurriculumAblations.get_reverse_curriculum(dataset)),
            ('Random order (DPO)',
             CurriculumAblations.get_random_curriculum(dataset)),
        ]

        results = {}
        for name, modified_dataset in ablations:
            print(f"\n{'='*60}")
            print(f"Running ablation: {name}")
            print(f"{'='*60}")

            data_to_use = modified_dataset if modified_dataset is not None else dataset
            result = train_fn(data_to_use, self.base_config, **kwargs)
            results[name] = result

        self.results['curriculum'] = results
        return results

    def run_monitoring_ablations(
        self,
        train_fn,
        dataset,
        total_steps: int,
        **kwargs
    ) -> Dict[str, Dict]:
        """
        Run monitoring mechanism ablations (Table 6).

        Args:
            train_fn: Training function
            dataset: Training dataset
            total_steps: Total training steps
            **kwargs: Additional arguments for training

        Returns:
            Results dictionary
        """
        configs = {
            'Gradient-based (Ours)': self.base_config,
            'Fixed linear schedule':
                MonitoringAblations.create_fixed_schedule_config(
                    self.base_config, total_steps
                ),
            'w/o Srep monitor':
                MonitoringAblations.create_no_srep_config(self.base_config),
            'w/o Adisc monitor':
                MonitoringAblations.create_no_adisc_config(self.base_config),
        }

        results = {}
        for name, config in configs.items():
            print(f"\n{'='*60}")
            print(f"Running ablation: {name}")
            print(f"{'='*60}")

            result = train_fn(dataset, config, **kwargs)
            results[name] = result

        self.results['monitoring'] = results
        return results

    def run_sensitivity_analysis(
        self,
        train_fn,
        dataset,
        num_layers: int = 32,
        **kwargs
    ) -> Dict[str, Dict]:
        """
        Run hyperparameter sensitivity analysis (Tables 9, 10).

        Args:
            train_fn: Training function
            dataset: Training dataset
            num_layers: Total number of layers in model
            **kwargs: Additional arguments for training

        Returns:
            Results dictionary
        """
        # Layer boundary sensitivity (Table 9)
        layer_configs = SensitivityAnalysis.layer_boundary_sweep(
            self.base_config, num_layers
        )

        layer_results = {}
        for name, config in layer_configs.items():
            print(f"\n{'='*60}")
            print(f"Running layer boundary ablation: {name}")
            print(f"{'='*60}")

            result = train_fn(dataset, config, **kwargs)
            layer_results[name] = result

        # Threshold sensitivity (Table 10)
        threshold_configs = SensitivityAnalysis.threshold_sweep(self.base_config)

        threshold_results = {}
        for name, config in threshold_configs.items():
            print(f"\n{'='*60}")
            print(f"Running threshold ablation: {name}")
            print(f"{'='*60}")

            result = train_fn(dataset, config, **kwargs)
            threshold_results[name] = result

        self.results['sensitivity'] = {
            'layer_boundary': layer_results,
            'thresholds': threshold_results,
        }

        return self.results['sensitivity']

    def save_results(self, save_path: str):
        """
        Save ablation results.

        Args:
            save_path: Path to save results
        """
        import json
        with open(save_path, 'w') as f:
            json.dump(self.results, f, indent=2)
        print(f"Saved ablation results to {save_path}")
