import os
import sys
import numpy as np
import time
import traceback
from tqdm import tqdm
import warnings

# Import environment and runner
from multi_target_tracking.multi_target_tracking_env import MultiTargetTracking20DEnv
from multi_target_tracking.multi_target_tracking_runner import MultiTargetTracking20DRunner

# Import robust fixes
from multi_target_tracking.experiment_fixes import (
    fix_covariance_matrix, 
    safe_multivariate_normal,
    apply_global_patches, 
    patch_environment, 
    fix_computation_functions, 
    optimize_runner_performance, 
    fix_serialization_issues
)

# First apply the global patches to numpy functions
# This is crucial to prevent warnings
apply_global_patches()

# Import belief representation methods
try:
    # Original methods
    from escort.escort_improvements import ImprovedESCORT as ESCORT
    from escort.svgd_improvements import RobustSVGD as SVGD
    from dvrl.dvrl_adapter_tracking import DVRL
    from pomcpow.pomcpow_adapter_mtt import POMCPOWAdapterMTT as POMCPOW
    
    # Ablation variants
    from escort.escort_nocorr import ESCORTNoCorr
    from escort.escort_notemp import ESCORTNoTemp
    from escort.escort_noproj import ESCORTNoProj
except ImportError as e:
    print(f"Import error: {e}")
    print("Make sure all implementation files are in the correct path")
    sys.exit(1)

# Try to import joblib for parallel processing
try:
    from joblib import Parallel, delayed, cpu_count
    PARALLEL_AVAILABLE = True
except ImportError:
    PARALLEL_AVAILABLE = False
    print("joblib not available. Install with 'pip install joblib' for parallel processing")

class EnhancedMultiTargetTracking20DEnv(MultiTargetTracking20DEnv):
    """Enhanced environment with more robust correlation handling"""
    
    def _apply_correlation(self, state_delta):
        """
        Robust implementation of _apply_correlation that properly handles
        numerical stability issues with covariance matrices.
        
        Args:
            state_delta: State delta for correlation
            
        Returns:
            Correlated delta
        """
        # Convert correlation matrix to covariance matrix using state_delta as scale
        scales = np.abs(state_delta) + 0.01  # Add small constant to avoid zero scale
        cov_matrix = np.outer(scales, scales) * self.correlation_matrix
        
        # Use the fix_covariance_matrix function to ensure positive-semidefiniteness
        fixed_cov = fix_covariance_matrix(cov_matrix, min_eigenvalue=1e-5)
        
        try:
            # Try using fixed covariance matrix
            correlated_noise = safe_multivariate_normal(
                mean=np.zeros(self.state_dim), 
                cov=fixed_cov
            )
        except Exception as e:
            # If that still fails, use diagonal covariance as fallback
            print(f"Warning: Using diagonal covariance as fallback: {e}")
            diag_values = np.abs(np.diag(fixed_cov))
            diag_values[diag_values < 1e-6] = 1e-6
            correlated_noise = np.random.normal(0, 0.1, self.state_dim)
        
        # Scale the noise based on the intended state_delta
        direction = np.sign(state_delta)
        magnitude = np.abs(state_delta)
        
        # Combine direction and magnitude with correlation structure
        correlated_delta = direction * (magnitude + 0.1 * correlated_noise)
        
        return correlated_delta

# Main experiment function
def run_experiment(n_episodes=1, max_steps=100, n_particles=100, 
                  methods_to_run=None, save_dir=None,
                  visualize_freq=20):
    """Run experiment with all necessary numerical stability fixes applied"""
    
    # Suppress specific NumPy warnings about PSD matrices since we're fixing them
    warnings.filterwarnings('ignore', message='covariance is not symmetric positive-semidefinite')
    
    # Default to all methods if not specified
    if methods_to_run is None:
        methods_to_run = ['ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj',
                         'SVGD', 'DVRL', 'POMCPOW']
    
    # Use script directory if save_dir is not specified
    if save_dir is None:
        script_dir = os.path.dirname(os.path.abspath(__file__))
        save_dir = os.path.join(script_dir, "results_mtt20d")
    
    # Create environment parameters
    env_params = {
        "map_size": 10,
        "n_targets": 4,  # Fixed for 20D state space
        "noise_level": 0.5
    }
    
    # Create enhanced environment with robust correlation handling
    env = EnhancedMultiTargetTracking20DEnv(**env_params)
    
    # Create runner with the enhanced environment
    runner = MultiTargetTracking20DRunner(env_params, save_dir=save_dir)
    runner.env = env  # Replace the default environment with our enhanced version
    
    # Apply all fixes to the runner
    runner = patch_environment(runner)
    runner = fix_computation_functions(runner)
    runner = optimize_runner_performance(runner)
    runner = fix_serialization_issues(runner)
    
    # Initialize methods
    methods = {}
    
    # State dimension is fixed at 20 for this environment
    state_dim = 20
    
    # Add methods based on what's requested
    if 'ESCORT' in methods_to_run:
        print("Initializing ESCORT...")
        try:
            methods['ESCORT'] = ESCORT(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.1,
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT: {e}")
    
    # Add ESCORT-NoCorr if requested
    if 'ESCORT-NoCorr' in methods_to_run:
        print("Initializing ESCORT-NoCorr...")
        try:
            methods['ESCORT-NoCorr'] = ESCORTNoCorr(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.0,  # No correlation regularization
                lambda_temp=0.1,
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoCorr: {e}")
            if 'ESCORT-NoCorr' in methods_to_run:
                methods_to_run.remove('ESCORT-NoCorr')
    
    # Add ESCORT-NoTemp if requested
    if 'ESCORT-NoTemp' in methods_to_run:
        print("Initializing ESCORT-NoTemp...")
        try:
            methods['ESCORT-NoTemp'] = ESCORTNoTemp(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.0,  # No temporal consistency
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoTemp: {e}")
            if 'ESCORT-NoTemp' in methods_to_run:
                methods_to_run.remove('ESCORT-NoTemp')
    
    # Add ESCORT-NoProj if requested
    if 'ESCORT-NoProj' in methods_to_run:
        print("Initializing ESCORT-NoProj...")
        try:
            methods['ESCORT-NoProj'] = ESCORTNoProj(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01,
                lambda_corr=0.1,
                lambda_temp=0.1,
                n_projections=5
            )
        except Exception as e:
            print(f"Error initializing ESCORT-NoProj: {e}")
            if 'ESCORT-NoProj' in methods_to_run:
                methods_to_run.remove('ESCORT-NoProj')
    
    if 'SVGD' in methods_to_run:
        print("Initializing SVGD...")
        try:
            methods['SVGD'] = SVGD(
                n_particles=n_particles,
                state_dim=state_dim,
                kernel_bandwidth=0.1,
                step_size=0.01
            )
        except Exception as e:
            print(f"Error initializing SVGD: {e}")
    
    if 'DVRL' in methods_to_run:
        print("Initializing DVRL...")
        try:
            methods['DVRL'] = DVRL(
                state_dim=state_dim,
                belief_dim=state_dim,  # Match state dimension
                n_particles=n_particles
            )
        except Exception as e:
            print(f"Error initializing DVRL: {e}")
    
    if 'POMCPOW' in methods_to_run:
        print("Initializing POMCPOW...")
        # POMCPOW needs action space
        action_space = list(range(4))  # 4 actions: up, right, down, left
        
        try:
            methods['POMCPOW'] = POMCPOW(
                action_space=action_space,
                n_particles=n_particles,
                max_depth=3,  # Smaller depth for computational efficiency
                n_simulations=50,  # Fewer simulations for computational efficiency
                exploration_const=10.0,
                alpha_action=0.5,
                k_action=4.0,
                alpha_obs=0.5,
                k_obs=4.0,
                discount_factor=0.95
            )
        except Exception as e:
            print(f"Error initializing POMCPOW: {e}")
    
    # Check if we still have methods to run
    if not methods:
        print("No methods available to run. Exiting.")
        import pandas as pd
        return pd.DataFrame()
    
    # Create a direct subclass of runner's class to override the visualization method
    RunnerClass = type(runner)
    
    class VisualizeFreqRunner(RunnerClass):
        """Subclass that overrides the visualization method to control frequency"""
        def __init__(self, original_runner, freq):
            # Copy all attributes from the original runner
            for attr_name in dir(original_runner):
                if not attr_name.startswith('__'):
                    try:
                        setattr(self, attr_name, getattr(original_runner, attr_name))
                    except (AttributeError, TypeError):
                        pass
            
            # Store visualization frequency
            self.visualize_freq = freq
            self.max_steps = max_steps
        
        def _visualize_beliefs(self, particles_dict, episode, step):
            """Override visualization method to control frequency"""
            if step % self.visualize_freq == 0 or step == self.max_steps - 1:
                # Call the parent class method using super()
                super()._visualize_beliefs(particles_dict, episode, step)
    
    # Create a new runner with controlled visualization frequency
    freq_runner = VisualizeFreqRunner(runner, visualize_freq)
    
    # Run the experiment using the modified runner
    try:
        print("Running experiment...")
        results_df = freq_runner.run_experiment(
            methods=methods,
            n_episodes=n_episodes,
            max_steps=max_steps,
            n_particles=n_particles
        )
        
        return results_df
    
    except Exception as e:
        print(f"Error running experiment: {e}")
        traceback.print_exc()
        
        # Return empty DataFrame
        import pandas as pd
        return pd.DataFrame()

if __name__ == "__main__":
    import argparse
    
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Multi-Target Tracking 20D Experiment')
    parser.add_argument('--episodes', type=int, default=10, 
                      help='Number of episodes (default: 1)')
    parser.add_argument('--steps', type=int, default=100, 
                      help='Maximum steps per episode (default: 100)')
    parser.add_argument('--particles', type=int, default=100, 
                      help='Number of particles (default: 100)')
    parser.add_argument('--methods', nargs='+', 
                      choices=['ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj',
                              'SVGD', 'DVRL', 'POMCPOW'],
                      default=['ESCORT', 'ESCORT-NoCorr', 'ESCORT-NoTemp', 'ESCORT-NoProj'],
                      help='Methods to evaluate (default: all methods)')
    parser.add_argument('--save_dir', type=str, default=None,
                      help='Directory to save results (default: script directory)')
    parser.add_argument('--visualize_freq', type=int, default=20,
                      help='Visualization frequency (default: 20)')
    parser.add_argument('--quick-test', action='store_true',
                      help='Run a quick test with minimal settings')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Handle quick test option
    if args.quick_test:
        print("Running in quick test mode with minimal settings")
        args.episodes = 1
        args.steps = 20
        args.particles = 50
        args.methods = ['POMCPOW']  # Just use one method for quick test
        args.visualize_freq = 10
    
    # If save_dir is not provided, use the script directory
    if args.save_dir is None:
        script_dir = os.path.dirname(os.path.abspath(__file__))
        save_dir = os.path.join(script_dir, "results_mtt20d")
    else:
        save_dir = args.save_dir
    
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Run experiment
    results = run_experiment(
        n_episodes=args.episodes,
        max_steps=args.steps,
        n_particles=args.particles,
        methods_to_run=args.methods,
        save_dir=save_dir,
        visualize_freq=args.visualize_freq
    )
    
    # Save results
    if not results.empty:
        results_path = os.path.join(save_dir, "results_summary.csv")
        results.to_csv(results_path, index=False)
        print(f"Experiment complete. Results saved in: {save_dir}")
    else:
        print("Experiment failed to produce results.")
