"""
RQ1: Comparative Analysis - How does DATE-GFN compare against state-of-the-art baselines?

This experiment compares DATE-GFN against:
1. Standard GFlowNet (GFN-TB)
2. Evolution Guided GFlowNet (EGFN) 
3. Soft Actor-Critic (SAC)
4. Markov Chain Monte Carlo (MARS)

On Hypergrid environments.
"""

import os
import sys
import argparse
import yaml
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List, Any
import time

# Add parent directory to path to import shared modules
sys.path.append(str(Path(__file__).parent.parent))

from shared import (
    DATEGFN, GFNBaseline, EGFNBaseline, SACBaseline, MARSBaseline,
    HypergridEnvironment,
    ExperimentTracker, WandbLogger, set_seed, save_results,
    create_experiment_directory, statistical_significance_test
)


class RQ1Experiment:
    """Main experiment class for RQ1 comparative analysis."""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.device = torch.device(config['training']['device'] if config['training']['device'] != 'auto' 
                                 else ('cuda' if torch.cuda.is_available() else 'cpu'))
        
        # Initialize wandb logger
        self.wandb_logger = WandbLogger("date_gfn_rq1")
        
        # Results storage
        self.results = {}
        
    def create_environment(self, env_config: Dict[str, Any]):
        """Create environment based on configuration."""
        if env_config['type'] == 'HypergridEnvironment':
            return HypergridEnvironment(
                height=env_config['height'],
                ndim=env_config['ndim'],
                reward_beta=env_config['reward_beta'],
                reward_at_corners=env_config['reward_at_corners']
            )
        # SingleCellEnvironment removed - single-cell experiments not supported
        else:
            raise ValueError(f"Unknown environment type: {env_config['type']}")
    
    def create_method(self, method_name: str, env_config: Dict[str, Any]):
        """Create method instance based on name."""
        state_dim = env_config['state_dim']
        action_dim = env_config['action_dim']
        
        if method_name == 'DATE-GFN':
            return DATEGFN(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=self.config['model']['hidden_dim'],
                population_size=self.config['date_gfn']['population_size'],
                elite_ratio=self.config['date_gfn']['elite_ratio'],
                teachability_weight=self.config['date_gfn']['teachability_weight'],
                student_updates_per_cycle=self.config['date_gfn']['student_updates_per_cycle'],
                device=self.device
            )
        elif method_name == 'GFN-TB':
            return GFNBaseline(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=self.config['model']['hidden_dim']
            )
        elif method_name == 'EGFN':
            return EGFNBaseline(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=self.config['model']['hidden_dim'],
                population_size=self.config['baselines']['egfn'].get('population_size', 16),
                elite_ratio=self.config['baselines']['egfn'].get('elite_ratio', 0.25)
            )
        elif method_name == 'SAC':
            return SACBaseline(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=self.config['model']['hidden_dim']
            )
        elif method_name == 'MARS':
            return MARSBaseline(
                state_dim=state_dim,
                action_dim=action_dim
            )
        else:
            raise ValueError(f"Unknown method: {method_name}")
    
    def create_optimizers(self, method, method_name: str):
        """Create optimizers for the method."""
        if method_name == 'DATE-GFN':
            optimizer_forward = torch.optim.Adam(
                method.forward_policy.parameters(), 
                lr=float(self.config['baselines']['gfn_tb']['learning_rate'])
            )
            optimizer_backward = torch.optim.Adam(
                method.backward_policy.parameters(), 
                lr=float(self.config['baselines']['gfn_tb']['learning_rate'])
            )
            return [optimizer_forward, optimizer_backward]
        
        elif method_name == 'GFN-TB':
            optimizer = torch.optim.Adam(
                list(method.forward_policy.parameters()) + 
                list(method.backward_policy.parameters()) + 
                list(method.state_value.parameters()),
                lr=float(self.config['baselines']['gfn_tb']['learning_rate'])
            )
            return [optimizer]
        
        elif method_name == 'EGFN':
            optimizer = torch.optim.Adam(
                method.gfn.forward_policy.parameters(),
                lr=float(self.config['baselines']['gfn_tb']['learning_rate'])
            )
            return [optimizer]
        
        elif method_name == 'SAC':
            actor_optimizer = torch.optim.Adam(
                method.actor.parameters(),
                lr=float(self.config['baselines']['sac']['learning_rate'])
            )
            critic1_optimizer = torch.optim.Adam(
                method.critic1.parameters(),
                lr=float(self.config['baselines']['sac']['learning_rate'])
            )
            critic2_optimizer = torch.optim.Adam(
                method.critic2.parameters(),
                lr=float(self.config['baselines']['sac']['learning_rate'])
            )
            alpha_optimizer = torch.optim.Adam(
                [method.log_alpha],
                lr=float(self.config['baselines']['sac']['learning_rate'])
            )
            return [actor_optimizer, critic1_optimizer, critic2_optimizer, alpha_optimizer]
        
        elif method_name == 'MARS':
            return []  # MARS doesn't use optimizers
        
        else:
            raise ValueError(f"Unknown method for optimizer creation: {method_name}")
    
    def train_method(self, method, method_name: str, environment, 
                    optimizers: List, tracker: ExperimentTracker, seed: int):
        """Train a single method."""
        print(f"Training {method_name} (seed {seed})...")
        
        num_steps = self.config['training']['num_steps']
        eval_every = self.config['evaluation']['eval_every']
        
        for step in range(num_steps):
            step_start_time = time.time()
            
            # Training step based on method type
            if method_name == 'DATE-GFN':
                metrics = method.train_step(environment, optimizers[0], optimizers[1])
            elif method_name == 'GFN-TB':
                metrics = method.train_step(environment, optimizers[0])
            elif method_name == 'EGFN':
                metrics = method.train_step(environment, optimizers[0])
            elif method_name == 'SAC':
                metrics = method.train_step(environment, *optimizers)
            elif method_name == 'MARS':
                metrics = method.train_step(environment)
            
            step_time = time.time() - step_start_time
            
            # Evaluation
            if step % eval_every == 0 or step == num_steps - 1:
                eval_trajectories = method.sample(
                    environment, 
                    self.config['evaluation']['num_eval_trajectories']
                )
                
                # Log results
                step_metrics = tracker.log_step_results(
                    step, eval_trajectories, environment, method_name, 
                    {**metrics, 'step_time': step_time, 'seed': seed}
                )
                
                # Log to wandb
                self.wandb_logger.log_metrics(step_metrics, step)
                
                # Print progress
                if step % (eval_every * 5) == 0:
                    print(f"  Step {step}/{num_steps} - "
                          f"Reward: {step_metrics.get('mean_reward', 0):.4f}, "
                          f"Diversity: {step_metrics.get('diversity', 0):.4f}")
        
        # Final evaluation
        final_trajectories = method.sample(environment, 200)  # More samples for final eval
        final_metrics = tracker.log_step_results(
            num_steps, final_trajectories, environment, method_name,
            {'final_evaluation': True, 'seed': seed}
        )
        
        return final_metrics, tracker
    
    def run_single_environment_experiment(self, env_name: str, env_config: Dict[str, Any]):
        """Run experiment on a single environment."""
        print(f"\n{'='*60}")
        print(f"Running RQ1 experiment on {env_name} environment")
        print(f"{'='*60}")
        
        environment = self.create_environment(env_config)
        methods = ['DATE-GFN', 'GFN-TB', 'EGFN', 'SAC', 'MARS']
        
        env_results = {}
        
        for method_name in methods:
            print(f"\nTesting method: {method_name}")
            method_results = []
            
            # Run multiple seeds
            for seed in range(self.config['training']['num_seeds']):
                set_seed(seed)
                
                # Initialize wandb run
                run_name = f"RQ1_{env_name}_{method_name}_seed{seed}"
                wandb_config = {
                    'research_question': 'RQ1',
                    'environment': env_name,
                    'method': method_name,
                    'seed': seed,
                    **self.config
                }
                self.wandb_logger.init_run(wandb_config, run_name, 
                                         tags=['RQ1', env_name, method_name])
                
                # Create method and optimizers
                method = self.create_method(method_name, env_config)
                optimizers = self.create_optimizers(method, method_name)
                
                # Create tracker
                tracker = ExperimentTracker(f"RQ1_{env_name}_{method_name}")
                
                # Train method
                try:
                    final_metrics, tracker = self.train_method(
                        method, method_name, environment, optimizers, tracker, seed
                    )
                    method_results.append(final_metrics)
                    
                except Exception as e:
                    print(f"Error training {method_name} with seed {seed}: {e}")
                    continue
                
                # Finish wandb run
                self.wandb_logger.finish_run()
            
            # Aggregate results across seeds
            if method_results:
                env_results[method_name] = self.aggregate_seed_results(method_results)
            
        return env_results
    
    def aggregate_seed_results(self, seed_results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Aggregate results across multiple seeds."""
        aggregated = {}
        
        # Get all metric keys
        all_keys = set()
        for result in seed_results:
            all_keys.update(result.keys())
        
        # Aggregate numeric metrics
        for key in all_keys:
            values = []
            for result in seed_results:
                if key in result and isinstance(result[key], (int, float)):
                    values.append(result[key])
            
            if values:
                aggregated[f"{key}_mean"] = np.mean(values)
                aggregated[f"{key}_std"] = np.std(values)
                aggregated[f"{key}_min"] = np.min(values)
                aggregated[f"{key}_max"] = np.max(values)
                aggregated[f"{key}_values"] = values  # Store raw values for significance testing
        
        aggregated['num_seeds'] = len(seed_results)
        return aggregated
    
    def perform_statistical_analysis(self, results: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
        """Perform statistical significance tests."""
        statistical_results = {}
        
        methods = list(results.keys())
        
        # Compare DATE-GFN with each baseline
        date_gfn_results = results.get('DATE-GFN', {})
        
        for method in methods:
            if method == 'DATE-GFN':
                continue
                
            method_results = results[method]
            comparison_key = f"DATE-GFN_vs_{method}"
            statistical_results[comparison_key] = {}
            
            # Test key metrics
            key_metrics = ['mean_reward', 'diversity', 'rel_l1_error', 'des_score', 'pds_score', 'mae']
            
            for metric in key_metrics:
                date_gfn_values = date_gfn_results.get(f"{metric}_values", [])
                method_values = method_results.get(f"{metric}_values", [])
                
                if date_gfn_values and method_values:
                    # For error metrics (lower is better), flip comparison
                    if 'error' in metric.lower() or 'mae' in metric.lower():
                        stat_test = statistical_significance_test(method_values, date_gfn_values)
                    else:
                        stat_test = statistical_significance_test(date_gfn_values, method_values)
                    
                    statistical_results[comparison_key][metric] = stat_test
        
        return statistical_results
    
    def generate_summary_report(self, all_results: Dict[str, Dict[str, Dict[str, Any]]]) -> Dict[str, Any]:
        """Generate comprehensive summary report."""
        report = {
            'experiment': 'RQ1 - Comparative Analysis',
            'hypothesis': 'DATE-GFN significantly outperforms baselines',
            'environments_tested': list(all_results.keys()),
            'methods_compared': ['DATE-GFN', 'GFN-TB', 'EGFN', 'SAC', 'MARS'],
            'key_findings': {},
            'statistical_significance': {},
            'recommendations': []
        }
        
        # Analyze results for each environment
        for env_name, env_results in all_results.items():
            env_analysis = {}
            
            # Find best performing method for each metric
            methods = list(env_results.keys())
            key_metrics = ['mean_reward_mean', 'diversity_mean', 'rel_l1_error_mean', 
                          'des_score_mean', 'pds_score_mean', 'mae_mean']
            
            for metric in key_metrics:
                best_method = None
                best_value = float('-inf') if 'error' not in metric and 'mae' not in metric else float('inf')
                
                for method in methods:
                    if metric in env_results[method]:
                        value = env_results[method][metric]
                        
                        # For error metrics, lower is better
                        if ('error' in metric or 'mae' in metric) and value < best_value:
                            best_value = value
                            best_method = method
                        elif ('error' not in metric and 'mae' not in metric) and value > best_value:
                            best_value = value
                            best_method = method
                
                env_analysis[metric] = {'best_method': best_method, 'best_value': best_value}
            
            report['key_findings'][env_name] = env_analysis
            
            # Statistical significance
            stat_results = self.perform_statistical_analysis(env_results)
            report['statistical_significance'][env_name] = stat_results
        
        # Generate recommendations
        date_gfn_wins = 0
        total_comparisons = 0
        
        for env_name, env_findings in report['key_findings'].items():
            for metric, result in env_findings.items():
                if result['best_method'] == 'DATE-GFN':
                    date_gfn_wins += 1
                total_comparisons += 1
        
        win_rate = date_gfn_wins / total_comparisons if total_comparisons > 0 else 0
        
        if win_rate > 0.6:
            report['recommendations'].append("DATE-GFN demonstrates superior performance")
        elif win_rate > 0.4:
            report['recommendations'].append("DATE-GFN shows competitive performance")
        else:
            report['recommendations'].append("DATE-GFN needs improvement")
        
        report['date_gfn_win_rate'] = win_rate
        
        return report
    
    def run_experiment(self):
        """Run the complete RQ1 experiment."""
        print("Starting RQ1: Comparative Analysis Experiment")
        print("=" * 60)
        
        # Create experiment directory
        exp_dir = create_experiment_directory("results", "RQ1_comparative")
        
        all_results = {}
        
        # Test on Hypergrid environment
        hypergrid_config = yaml.safe_load(open('configs/hypergrid_config.yaml'))['environment']
        hypergrid_results = self.run_single_environment_experiment('Hypergrid', hypergrid_config)
        all_results['Hypergrid'] = hypergrid_results
        
        # Single-cell experiments removed
        
        # Generate summary report
        summary_report = self.generate_summary_report(all_results)
        
        # Save results
        save_results(all_results, exp_dir / "detailed_results.json")
        save_results(summary_report, exp_dir / "summary_report.json")
        
        # Print summary
        print("\n" + "=" * 60)
        print("RQ1 EXPERIMENT COMPLETED")
        print("=" * 60)
        print(f"DATE-GFN win rate: {summary_report['date_gfn_win_rate']:.2%}")
        print(f"Results saved to: {exp_dir}")
        
        for recommendation in summary_report['recommendations']:
            print(f"• {recommendation}")
        
        return all_results, summary_report


def main():
    parser = argparse.ArgumentParser(description='Run RQ1 Comparative Analysis')
    parser.add_argument('--config', type=str, default='configs/base_config.yaml',
                       help='Path to configuration file')
    parser.add_argument('--device', type=str, default='auto',
                       help='Device to use (auto, cuda, cpu)')
    
    args = parser.parse_args()
    
    # Load configuration
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    if args.device != 'auto':
        config['training']['device'] = args.device
    
    # Run experiment
    experiment = RQ1Experiment(config)
    results, summary = experiment.run_experiment()
    
    return results, summary


if __name__ == "__main__":
    main()
