"""
RQ2: Improved Robustness and Training Stability

Scientific Question: Is the teachability penalty essential for closing the realization gap 
and ensuring stable training?

This experiment sweeps the teachability weight λ ∈ {0, 10^-3, 10^-2, 10^-1, 1.0} where:
- λ=0 represents the decoupled TE-GFN (critical ablation)
- Higher λ values emphasize teachability over raw reward

Key Metrics:
1. Final ℓ₁ error and number of modes discovered
2. Training Stability: Variance of the student's distillation loss
"""

import os
import sys
import argparse
import yaml
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Tuple
import time
import matplotlib.pyplot as plt

# Add parent directory to path
sys.path.append(str(Path(__file__).parent.parent))

from shared import (
    DATEGFN, HypergridEnvironment,
    ExperimentTracker, WandbLogger, set_seed, save_results,
    create_experiment_directory, statistical_significance_test,
    create_comparison_plot, create_learning_curves
)


class RQ2Experiment:
    """Teachability weight ablation experiment."""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.device = torch.device(config['training']['device'] if config['training']['device'] != 'auto' 
                                 else ('cuda' if torch.cuda.is_available() else 'cpu'))
        
        # Lambda values to test
        self.lambda_values = [0.0, 1e-3, 1e-2, 1e-1, 1.0]
        
        # Initialize wandb logger
        self.wandb_logger = WandbLogger("date_gfn_rq2")
        
        # Results storage
        self.results = {}
        self.stability_metrics = {}
        
    def create_environment(self, env_config: Dict[str, Any]):
        """Create environment based on configuration."""
        if env_config['type'] == 'HypergridEnvironment':
            return HypergridEnvironment(
                height=env_config['height'],
                ndim=env_config['ndim'],
                reward_beta=env_config['reward_beta'],
                reward_at_corners=env_config['reward_at_corners']
            )
        # SingleCellEnvironment removed - single-cell experiments not supported
        else:
            raise ValueError(f"Unknown environment type: {env_config['type']}")
    
    def create_date_gfn_with_lambda(self, lambda_val: float, env_config: Dict[str, Any]) -> DATEGFN:
        """Create DATE-GFN with specific teachability weight."""
        return DATEGFN(
            state_dim=env_config['state_dim'],
            action_dim=env_config['action_dim'],
            hidden_dim=self.config['model']['hidden_dim'],
            population_size=self.config['date_gfn']['population_size'],
            elite_ratio=self.config['date_gfn']['elite_ratio'],
            teachability_weight=lambda_val,  # This is the key parameter we're ablating
            student_updates_per_cycle=self.config['date_gfn']['student_updates_per_cycle'],
            device=self.device
        )
    
    def calculate_training_stability(self, loss_history: List[float], window_size: int = 100) -> Dict[str, float]:
        """Calculate training stability metrics."""
        if len(loss_history) < window_size:
            return {
                'distillation_loss_variance': float('inf'),
                'loss_coefficient_of_variation': float('inf'),
                'stability_score': 0.0
            }
        
        # Calculate variance over different windows
        variances = []
        for i in range(window_size, len(loss_history)):
            window = loss_history[i-window_size:i]
            variances.append(np.var(window))
        
        overall_variance = np.mean(variances)
        
        # Coefficient of variation (normalized volatility)
        mean_loss = np.mean(loss_history[-window_size:])
        cv = np.sqrt(overall_variance) / (mean_loss + 1e-8)
        
        # Stability score (inverse of normalized variance)
        stability_score = 1.0 / (1.0 + overall_variance)
        
        return {
            'distillation_loss_variance': overall_variance,
            'loss_coefficient_of_variation': cv,
            'stability_score': stability_score,
            'mean_recent_loss': mean_loss
        }
    
    def calculate_convergence_metrics(self, performance_history: List[float]) -> Dict[str, float]:
        """Calculate convergence-related metrics."""
        if len(performance_history) < 100:
            return {
                'time_to_50pct_modes': len(performance_history),
                'time_to_90pct_modes': len(performance_history),
                'final_performance': performance_history[-1] if performance_history else 0.0,
                'convergence_rate': 0.0
            }
        
        final_performance = performance_history[-1]
        
        # Time to reach percentage of final performance
        time_to_50pct = len(performance_history)
        time_to_90pct = len(performance_history)
        
        threshold_50 = 0.5 * final_performance
        threshold_90 = 0.9 * final_performance
        
        for i, perf in enumerate(performance_history):
            if perf >= threshold_50 and time_to_50pct == len(performance_history):
                time_to_50pct = i
            if perf >= threshold_90 and time_to_90pct == len(performance_history):
                time_to_90pct = i
        
        # Convergence rate (improvement per step)
        if len(performance_history) >= 200:
            early_perf = np.mean(performance_history[:100])
            late_perf = np.mean(performance_history[-100:])
            convergence_rate = (late_perf - early_perf) / len(performance_history)
        else:
            convergence_rate = 0.0
        
        return {
            'time_to_50pct_modes': time_to_50pct,
            'time_to_90pct_modes': time_to_90pct,
            'final_performance': final_performance,
            'convergence_rate': convergence_rate
        }
    
    def train_with_lambda(self, lambda_val: float, environment, env_name: str, seed: int) -> Dict[str, Any]:
        """Train DATE-GFN with specific lambda value."""
        print(f"  Training with λ={lambda_val} (seed {seed})...")
        
        # Create method
        method = self.create_date_gfn_with_lambda(lambda_val, 
                                                self.get_env_config(env_name))
        
        # Create optimizers
        optimizer_forward = torch.optim.Adam(
            method.forward_policy.parameters(), 
            lr=self.config['baselines']['gfn_tb']['learning_rate']
        )
        optimizer_backward = torch.optim.Adam(
            method.backward_policy.parameters(), 
            lr=self.config['baselines']['gfn_tb']['learning_rate']
        )
        
        # Initialize wandb run
        run_name = f"RQ2_{env_name}_lambda{lambda_val}_seed{seed}"
        wandb_config = {
            'research_question': 'RQ2',
            'environment': env_name,
            'teachability_weight': lambda_val,
            'seed': seed,
            **self.config
        }
        self.wandb_logger.init_run(wandb_config, run_name, 
                                 tags=['RQ2', env_name, f'lambda_{lambda_val}'])
        
        # Tracking
        tracker = ExperimentTracker(f"RQ2_{env_name}_lambda{lambda_val}")
        
        # Training history for stability analysis
        distillation_losses = []
        performance_history = []
        
        num_steps = self.config['training']['num_steps']
        eval_every = self.config['evaluation']['eval_every']
        
        for step in range(num_steps):
            step_start_time = time.time()
            
            # Training step
            metrics = method.train_step(environment, optimizer_forward, optimizer_backward)
            
            # Track distillation loss for stability analysis
            if 'distill_loss' in metrics:
                distillation_losses.append(metrics['distill_loss'])
            
            step_time = time.time() - step_start_time
            
            # Evaluation
            if step % eval_every == 0 or step == num_steps - 1:
                eval_trajectories = method.sample(environment, 100)
                
                # Log results
                step_metrics = tracker.log_step_results(
                    step, eval_trajectories, environment, f"DATE-GFN_λ{lambda_val}",
                    {**metrics, 'step_time': step_time, 'seed': seed, 'lambda': lambda_val}
                )
                
                # Track performance for convergence analysis
                if env_name == 'Hypergrid':
                    performance_history.append(step_metrics.get('modes_discovered', 0))
                else:
                    performance_history.append(step_metrics.get('des_score', 0))
                
                # Log to wandb
                self.wandb_logger.log_metrics(step_metrics, step)
                
                # Print progress
                if step % (eval_every * 5) == 0:
                    print(f"    Step {step}/{num_steps} - "
                          f"Performance: {performance_history[-1]:.4f}, "
                          f"Stability: {len(distillation_losses)} losses tracked")
        
        # Final evaluation
        final_trajectories = method.sample(environment, 200)
        final_metrics = tracker.log_step_results(
            num_steps, final_trajectories, environment, f"DATE-GFN_λ{lambda_val}",
            {'final_evaluation': True, 'seed': seed, 'lambda': lambda_val}
        )
        
        # Calculate stability metrics
        stability_metrics = self.calculate_training_stability(distillation_losses)
        
        # Calculate convergence metrics
        convergence_metrics = self.calculate_convergence_metrics(performance_history)
        
        # Combine all metrics
        result = {
            **final_metrics,
            **stability_metrics,
            **convergence_metrics,
            'lambda': lambda_val,
            'seed': seed,
            'distillation_loss_history': distillation_losses,
            'performance_history': performance_history
        }
        
        # Finish wandb run
        self.wandb_logger.finish_run()
        
        return result
    
    def get_env_config(self, env_name: str) -> Dict[str, Any]:
        """Get environment configuration."""
        if env_name == 'Hypergrid':
            return yaml.safe_load(open('configs/hypergrid_config.yaml'))['environment']
        else:
            raise ValueError(f"Unknown environment name: {env_name}. Only 'Hypergrid' is supported.")
    
    def run_lambda_ablation(self, env_name: str) -> Dict[str, List[Dict[str, Any]]]:
        """Run ablation study for all lambda values."""
        print(f"\n{'='*60}")
        print(f"Running RQ2 λ ablation on {env_name} environment")
        print(f"{'='*60}")
        
        env_config = self.get_env_config(env_name)
        environment = self.create_environment(env_config)
        
        lambda_results = {}
        
        for lambda_val in self.lambda_values:
            print(f"\nTesting λ = {lambda_val}")
            
            seed_results = []
            
            # Run multiple seeds
            for seed in range(self.config['training']['num_seeds']):
                set_seed(seed)
                
                try:
                    result = self.train_with_lambda(lambda_val, environment, env_name, seed)
                    seed_results.append(result)
                    
                except Exception as e:
                    print(f"Error with λ={lambda_val}, seed {seed}: {e}")
                    continue
            
            lambda_results[str(lambda_val)] = seed_results
        
        return lambda_results
    
    def analyze_lambda_effects(self, lambda_results: Dict[str, List[Dict[str, Any]]], 
                              env_name: str) -> Dict[str, Any]:
        """Analyze the effects of different lambda values."""
        analysis = {
            'lambda_comparison': {},
            'stability_analysis': {},
            'convergence_analysis': {},
            'statistical_tests': {}
        }
        
        # Aggregate results for each lambda
        for lambda_str, seed_results in lambda_results.items():
            if not seed_results:
                continue
                
            lambda_val = float(lambda_str)
            
            # Performance metrics
            key_metrics = ['modes_discovered', 'rel_l1_error', 'des_score', 'pds_score', 'mae']
            
            aggregated = {}
            for metric in key_metrics:
                values = [r.get(metric, 0) for r in seed_results if metric in r]
                if values:
                    aggregated[f"{metric}_mean"] = np.mean(values)
                    aggregated[f"{metric}_std"] = np.std(values)
                    aggregated[f"{metric}_values"] = values
            
            # Stability metrics
            stability_metrics = ['distillation_loss_variance', 'loss_coefficient_of_variation', 'stability_score']
            for metric in stability_metrics:
                values = [r.get(metric, float('inf')) for r in seed_results if metric in r]
                if values:
                    aggregated[f"{metric}_mean"] = np.mean(values)
                    aggregated[f"{metric}_std"] = np.std(values)
                    aggregated[f"{metric}_values"] = values
            
            # Convergence metrics  
            convergence_metrics = ['time_to_50pct_modes', 'time_to_90pct_modes', 'convergence_rate']
            for metric in convergence_metrics:
                values = [r.get(metric, 0) for r in seed_results if metric in r]
                if values:
                    aggregated[f"{metric}_mean"] = np.mean(values)
                    aggregated[f"{metric}_std"] = np.std(values)
                    aggregated[f"{metric}_values"] = values
            
            analysis['lambda_comparison'][lambda_str] = aggregated
        
        # Statistical significance tests between λ=0 and other values
        lambda_0_results = analysis['lambda_comparison'].get('0.0', {})
        
        for lambda_str in analysis['lambda_comparison']:
            if lambda_str == '0.0':
                continue
                
            lambda_results_current = analysis['lambda_comparison'][lambda_str]
            comparison_key = f"lambda_0_vs_{lambda_str}"
            analysis['statistical_tests'][comparison_key] = {}
            
            # Test key metrics
            test_metrics = ['distillation_loss_variance_values', 'stability_score_values']
            if env_name == 'Hypergrid':
                test_metrics.extend(['modes_discovered_values', 'rel_l1_error_values'])
            else:
                test_metrics.extend(['des_score_values', 'pds_score_values'])
            
            for metric in test_metrics:
                lambda_0_values = lambda_0_results.get(metric, [])
                lambda_current_values = lambda_results_current.get(metric, [])
                
                if lambda_0_values and lambda_current_values:
                    # For stability metrics, higher is better
                    if 'stability' in metric or 'modes_discovered' in metric or 'des_score' in metric or 'pds_score' in metric:
                        stat_test = statistical_significance_test(lambda_current_values, lambda_0_values)
                    else:  # For variance and error metrics, lower is better
                        stat_test = statistical_significance_test(lambda_0_values, lambda_current_values)
                    
                    analysis['statistical_tests'][comparison_key][metric] = stat_test
        
        return analysis
    
    def create_visualization_plots(self, lambda_results: Dict[str, List[Dict[str, Any]]], 
                                 analysis: Dict[str, Any], env_name: str) -> List[plt.Figure]:
        """Create visualization plots for lambda ablation."""
        figures = []
        
        # 1. Performance vs Lambda plot
        lambda_vals = []
        performance_means = []
        performance_stds = []
        
        stability_means = []
        stability_stds = []
        
        for lambda_str in sorted(lambda_results.keys(), key=float):
            if lambda_str not in analysis['lambda_comparison']:
                continue
                
            lambda_vals.append(float(lambda_str))
            
            # Performance metric
            if env_name == 'Hypergrid':
                perf_mean = analysis['lambda_comparison'][lambda_str].get('modes_discovered_mean', 0)
                perf_std = analysis['lambda_comparison'][lambda_str].get('modes_discovered_std', 0)
            else:
                perf_mean = analysis['lambda_comparison'][lambda_str].get('des_score_mean', 0)
                perf_std = analysis['lambda_comparison'][lambda_str].get('des_score_std', 0)
            
            performance_means.append(perf_mean)
            performance_stds.append(perf_std)
            
            # Stability metric
            stab_mean = analysis['lambda_comparison'][lambda_str].get('stability_score_mean', 0)
            stab_std = analysis['lambda_comparison'][lambda_str].get('stability_score_std', 0)
            stability_means.append(stab_mean)
            stability_stds.append(stab_std)
        
        # Performance plot
        fig1, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        ax1.errorbar(lambda_vals, performance_means, yerr=performance_stds, 
                    marker='o', capsize=5, linewidth=2, markersize=8)
        ax1.set_xscale('symlog', linthresh=1e-4)
        ax1.set_xlabel('Teachability Weight (λ)')
        ylabel = 'Modes Discovered' if env_name == 'Hypergrid' else 'DES Score'
        ax1.set_ylabel(ylabel)
        ax1.set_title(f'{ylabel} vs Teachability Weight')
        ax1.grid(True, alpha=0.3)
        
        # Highlight λ=0 case
        if 0.0 in lambda_vals:
            idx = lambda_vals.index(0.0)
            ax1.axvline(x=0.0, color='red', linestyle='--', alpha=0.7, 
                       label='λ=0 (TE-GFN)')
            ax1.scatter(0.0, performance_means[idx], color='red', s=100, zorder=5)
        
        ax1.legend()
        
        # Stability plot
        ax2.errorbar(lambda_vals, stability_means, yerr=stability_stds, 
                    marker='s', capsize=5, linewidth=2, markersize=8, color='orange')
        ax2.set_xscale('symlog', linthresh=1e-4)
        ax2.set_xlabel('Teachability Weight (λ)')
        ax2.set_ylabel('Training Stability Score')
        ax2.set_title('Training Stability vs Teachability Weight')
        ax2.grid(True, alpha=0.3)
        
        # Highlight λ=0 case
        if 0.0 in lambda_vals:
            ax2.axvline(x=0.0, color='red', linestyle='--', alpha=0.7)
            ax2.scatter(0.0, stability_means[lambda_vals.index(0.0)], color='red', s=100, zorder=5)
        
        plt.tight_layout()
        figures.append(fig1)
        
        # 2. Loss variance comparison
        fig2, ax = plt.subplots(figsize=(10, 6))
        
        lambda_vals_var = []
        variance_means = []
        variance_stds = []
        
        for lambda_str in sorted(lambda_results.keys(), key=float):
            if lambda_str not in analysis['lambda_comparison']:
                continue
                
            var_mean = analysis['lambda_comparison'][lambda_str].get('distillation_loss_variance_mean', 0)
            var_std = analysis['lambda_comparison'][lambda_str].get('distillation_loss_variance_std', 0)
            
            lambda_vals_var.append(float(lambda_str))
            variance_means.append(var_mean)
            variance_stds.append(var_std)
        
        bars = ax.bar(range(len(lambda_vals_var)), variance_means, yerr=variance_stds, 
                     capsize=5, alpha=0.7, color='skyblue', edgecolor='navy')
        
        # Highlight λ=0 bar
        if 0.0 in lambda_vals_var:
            idx = lambda_vals_var.index(0.0)
            bars[idx].set_color('red')
            bars[idx].set_alpha(0.8)
        
        ax.set_xlabel('Teachability Weight (λ)')
        ax.set_ylabel('Distillation Loss Variance')
        ax.set_title('Training Instability (Loss Variance) vs λ')
        ax.set_xticks(range(len(lambda_vals_var)))
        ax.set_xticklabels([f'λ={x}' for x in lambda_vals_var], rotation=45)
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        figures.append(fig2)
        
        return figures
    
    def generate_rq2_report(self, all_results: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
        """Generate comprehensive RQ2 report."""
        report = {
            'research_question': 'RQ2 - Teachability Weight Ablation',
            'hypothesis': 'λ=0 shows high variance; optimal λ>0 provides stability',
            'lambda_values_tested': self.lambda_values,
            'environments': list(all_results.keys()),
            'key_findings': {},
            'stability_analysis': {},
            'recommendations': []
        }
        
        for env_name, env_analysis in all_results.items():
            env_findings = {}
            
            # Find optimal lambda
            lambda_comparison = env_analysis.get('lambda_comparison', {})
            
            best_lambda = None
            best_stability = -1
            
            lambda_0_stability = lambda_comparison.get('0.0', {}).get('stability_score_mean', 0)
            
            for lambda_str, results in lambda_comparison.items():
                stability = results.get('stability_score_mean', 0)
                if stability > best_stability:
                    best_stability = stability
                    best_lambda = float(lambda_str)
            
            env_findings['optimal_lambda'] = best_lambda
            env_findings['lambda_0_stability'] = lambda_0_stability
            env_findings['best_stability'] = best_stability
            env_findings['stability_improvement'] = (best_stability - lambda_0_stability) / (lambda_0_stability + 1e-8)
            
            # Performance analysis
            if env_name == 'Hypergrid':
                lambda_0_performance = lambda_comparison.get('0.0', {}).get('modes_discovered_mean', 0)
                best_lambda_performance = lambda_comparison.get(str(best_lambda), {}).get('modes_discovered_mean', 0)
                performance_metric = 'modes_discovered'
            else:
                lambda_0_performance = lambda_comparison.get('0.0', {}).get('des_score_mean', 0)
                best_lambda_performance = lambda_comparison.get(str(best_lambda), {}).get('des_score_mean', 0)
                performance_metric = 'des_score'
            
            env_findings['lambda_0_performance'] = lambda_0_performance
            env_findings['best_lambda_performance'] = best_lambda_performance
            env_findings['performance_improvement'] = (best_lambda_performance - lambda_0_performance) / (lambda_0_performance + 1e-8)
            env_findings['performance_metric'] = performance_metric
            
            report['key_findings'][env_name] = env_findings
        
        # Overall recommendations
        avg_stability_improvement = np.mean([
            findings['stability_improvement'] for findings in report['key_findings'].values()
        ])
        
        avg_performance_improvement = np.mean([
            findings['performance_improvement'] for findings in report['key_findings'].values()
        ])
        
        if avg_stability_improvement > 0.5:  # >50% improvement
            report['recommendations'].append("Teachability penalty significantly improves training stability")
        
        if avg_performance_improvement > 0.2:  # >20% improvement
            report['recommendations'].append("Optimal λ > 0 improves final performance over TE-GFN")
        
        if all(findings['optimal_lambda'] > 0 for findings in report['key_findings'].values()):
            report['recommendations'].append("λ=0 (TE-GFN) is suboptimal across all environments")
        
        report['avg_stability_improvement'] = avg_stability_improvement
        report['avg_performance_improvement'] = avg_performance_improvement
        
        return report
    
    def run_experiment(self):
        """Run the complete RQ2 experiment."""
        print("Starting RQ2: Teachability Weight Ablation Experiment")
        print("=" * 60)
        
        # Create experiment directory
        exp_dir = create_experiment_directory("results", "RQ2_robustness")
        
        all_results = {}
        all_plots = {}
        
        # Test on Hypergrid environment (primary)
        hypergrid_results = self.run_lambda_ablation('Hypergrid')
        hypergrid_analysis = self.analyze_lambda_effects(hypergrid_results, 'Hypergrid')
        hypergrid_plots = self.create_visualization_plots(hypergrid_results, hypergrid_analysis, 'Hypergrid')
        
        all_results['Hypergrid'] = hypergrid_analysis
        all_plots['Hypergrid'] = hypergrid_plots
        
        # Single-cell experiments removed
        
        # Generate summary report
        summary_report = self.generate_rq2_report(all_results)
        
        # Save results
        save_results(all_results, exp_dir / "detailed_results.json")
        save_results(summary_report, exp_dir / "summary_report.json")
        
        # Save plots
        plots_dir = exp_dir / "plots"
        plots_dir.mkdir(exist_ok=True)
        
        for env_name, plots in all_plots.items():
            for i, fig in enumerate(plots):
                fig.savefig(plots_dir / f"{env_name}_lambda_ablation_{i+1}.png", 
                           dpi=300, bbox_inches='tight')
                plt.close(fig)
        
        # Print summary
        print("\n" + "=" * 60)
        print("RQ2 EXPERIMENT COMPLETED")
        print("=" * 60)
        print(f"Average stability improvement: {summary_report['avg_stability_improvement']:.2%}")
        print(f"Average performance improvement: {summary_report['avg_performance_improvement']:.2%}")
        print(f"Results saved to: {exp_dir}")
        
        for env_name, findings in summary_report['key_findings'].items():
            print(f"\n{env_name}:")
            print(f"  Optimal λ: {findings['optimal_lambda']}")
            print(f"  Stability improvement: {findings['stability_improvement']:.2%}")
            print(f"  Performance improvement: {findings['performance_improvement']:.2%}")
        
        for recommendation in summary_report['recommendations']:
            print(f"• {recommendation}")
        
        return all_results, summary_report


def main():
    parser = argparse.ArgumentParser(description='Run RQ2 Teachability Weight Ablation')
    parser.add_argument('--config', type=str, default='configs/base_config.yaml',
                       help='Path to configuration file')
    parser.add_argument('--device', type=str, default='auto',
                       help='Device to use (auto, cuda, cpu)')
    
    args = parser.parse_args()
    
    # Load configuration
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    if args.device != 'auto':
        config['training']['device'] = args.device
    
    # Run experiment
    experiment = RQ2Experiment(config)
    results, summary = experiment.run_experiment()
    
    return results, summary


if __name__ == "__main__":
    main()
