import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import time
import pandas as pd
import traceback
from tqdm import tqdm
from datetime import datetime, timedelta

# Import environment and runner
from lightdark10d.lightdark10d_env import LightDark10DEnv
from lightdark10d.lightdark10d_runner import LightDark10DRunner

# Import the safety wrappers
from kidnapped.experiment_fixes import apply_global_patches, patch_experiment_runner

# Apply global patches
apply_global_patches()

# Import belief representation methods
try:
    # Original ESCORT and other methods
    from escort.escort_improvements import ImprovedESCORT as ESCORT
    from escort.svgd_improvements import RobustSVGD as SVGD
    from dvrl.dvrl_adapter_lightdark10d import DVRL
    from pomcpow.pomcpow import POMCPOW
    
    # Ablation variants
    from escort.escort_nocorr import ESCORTNoCorr
    from escort.escort_notemp import ESCORTNoTemp
    from escort.escort_noproj import ESCORTNoProj
except ImportError as e:
    print(f"Import error: {e}")
    print("Make sure all implementation files are in the correct path")
    sys.exit(1)

class ProgressTracker:
    """Helper class to track and display experiment progress"""
    def __init__(self, n_episodes, max_steps, methods):
        self.n_episodes = n_episodes
        self.max_steps = max_steps
        self.methods = methods
        self.start_time = time.time()
        self.episode_times = []
        
        # Create master progress bar
        self.master_pbar = tqdm(total=n_episodes, desc="Episodes", position=0)
        
        # Create method progress bar (will be updated for each method)
        self.method_pbar = tqdm(total=len(methods), desc="Methods", position=1)
        
        # Create step progress bar (will be updated for each step)
        self.step_pbar = tqdm(total=max_steps, desc="Steps", position=2)
        
        # Last update time for periodic status prints
        self.last_status_time = time.time()
        self.status_interval = 60  # status every 60 seconds
    
    def start_episode(self, episode):
        """Mark the start of a new episode"""
        self.master_pbar.update(1)
        self.master_pbar.set_description(f"Episode {episode+1}/{self.n_episodes}")
        self.episode_start_time = time.time()
        
        # Reset method progress
        self.method_pbar.reset()
        self.method_pbar.total = len(self.methods)
        
        # Print episode start
        elapsed = time.time() - self.start_time
        print(f"\n--- Starting Episode {episode+1}/{self.n_episodes} (Elapsed: {timedelta(seconds=int(elapsed))}) ---")
    
    def start_method(self, method_idx, method_name):
        """Mark the start of a new method"""
        self.method_pbar.update(1)
        self.method_pbar.set_description(f"Method: {method_name}")
        
        # Reset step progress
        self.step_pbar.reset()
        self.step_pbar.total = self.max_steps
    
    def update_step(self, step):
        """Update step progress"""
        self.step_pbar.update(1)
        self.step_pbar.set_description(f"Step {step+1}/{self.max_steps}")
        
        # Print status update periodically
        current_time = time.time()
        if current_time - self.last_status_time > self.status_interval:
            elapsed = current_time - self.start_time
            self.last_status_time = current_time
            
            # Calculate ETA if we have enough episode times
            eta_str = ""
            if len(self.episode_times) > 0:
                avg_episode_time = sum(self.episode_times) / len(self.episode_times)
                remaining_episodes = self.n_episodes - len(self.episode_times) - 1  # -1 for current episode
                if remaining_episodes > 0:
                    eta_seconds = avg_episode_time * remaining_episodes
                    eta = datetime.now() + timedelta(seconds=eta_seconds)
                    eta_str = f", ETA: {eta.strftime('%H:%M:%S')}"
            
            print(f"Status update - Elapsed: {timedelta(seconds=int(elapsed))}{eta_str}")
    
    def end_episode(self, episode):
        """Mark the end of an episode"""
        episode_time = time.time() - self.episode_start_time
        self.episode_times.append(episode_time)
        
        # Print episode completion
        avg_time = sum(self.episode_times) / len(self.episode_times)
        print(f"--- Completed Episode {episode+1}/{self.n_episodes} in {timedelta(seconds=int(episode_time))} " + 
              f"(Avg: {timedelta(seconds=int(avg_time))}) ---")
        
        # Print estimated time remaining
        remaining_episodes = self.n_episodes - (episode + 1)
        if remaining_episodes > 0:
            est_remaining = avg_time * remaining_episodes
            print(f"--- Estimated time remaining: {timedelta(seconds=int(est_remaining))} ---")
    
    def finish(self):
        """Clean up progress bars at end of experiment"""
        total_time = time.time() - self.start_time
        self.master_pbar.close()
        self.method_pbar.close()
        self.step_pbar.close()
        print(f"\n=== Experiment completed in {timedelta(seconds=int(total_time))} ===")

def run_experiment(n_episodes=5, max_steps=100, n_particles=100, 
                  methods_to_run=None, save_dir=None):
    """
    Run the complete Light-Dark 10D experiment comparing different
    belief approximation methods.
    
    Args:
        n_episodes: Number of episodes to run
        max_steps: Maximum steps per episode
        n_particles: Number of particles for belief representation
        methods_to_run: List of methods to evaluate (if None, run all)
        save_dir: Directory to save results (if None, use script directory)
        
    Returns:
        DataFrame with results
    """
    # Default to all methods if not specified
    if methods_to_run is None:
        methods_to_run = [
            'ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj',
            'SVGD', 'DVRL', 'POMCPOW'
        ]
    
    # Use script directory if save_dir is not specified
    if save_dir is None:
        # Get the directory of the current script
        script_dir = os.path.dirname(os.path.abspath(__file__))
        save_dir = os.path.join(script_dir, "results_lightdark")
    
    # Create environment parameters
    env_params = {
        "map_size": 10,
        "noise_level": 0.5
    }
    
    # Create runner
    runner = LightDark10DRunner(env_params, save_dir=save_dir)
    
    # Apply safety patches to the runner
    runner = patch_experiment_runner(runner)
    
    # Initialize methods
    methods = {}
    
    # State dimension is fixed at 10 for this environment
    state_dim = 10
    
    # Add methods based on what's requested
    if 'ESCORT' in methods_to_run:
        print("Initializing ESCORT...")
        try:
            methods['ESCORT'] = ESCORT(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.1,
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT: {e}")
            if 'ESCORT' in methods_to_run:
                methods_to_run.remove('ESCORT')
    
    # Add ESCORT-NoCorr if requested
    if 'ESCORT-NoCorr' in methods_to_run:
        print("Initializing ESCORT-NoCorr...")
        try:
            methods['ESCORT-NoCorr'] = ESCORTNoCorr(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.0,  # No correlation regularization
                lambda_temp=0.1,
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoCorr: {e}")
            if 'ESCORT-NoCorr' in methods_to_run:
                methods_to_run.remove('ESCORT-NoCorr')
    
    # Add ESCORT-NoTemp if requested
    if 'ESCORT-NoTemp' in methods_to_run:
        print("Initializing ESCORT-NoTemp...")
        try:
            methods['ESCORT-NoTemp'] = ESCORTNoTemp(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.0,  # No temporal consistency
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoTemp: {e}")
            if 'ESCORT-NoTemp' in methods_to_run:
                methods_to_run.remove('ESCORT-NoTemp')
    
    # Add ESCORT-NoProj if requested
    if 'ESCORT-NoProj' in methods_to_run:
        print("Initializing ESCORT-NoProj...")
        try:
            methods['ESCORT-NoProj'] = ESCORTNoProj(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.1,
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoProj: {e}")
            if 'ESCORT-NoProj' in methods_to_run:
                methods_to_run.remove('ESCORT-NoProj')
    
    if 'SVGD' in methods_to_run:
        print("Initializing SVGD...")
        try:
            methods['SVGD'] = SVGD(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01
            )
        except Exception as e:
            print(f"Error initializing SVGD: {e}")
            if 'SVGD' in methods_to_run:
                methods_to_run.remove('SVGD')
    
    if 'DVRL' in methods_to_run:
        print("Initializing DVRL...")
        try:
            methods['DVRL'] = DVRL(
                state_dim=state_dim,
                belief_dim=5,  # Reduced dimension for latent representation
                n_particles=n_particles
            )
        except Exception as e:
            print(f"Error initializing DVRL: {e}")
            if 'DVRL' in methods_to_run:
                methods_to_run.remove('DVRL')
    
    if 'POMCPOW' in methods_to_run:
        print("Initializing POMCPOW...")
        # POMCPOW needs action space
        action_space = list(range(10))  # 10 actions: +/- force in each dimension
        
        try:
            methods['POMCPOW'] = POMCPOW(
                action_space=action_space,
                n_particles=n_particles,
                max_depth=3,  # Smaller depth for computational efficiency
                n_simulations=50,  # Fewer simulations for computational efficiency
                exploration_const=10.0,
                alpha_action=0.5,
                k_action=4.0,
                alpha_obs=0.5,
                k_obs=4.0,
                discount_factor=0.95
            )
        except Exception as e:
            print(f"Error initializing POMCPOW: {e}")
            if 'POMCPOW' in methods_to_run:
                methods_to_run.remove('POMCPOW')
    
    # Check if we still have methods to run
    if not methods:
        print("No methods available to run. Exiting.")
        return pd.DataFrame()
    
    # Create a results dataframe to collect data incrementally
    all_results = []
    
    # Set up progress tracking
    progress = ProgressTracker(n_episodes=n_episodes, max_steps=max_steps, methods=methods)
    
    # Run experiment manually with progress tracking
    print("Running experiment...")
    try:
        # Run episodes
        for episode in range(n_episodes):
            progress.start_episode(episode)
            
            # Reset environment
            obs = runner.env.reset()
            runner.true_state = runner.env.state.copy()
            
            # Initialize particles randomly for each method
            particles = {}
            for method_idx, (method_name, method) in enumerate(methods.items()):
                progress.start_method(method_idx, method_name)
                
                # Initialize particles with appropriate ranges
                init_particles = np.zeros((n_particles, state_dim))
                
                # Position (uniformly distributed in map)
                init_particles[:, :5] = np.random.uniform(
                    0, env_params["map_size"], (n_particles, 5))
                
                # Velocity (normal distribution around zero)
                init_particles[:, 5:] = np.random.normal(0, 0.1, (n_particles, 5))
                
                particles[method_name] = init_particles
                
                # Save checkpoint of initial particles
                os.makedirs(os.path.join(save_dir, "checkpoints"), exist_ok=True)
                np.save(os.path.join(save_dir, "checkpoints", f"ep{episode+1}_{method_name}_init.npy"), 
                       init_particles)
            
            # Run episode
            for step in range(max_steps):
                progress.update_step(step)
                
                # Save ground truth state for evaluation
                true_state_copy = runner.env.state.copy()
                
                # Select action based on current belief - using simple heuristic
                action = runner._select_action(true_state_copy)
                
                # Take step in environment
                next_obs, reward, done, info = runner.env.step(action)
                runner.true_state = runner.env.state.copy()  # Update true state reference
                
                # Update belief for each method
                for method_name, method in methods.items():
                    start_time = time.time()
                    
                    # Update belief using the method
                    try:
                        # Different methods have different interfaces
                        if hasattr(method, 'update'):
                            # Standard update interface
                            method.update(action, next_obs, runner.transition_model, runner.observation_model)
                            updated_particles = method.get_belief_estimate()
                        elif hasattr(method, 'fit_transform'):
                            # Some methods use fit_transform pattern
                            updated_particles = method.fit_transform(
                                particles[method_name],
                                lambda x: runner.observation_model(x, next_obs),
                                None
                            )
                        else:
                            # Fallback: assume callable method
                            updated_particles = method(
                                particles[method_name],
                                lambda x: runner.observation_model(x, next_obs)
                            )
                        
                        # Update particles
                        particles[method_name] = updated_particles
                    except Exception as e:
                        print(f"Error updating {method_name}: {e}")
                        traceback.print_exc()
                        # Keep previous particles if update fails
                    
                    # Record runtime
                    runtime = time.time() - start_time
                
                # Visualize beliefs periodically
                if step % 10 == 0 or step == max_steps - 1:
                    runner._visualize_beliefs(particles, episode, step)
                
                # Update observation
                obs = next_obs
                
                # Save checkpoint every 25 steps
                if step % 25 == 0 or step == max_steps - 1:
                    for method_name, method_particles in particles.items():
                        np.save(os.path.join(save_dir, "checkpoints", 
                                           f"ep{episode+1}_{method_name}_step{step+1}.npy"), 
                               method_particles)
                
                # Break if done
                if done:
                    break
            
            # Evaluate performance for each method
            for method_name, method_particles in particles.items():
                # Compute metrics
                position_error = runner._compute_position_error(method_particles, runner.true_state)
                belief_metrics = runner._evaluate_belief_quality(method_particles, runner.true_state)
                
                # Check success - if we've reached the goal
                distance_to_goal = np.linalg.norm(runner.true_state[:5] - runner.env.goal)
                success = distance_to_goal < 0.5
                
                # Create result record
                result = {
                    "Method": method_name,
                    "Episode": episode,
                    "Steps": step + 1,
                    "Final Position Error": position_error["final"],
                    "Mean Position Error": position_error["mean"],
                    "Max Position Error": position_error["max"],
                    "MMD": belief_metrics["mmd"],
                    "Sliced Wasserstein": belief_metrics["sliced_wasserstein"],
                    "Correlation Error": belief_metrics["correlation_error"],
                    "Mode Coverage": belief_metrics["mode_coverage"],
                    "ESS": belief_metrics["ess"],
                    "Runtime": belief_metrics["runtime"],
                    "Success": success,
                    "Final Distance": distance_to_goal
                }
                
                # Add to results
                all_results.append(result)
                
                # Create partial results dataframe
                partial_df = pd.DataFrame(all_results)
                
                # Save partial results
                partial_df.to_csv(os.path.join(save_dir, "partial_results.csv"), index=False)
                
                # Print result summary
                print(f"{method_name}: Final Pos Error = {position_error['final']:.2f}, "
                      f"MMD = {belief_metrics['mmd']:.4f}, "
                      f"Mode Coverage = {belief_metrics['mode_coverage']:.2f}, "
                      f"Success = {success}")
            
            # End of episode
            progress.end_episode(episode)
        
        # End of experiment
        progress.finish()
        
        # Convert final results to DataFrame
        results_df = pd.DataFrame(all_results)
        
        # Print summary of results
        print("\nResults Summary:")
        summary = results_df.groupby('Method').mean().reset_index()
        print(summary[['Method', 'Final Position Error', 'MMD', 'Correlation Error', 'Mode Coverage', 'Success']])
        
        return results_df
    
    except Exception as e:
        print(f"Error running experiment: {e}")
        import traceback
        traceback.print_exc()
        
        # Clean up progress bars
        progress.finish()
        
        # Return partial results if any
        if all_results:
            print("Returning partial results...")
            return pd.DataFrame(all_results)
        
        # Create fallback DataFrame with empty results
        results = []
        for method_name in methods.keys():
            results.append({
                'Method': method_name,
                'Episode': 0,
                'Final Position Error': float('nan'),
                'MMD': float('nan'),
                'Correlation Error': float('nan'),
                'Mode Coverage': float('nan'),
                'Success': False
            })
        return pd.DataFrame(results)

if __name__ == "__main__":
    import argparse
    
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Light Dark 10D Experiment')
    parser.add_argument('--episodes', type=int, default=10, 
                      help='Number of episodes (default: 5)')
    parser.add_argument('--steps', type=int, default=100, 
                      help='Maximum steps per episode (default: 100)')
    parser.add_argument('--particles', type=int, default=100, 
                      help='Number of particles (default: 100)')
    # 'ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj', 'SVGD', 'DVRL', 'POMCPOW'
    parser.add_argument('--methods', nargs='+', 
                      default=['ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj'],
                      help='Methods to evaluate (default: all methods)')
    parser.add_argument('--save_dir', type=str, default=None,
                      help='Directory to save results (default: script directory)')
    
    # Parse arguments
    args = parser.parse_args()
    
    # If save_dir is not provided, use the script directory
    if args.save_dir is None:
        script_dir = os.path.dirname(os.path.abspath(__file__))
        save_dir = os.path.join(script_dir, "results_lightdark")
    else:
        save_dir = args.save_dir
    
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Run experiment
    results = run_experiment(
        n_episodes=args.episodes,
        max_steps=args.steps,
        n_particles=args.particles,
        methods_to_run=args.methods,
        save_dir=save_dir
    )
    
    # Save results
    if not results.empty:
        results_path = os.path.join(save_dir, "results_summary.csv")
        results.to_csv(results_path, index=False)
        print(f"Experiment complete. Results saved in: {save_dir}")
    else:
        print("Experiment failed to produce results.")
