import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import time
import pandas as pd
from tqdm import tqdm

# Import environment and runner
from kidnapped.kidnapped_robot_env import KidnappedRobotEnv
from kidnapped.kidnapped_robot_runner import KidnappedRobotRunner

# Import the safety wrappers
from experiment_fixes import apply_global_patches, patch_experiment_runner

# Apply global patches
apply_global_patches()

# Import belief representation methods
try:
    # Original methods
    from escort.escort_improvements import ImprovedESCORT as ESCORT
    from escort.svgd_improvements import RobustSVGD as SVGD
    from dvrl.dvrl_adapter_kidnapped import DVRL
    from pomcpow.pomcpow_adapter_kidnapped import POMCPOWAdapter
    
    # Ablation variants
    from escort.escort_nocorr import ESCORTNoCorr
    from escort.escort_notemp import ESCORTNoTemp
    from escort.escort_noproj import ESCORTNoProj
except ImportError as e:
    print(f"Import error: {e}")
    print("Make sure all implementation files are in the correct path")
    sys.exit(1)

def run_experiment(n_episodes=5, max_steps=100, n_particles=100, 
                  methods_to_run=None, save_dir="results"):
    """
    Run the complete Kidnapped Robot experiment comparing different
    belief approximation methods.
    
    Args:
        n_episodes: Number of episodes to run
        max_steps: Maximum steps per episode
        n_particles: Number of particles for belief representation
        methods_to_run: List of methods to evaluate (if None, run all)
        save_dir: Directory to save results
        
    Returns:
        DataFrame with results
    """
    # Default to all methods if not specified
    if methods_to_run is None:
        methods_to_run = ['ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj', 
                          'SVGD', 'DVRL', 'POMCPOW']
    
    # Create environment parameters
    env_params = {
        "map_size": 20,
        "n_landmarks": 15,
        "sensor_range": 5,
        "noise_level": 0.1
    }
    
    # Create runner
    runner = KidnappedRobotRunner(env_params, save_dir=save_dir)
    
    # Apply safety patches to the runner
    runner = patch_experiment_runner(runner)
    
    # Initialize methods
    methods = {}
    
    # State dimension is fixed at 20 for this environment
    state_dim = 20
    
    # Add methods based on what's requested
    if 'ESCORT' in methods_to_run:
        print("Initializing ESCORT...")
        try:
            methods['ESCORT'] = ESCORT(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.1,
                n_projections=10
            )
        except Exception as e:
            print(f"Error initializing ESCORT: {e}")
            if 'ESCORT' in methods_to_run:
                methods_to_run.remove('ESCORT')
    
    # Add ESCORT-NoCorr if requested
    if 'ESCORT-NoCorr' in methods_to_run:
        print("Initializing ESCORT-NoCorr...")
        try:
            methods['ESCORT-NoCorr'] = ESCORTNoCorr(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.0,  # No correlation regularization
                lambda_temp=0.1,
                n_projections=10
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoCorr: {e}")
            if 'ESCORT-NoCorr' in methods_to_run:
                methods_to_run.remove('ESCORT-NoCorr')
    
    # Add ESCORT-NoTemp if requested
    if 'ESCORT-NoTemp' in methods_to_run:
        print("Initializing ESCORT-NoTemp...")
        try:
            methods['ESCORT-NoTemp'] = ESCORTNoTemp(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.0,  # No temporal consistency
                n_projections=10
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoTemp: {e}")
            if 'ESCORT-NoTemp' in methods_to_run:
                methods_to_run.remove('ESCORT-NoTemp')
    
    # Add ESCORT-NoProj if requested
    if 'ESCORT-NoProj' in methods_to_run:
        print("Initializing ESCORT-NoProj...")
        try:
            methods['ESCORT-NoProj'] = ESCORTNoProj(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.1,
                n_projections=10
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoProj: {e}")
            if 'ESCORT-NoProj' in methods_to_run:
                methods_to_run.remove('ESCORT-NoProj')
    
    if 'SVGD' in methods_to_run:
        print("Initializing SVGD...")
        try:
            methods['SVGD'] = SVGD(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01
            )
        except Exception as e:
            print(f"Error initializing SVGD: {e}")
            if 'SVGD' in methods_to_run:
                methods_to_run.remove('SVGD')
    
    if 'DVRL' in methods_to_run:
        print("Initializing DVRL...")
        try:
            methods['DVRL'] = DVRL(
                state_dim=state_dim,
                belief_dim=10,
                n_particles=n_particles
            )
        except Exception as e:
            print(f"Error initializing DVRL: {e}")
            if 'DVRL' in methods_to_run:
                methods_to_run.remove('DVRL')
    
    if 'POMCPOW' in methods_to_run:
        print("Initializing POMCPOW...")
        # POMCPOW needs action space
        action_space = list(range(4))  # 4 actions: forward, left, right, stay
        
        try:
            methods['POMCPOW'] = POMCPOWAdapter(
                action_space=action_space,
                n_particles=n_particles,
                max_depth=5,  # Smaller depth for computational efficiency
                n_simulations=100,  # Fewer simulations for computational efficiency
                exploration_const=50.0,
                alpha_action=0.5,
                k_action=4.0,
                alpha_obs=0.5,
                k_obs=4.0,
                discount_factor=0.95
            )
        except Exception as e:
            print(f"Error initializing POMCPOW: {e}")
            if 'POMCPOW' in methods_to_run:
                methods_to_run.remove('POMCPOW')
    
    # Check if we still have methods to run
    if not methods:
        print("No methods available to run. Exiting.")
        return pd.DataFrame()
    
    # Run experiment 
    print("Running experiment...")
    try:
        # Use the original run_experiment without patching it
        results_df = runner.run_experiment(
            methods=methods,
            n_episodes=n_episodes,
            max_steps=max_steps,
            n_particles=n_particles
        )
        
        # Print summary of results
        print("\nResults Summary:")
        summary = results_df.groupby('Method').mean().reset_index()
        print(summary[['Method', 'Final Position Error', 'MMD', 'Correlation Error', 'Mode Coverage']])
        
        return results_df
    
    except Exception as e:
        print(f"Error running experiment: {e}")
        # Create a fallback DataFrame with empty results
        results = []
        for method_name in methods.keys():
            results.append({
                'Method': method_name,
                'Episode': 0,
                'Final Position Error': float('nan'),
                'MMD': float('nan'),
                'Correlation Error': float('nan'),
                'Mode Coverage': float('nan'),
                'Success': False
            })
        return pd.DataFrame(results)

if __name__ == "__main__":
    import argparse
    
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Kidnapped Robot Experiment')
    parser.add_argument('--episodes', type=int, default=10, 
                      help='Number of episodes (default: 5)')
    parser.add_argument('--steps', type=int, default=100, 
                      help='Maximum steps per episode (default: 100)')
    parser.add_argument('--particles', type=int, default=100, 
                      help='Number of particles (default: 100)')
    parser.add_argument('--methods', nargs='+', 
                      default=['ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj'],
                      help='Methods to evaluate (default: all methods)')
    parser.add_argument('--save_dir', type=str, default='results',
                      help='Directory to save results (default: results)')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Create save directory
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Run experiment
    results = run_experiment(
        n_episodes=args.episodes,
        max_steps=args.steps,
        n_particles=args.particles,
        methods_to_run=args.methods,
        save_dir=args.save_dir
    )
    
    # Save results
    if not results.empty:
        results_path = os.path.join(args.save_dir, "results_summary.csv")
        results.to_csv(results_path, index=False)
        print(f"Experiment complete. Results saved in: {args.save_dir}")
    else:
        print("Experiment failed to produce results.")
