#!/usr/bin/env python3
"""
Comparative Analysis for DATE-GFN

This script runs a comprehensive comparison between DATE-GFN and baseline methods
using GFlowNet implementations on the Hypergrid environment.
"""

import sys
import os
import argparse
import torch
import numpy as np
import json
import time
from pathlib import Path
from typing import Dict, List, Any

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from date_gfn.core.date_gfn import DATEGFN
from date_gfn.baselines.baselines import GFNBaseline, EGFNBaseline
from date_gfn.environments.environments import HypergridEnvironment


def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)


class RealComparativeAnalysis:
    """comparative analysis using actual GFlowNet implementations."""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.device = torch.device(config.get("device", "cuda" if torch.cuda.is_available() else "cpu"))
        self.results = {}
        
    def create_environment(self, env_config: Dict[str, Any]):
        """Create Hypergrid environment."""
        return HypergridEnvironment(
            height=env_config.get("height", 10),
            ndim=env_config.get("ndim", 4),
            reward_beta=env_config.get("reward_beta", 10.0),
            reward_at_corners=env_config.get("reward_at_corners", 1e-4)
        )
    
    def create_method(self, method_name: str, state_dim: int, action_dim: int):
        """Create method instance."""
        if method_name == "DATE-GFN":
            method = DATEGFN(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=256,
                population_size=50,
                teachability_weight=0.1,
                device=self.device
            )
            method.initialize_population()
            return method
            
        elif method_name == "TB-GFN":
            return GFNBaseline(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=256
            ).to(self.device)
            
        elif method_name == "EGFN":
            return EGFNBaseline(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dim=256,
                population_size=50
            ).to(self.device)
        
        else:
            raise ValueError(f"Unknown method: {method_name}")
    
    def train_method(self, method, env, num_steps: int):
        """Train a method for specified steps."""
        print(f"  Training {method.__class__.__name__} for {num_steps} steps...")
        
        training_metrics = []
        
        for step in range(num_steps):
            step_start = time.time()
            
            if hasattr(method, 'train_step'):
                # For baselines
                if isinstance(method, GFNBaseline):
                    optimizer = torch.optim.Adam(method.parameters(), lr=5e-4)
                    metrics = method.train_step(env, optimizer)
                elif isinstance(method, EGFNBaseline):
                    optimizer = torch.optim.Adam(method.student_gfn.parameters(), lr=5e-4)
                    metrics = method.train_step(env, optimizer)
                else:
                    metrics = {}
            else:
                # For DATE-GFN
                method.evolutionary_phase(env)
                method.distillation_phase(env)
                metrics = {}
            
            step_time = time.time() - step_start
            metrics['step_time'] = step_time
            metrics['step'] = step
            training_metrics.append(metrics)
            
            if (step + 1) % 100 == 0:
                print(f"    Step {step + 1}/{num_steps} (avg time: {np.mean([m['step_time'] for m in training_metrics[-100:]]):.3f}s)")
        
        return training_metrics
    
    def evaluate_method(self, method, env, num_samples: int = 200):
        """Evaluate method performance."""
        print(f"  Evaluating with {num_samples} samples...")
        
        start_time = time.time()
        trajectories = method.sample(env, num_samples=num_samples)
        sampling_time = time.time() - start_time
        
        # Calculate metrics
        coverage, num_modes = env.calculate_mode_coverage(trajectories)
        diversity = env.calculate_diversity(trajectories)
        l1_error = env.calculate_l1_error(trajectories)
        
        # Reward statistics
        rewards = [env.get_reward(traj[-1]) for traj in trajectories]
        mean_reward = np.mean(rewards)
        std_reward = np.std(rewards)
        max_reward = np.max(rewards)
        min_reward = np.min(rewards)
        
        # Trajectory length statistics
        traj_lengths = [len(traj) for traj in trajectories]
        mean_length = np.mean(traj_lengths)
        std_length = np.std(traj_lengths)
        
        return {
            'mode_coverage': coverage,
            'num_modes_found': num_modes,
            'total_modes': len(env.corner_states),
            'diversity': diversity,
            'l1_error': l1_error,
            'mean_reward': mean_reward,
            'std_reward': std_reward,
            'max_reward': max_reward,
            'min_reward': min_reward,
            'mean_traj_length': mean_length,
            'std_traj_length': std_length,
            'sampling_time': sampling_time,
            'num_samples': len(trajectories)
        }
    
    def run_experiment(self, method_name: str, env_config: Dict[str, Any], 
                      training_config: Dict[str, Any]):
        """Run single method experiment."""
        print(f"\n--- {method_name} Experiment ---")
        
        # Create environment
        env = self.create_environment(env_config)
        print(f"Environment: {env.height}^{env.ndim} Hypergrid ({len(env.corner_states)} modes)")
        
        # Create method
        state_dim = env.ndim
        action_dim = 2 * env.ndim + 1
        method = self.create_method(method_name, state_dim, action_dim)
        
        # Training phase
        training_start = time.time()
        training_metrics = self.train_method(method, env, training_config["num_steps"])
        total_training_time = time.time() - training_start
        
        # Evaluation phase
        eval_results = self.evaluate_method(method, env, training_config.get("eval_samples", 200))
        
        # Combine results
        results = {
            'method': method_name,
            'environment': env_config,
            'training': {
                'num_steps': training_config["num_steps"],
                'total_time': total_training_time,
                'avg_step_time': total_training_time / training_config["num_steps"],
                'metrics': training_metrics
            },
            'evaluation': eval_results
        }
        
        return results
    
    def run_full_comparison(self):
        """Run full comparative analysis."""
        config = self.config
        
        print("=" * 80)
        print("DATE-GFN Comparative Analysis")
        print("=" * 80)
        print(f"Device: {self.device}")
        print(f"Seed: {config['seed']}")
        
        # Set seed
        set_seed(config['seed'])
        
        all_results = []
        
        for env_config in config['environments']:
            env_name = f"{env_config['height']}^{env_config['ndim']}"
            print(f"\n{'='*20} Environment: {env_name} {'='*20}")
            
            for method_name in config['methods']:
                results = self.run_experiment(
                    method_name, 
                    env_config, 
                    config['training']
                )
                all_results.append(results)
                
                # Print summary
                eval_res = results['evaluation']
                train_res = results['training']
                print(f"Results - Coverage: {eval_res['mode_coverage']:.3f}, "
                      f"Diversity: {eval_res['diversity']:.3f}, "
                      f"L1_Error: {eval_res['l1_error']:.3f}, "
                      f"Training_Time: {train_res['total_time']:.1f}s")
        
        self.results = all_results
        return all_results
    
    def print_summary(self):
        """Print summary of all results."""
        if not self.results:
            print("No results to summarize.")
            return
        
        print("\n" + "="*80)
        print("📊 COMPREHENSIVE RESULTS SUMMARY")
        print("="*80)
        
        # Group by environment
        env_groups = {}
        for result in self.results:
            env_config = result['environment']
            env_key = f"{env_config['height']}^{env_config['ndim']}"
            if env_key not in env_groups:
                env_groups[env_key] = []
            env_groups[env_key].append(result)
        
        for env_name, env_results in env_groups.items():
            print(f"\n{env_name} Environment:")
            print(f"{'Method':<12} {'Coverage':<10} {'Diversity':<10} {'L1_Error':<10} "
                  f"{'MeanReward':<12} {'TrainTime(s)':<12}")
            print("-" * 78)
            
            for result in env_results:
                eval_res = result['evaluation']
                train_res = result['training']
                method = result['method']
                
                print(f"{method:<12} {eval_res['mode_coverage']:<10.3f} "
                      f"{eval_res['diversity']:<10.3f} {eval_res['l1_error']:<10.3f} "
                      f"{eval_res['mean_reward']:<12.6f} {train_res['total_time']:<12.1f}")
        
        # Overall ranking
        print(f"\n{'='*80}")
        print("🏆 OVERALL METHOD RANKING (by average mode coverage)")
        print("="*80)
        
        method_stats = {}
        for result in self.results:
            method = result['method']
            if method not in method_stats:
                method_stats[method] = []
            method_stats[method].append(result['evaluation']['mode_coverage'])
        
        method_averages = {method: np.mean(scores) for method, scores in method_stats.items()}
        ranked_methods = sorted(method_averages.items(), key=lambda x: x[1], reverse=True)
        
        for rank, (method, avg_coverage) in enumerate(ranked_methods, 1):
            std_coverage = np.std(method_stats[method])
            print(f"{rank}. {method}: {avg_coverage:.3f} ± {std_coverage:.3f}")
    
    def save_results(self, output_file: str):
        """Save results to JSON file."""
        with open(output_file, 'w') as f:
            json.dump(self.results, f, indent=2, default=str)
        print(f"\nResults saved to: {output_file}")


def main():
    parser = argparse.ArgumentParser(description="DATE-GFN Comparative Analysis")
    parser.add_argument("--config", type=str, help="Configuration file (JSON)")
    parser.add_argument("--output", type=str, default="comparison_results.json", 
                       help="Output file for results")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--steps", type=int, default=500, help="Training steps")
    parser.add_argument("--methods", nargs="+", default=["DATE-GFN", "TB-GFN", "EGFN"],
                       help="Methods to compare")
    
    args = parser.parse_args()
    
    # Default configuration
    if args.config:
        with open(args.config, 'r') as f:
            config = json.load(f)
    else:
        config = {
            "seed": args.seed,
            "methods": args.methods,
            "environments": [
                {"height": 8, "ndim": 2, "reward_beta": 10.0, "reward_at_corners": 1e-4},
                {"height": 6, "ndim": 3, "reward_beta": 10.0, "reward_at_corners": 1e-4},
                {"height": 5, "ndim": 4, "reward_beta": 10.0, "reward_at_corners": 1e-4}
            ],
            "training": {
                "num_steps": args.steps,
                "eval_samples": 200
            }
        }
    
    # Run analysis
    analyzer = RealComparativeAnalysis(config)
    analyzer.run_full_comparison()
    analyzer.print_summary()
    analyzer.save_results(args.output)
    
    print(f"\nComparative analysis completed successfully!")


if __name__ == "__main__":
    main()
