import numpy as np
import torch
import time
from typing import Dict, List, Tuple
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))


from src.UEP.framework.UEP_ab_version import UEPBetaFixedPD

from src.new_formal.env.env import (
    static_traditional_environment,
    smooth_nonstationary,
    abrupt_change_environment,
    high_frequency_changes
)


class UEPAblationExperiment:
    
    
    def __init__(self, device: str = "cuda"):
        self.device = device
        self.results = {}
        
        self.environments = {
            'static': {
                'func': static_traditional_environment,
                'params': {'d': float('inf'), 'noise_level': 0.02},
                'name': '静态environments'
            },
            'smooth': {
                'func': smooth_nonstationary,
                'params': {'d': 1.2, 'noise_level': 0.15},
                'name': 'Smoothing非平稳environments'
            },
            'abrupt': {
                'func': abrupt_change_environment,
                'params': {'d': 0.8, 'noise_strength': 0.1},
                'name': '突变environments'
            },
            'high_freq': {
                'func': high_frequency_changes,
                'params': {'d': 0.6, 'oscillation_strength': 0.2},
                'name': '高频变化environments'
            }
        }
        
        self.ablation_configs = {
            'full': {
                'use_adaptive_window': True,
                'use_diffusion': True,
                'use_adaptive_lambda': True,
                'name': 'Full UEP'
            },
            'no_adaptive_window': {
                'use_adaptive_window': False,
                'use_diffusion': True,
                'use_adaptive_lambda': True,
                'name': 'No adaptive window'
            },
            'no_diffusion': {
                'use_adaptive_window': True,
                'use_diffusion': False,
                'use_adaptive_lambda': True,
                'name': 'No diffusion model'
            },
            'no_adaptive_lambda': {
                'use_adaptive_window': True,
                'use_diffusion': True,
                'use_adaptive_lambda': False,
                'name': 'No adaptive lambda'
            }
        }
        
        self.algorithms = {
            'UEP_Beta': {
                'class': UEPBetaFixedPD,
                'name': 'UEP Betaalgorithms'
            }
        }
        
        print("🎯 UEP消融实验初始化completed")
        print(f"   - environments数量: {len(self.environments)}")
        print(f"   - Number of ablation configurations: {len(self.ablation_configs)}")
        print(f"   - algorithms数量: {len(self.algorithms)}")
        print(f"   - Device: {self.device}")

    def run_single_experiment(self, env_name: str, config_name: str, 
                            algorithm_name: str = 'UEP',
                            n_arms: int = 4, T: int = 1000, 
                            d_true: float = 1.0) -> Dict:
        
        print(f"\n🔄 runs实验: {self.environments[env_name]['name']} + {self.ablation_configs[config_name]['name']} + {self.algorithms[algorithm_name]['name']}")
        
        env_func = self.environments[env_name]['func']
        env_params = self.environments[env_name]['params']
        means, d_env = env_func(n_arms, T, **env_params)
        
        actual_d = d_env if d_env != float('inf') else d_true
        
        config = self.ablation_configs[config_name]
        algorithm_class = self.algorithms[algorithm_name]['class']
        algorithm = algorithm_class(
            n_arms=n_arms,
            d_true=actual_d,
            device=self.device,
            use_adaptive_window=config['use_adaptive_window'],
            use_diffusion=config['use_diffusion'],
            use_adaptive_lambda=config['use_adaptive_lambda'],
            fixed_window=20,
            fixed_lambda=0.5
        )
        
        rewards = []
        regrets = []
        selected_arms = []
        
        start_time = time.time()
        
        for t in range(T):
            arm = algorithm.select_arm()
            selected_arms.append(arm)
            
            reward = np.random.binomial(1, means[arm, t]) if algorithm.detect_environment_type() else means[arm, t] + np.random.normal(0, 0.1)
            reward = np.clip(reward, 0.0, 1.0)
            rewards.append(reward)
            
            algorithm.update(arm, reward)
            
            regret = algorithm.compute_regret(means[:, t])
            regrets.append(regret)
        
        end_time = time.time()
        
        cumulative_regret = np.cumsum(regrets)
        final_regret = cumulative_regret[-1]
        avg_reward = np.mean(rewards)
        
        optimal_arms = np.argmax(means, axis=0)
        optimal_selections = np.sum([1 for i, arm in enumerate(selected_arms) if arm == optimal_arms[i]])
        optimal_ratio = optimal_selections / T
        
        result = {
            'env_name': env_name,
            'config_name': config_name,
            'algorithm_name': algorithm_name,
            'n_arms': n_arms,
            'T': T,
            'd_true': actual_d,
            'final_regret': final_regret,
            'avg_reward': avg_reward,
            'optimal_ratio': optimal_ratio,
            'cumulative_regret': cumulative_regret.tolist(),
            'rewards': rewards,
            'regrets': regrets,
            'selected_arms': selected_arms,
            'runtime': end_time - start_time,
            'algorithm_config': config
        }
        
        print(f"   ✅ completed - Final regret: {final_regret:.2f}, Average reward: {avg_reward:.3f}, Optimal selection rate: {optimal_ratio:.3f}")
        
        return result

    def run_ablation_experiment(self, n_runs: int = 10, n_arms: int = 4, T: int = 1000):
        
        print(f"\n🚀 Starting UEP ablation experiment")
        print(f"   - runstimes数: {n_runs}")
        print(f"   - Number of arms: {n_arms}")
        print(f"   - Number of steps: {T}")
        print(f"   - environments数: {len(self.environments)}")
        print(f"   - Number of ablation configurations: {len(self.ablation_configs)}")
        print(f"   - algorithms数: {len(self.algorithms)}")
        
        all_results = []
        
        for env_name in self.environments.keys():
            print(f"\n📊 Environment: {self.environments[env_name]['name']}")
            
            for algorithm_name in self.algorithms.keys():
                print(f"\n🤖 algorithms: {self.algorithms[algorithm_name]['name']}")
                
                for config_name in self.ablation_configs.keys():
                    print(f"\n🔧 Configuration: {self.ablation_configs[config_name]['name']}")
                    
                    env_results = []
                    
                    for run in range(n_runs):
                        print(f"   runs {run + 1}/{n_runs}...", end=" ")
                        
                        try:
                            result = self.run_single_experiment(
                                env_name=env_name,
                                config_name=config_name,
                                algorithm_name=algorithm_name,
                                n_arms=n_arms,
                                T=T
                            )
                            env_results.append(result)
                            all_results.append(result)
                            print("✅")
                            
                        except Exception as e:
                            print(f"❌ Error: {e}")
                            continue
                    
                    if env_results:
                        final_regrets = [r['final_regret'] for r in env_results]
                        avg_rewards = [r['avg_reward'] for r in env_results]
                        optimal_ratios = [r['optimal_ratio'] for r in env_results]
                        
                        print(f"   📈 Statistical results:")
                        print(f"      - Average final regret: {np.mean(final_regrets):.2f} ± {np.std(final_regrets):.2f}")
                        print(f"      - Average reward: {np.mean(avg_rewards):.3f} ± {np.std(avg_rewards):.3f}")
                        print(f"      - 平均Optimal selection rate: {np.mean(optimal_ratios):.3f} ± {np.std(optimal_ratios):.3f}")
        
        self.results = all_results
        self.print_summary()
        
        return all_results

    def print_summary(self):
        
        print(f"\n📋 UEP ablation experiment summary")
        print("=" * 100)
        
        for env_name in self.environments.keys():
            print(f"\n🌍 environments: {self.environments[env_name]['name']}")
            print("-" * 100)
            
            env_results = [r for r in self.results if r['env_name'] == env_name]
            
            for algorithm_name in self.algorithms.keys():
                print(f"\n🤖 algorithms: {self.algorithms[algorithm_name]['name']}")
                print("-" * 80)
                
                alg_results = [r for r in env_results if r['algorithm_name'] == algorithm_name]
                
                config_summary = {}
                for config_name in self.ablation_configs.keys():
                    config_results = [r for r in alg_results if r['config_name'] == config_name]
                    
                    if config_results:
                        final_regrets = [r['final_regret'] for r in config_results]
                        avg_rewards = [r['avg_reward'] for r in config_results]
                        optimal_ratios = [r['optimal_ratio'] for r in config_results]
                        
                        config_summary[config_name] = {
                            'mean_regret': np.mean(final_regrets),
                            'std_regret': np.std(final_regrets),
                            'mean_reward': np.mean(avg_rewards),
                            'std_reward': np.std(avg_rewards),
                            'mean_optimal': np.mean(optimal_ratios),
                            'std_optimal': np.std(optimal_ratios),
                            'n_runs': len(config_results)
                        }
                
                print(f"{'Configuration':<20} {'Final regret':<15} {'Average reward':<15} {'Optimal selection rate':<15} {'runstimes数':<10}")
                print("-" * 80)
                
                for config_name, stats in config_summary.items():
                    config_display = self.ablation_configs[config_name]['name']
                    print(f"{config_display:<20} "
                          f"{stats['mean_regret']:.2f}±{stats['std_regret']:.2f}    "
                          f"{stats['mean_reward']:.3f}±{stats['std_reward']:.3f}    "
                          f"{stats['mean_optimal']:.3f}±{stats['std_optimal']:.3f}    "
                          f"{stats['n_runs']:<10}")
        
        print("\n" + "=" * 100)
        print("🎯 消融实验completed！")


def main():
    
    print("🎯 UEPalgorithms消融实验")
    print("=" * 50)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"UseDevice: {device}")
    
    experiment = UEPAblationExperiment(device=device)
    
    results = experiment.run_ablation_experiment(
        n_runs=5,
        n_arms=4,
        T=1000
    )
    
    print(f"\n✅ 消融实验completed！共runs {len(results)} experiments")


if __name__ == "__main__":
    main()