import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import deque
import sys
import warnings

warnings.filterwarnings('ignore')

plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

from src.new_formal.model.predictor720 import BanditDDPMPredictor720
from src.UEP.framework.UEP import BalancedBanditDiffUCBPerfectD
from src.UEP.env.envs import gradually_diverging, high_frequency_changes



plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")


class SingleArmValidationExperiment:
    

    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.results = {}
        self.d_estimation_threshold = None
        self.default_d_value = 1.0

    def create_single_arm_env(self, env_func, d_value, T=200, noise_std=0.05, **env_kwargs):
        
        means, d = env_func(1, T, d=d_value, **env_kwargs)

        class SingleArmWrapper:
            def __init__(self, means, d, noise_std):
                self.means = means[0]
                self.d = d
                self.noise_std = noise_std
                self.t = 0
                self.T = len(self.means)

            def step(self):
                if self.t >= self.T:
                    raise StopIteration("Reached maximum steps")

                true_prob = self.means[self.t]
                reward = true_prob + np.random.normal(0, self.noise_std)
                reward = np.clip(reward, 0, 1)

                self.t += 1
                return float(reward), float(true_prob)

            def reset(self):
                self.t = 0

            def get_current_state(self):
                return {
                    'step': self.t,
                    'true_prob': self.means[self.t] if self.t < self.T else None,
                    'd_value': self.d
                }

        return SingleArmWrapper(means, d, noise_std)


    def run_single_experiment(self, env_name, env_func, d_value, T=200,
                              hist_len=50, timesteps=20, noise_std=0.05, **env_kwargs):
        
        print(f"\n{'=' * 60}")
        print(f"🔬 runs实验: {env_name} (d={d_value})")
        print(f"{'=' * 60}")

        env = self.create_single_arm_env(env_func, d_value, T, noise_std, **env_kwargs)

        ucb = BalancedBanditDiffUCBPerfectD(
            n_arms=1,
            d_true=d_value,
            device=self.device,
            use_diffusion=True,
            use_adaptive_window=True,
            use_theory_guided=True,
            use_virtual_data=True,
            hist_len=hist_len,
            timesteps=timesteps
        )

        predictor = BanditDDPMPredictor720(
            n_arms=1,
            hist_len=hist_len,
            timesteps=timesteps,
            device=self.device
        )

        history = {
            "steps": [],
            "true_probs": [],
            "predicted_means": [],
            "predicted_sigmas": [],
            "actual_rewards": [],
            "prediction_errors": [],
            "optimal_window_sizes": [],
            "lambda_values": [],
            "mse_components": [],
            "confidence_bounds": [],
            "window_means": [],
            "window_errors": [],
            "d_estimates": [],
            "estimated_d_value": d_value
        }

        print(f"📊 environments参数: d={d_value}, T={T}, Noise={noise_std}")
        print(f"🤖 algorithms参数: History length={hist_len}, Diffusion steps={timesteps}")
        print(f"🚀 Starting experiment...")

        local_rewards_history = []
        for step in range(T):
            try:
                optimal_window = max(1, min(len(local_rewards_history), hist_len)) if len(
                    local_rewards_history) > 0 else 1

                with torch.no_grad():
                    try:
                        pred_mean, pred_sigma = predictor.predict_next_reward(
                            arm_idx=0,
                            window_size=optimal_window
                        )
                        pred_mean = float(pred_mean.item() if torch.is_tensor(pred_mean) else pred_mean)
                        pred_sigma = float(pred_sigma.item() if torch.is_tensor(pred_sigma) else pred_sigma)
                    except Exception as e:
                        print(f"  Prediction failed: {e}")
                        pred_mean, pred_sigma = 0.5, 0.1

                recent_rewards = local_rewards_history[-optimal_window:]
                window_mean = np.mean(recent_rewards) if recent_rewards else 0.5

                reward, true_prob = env.step()

                ucb.update(0, reward)
                predictor.update_history(torch.tensor([[reward]], device=self.device).float())
                if hasattr(predictor, 'true_prob_history'):
                    predictor.true_prob_history.append(true_prob)
                local_rewards_history.append(reward)

                history["d_estimates"].append(d_value)

                pred_error = abs(pred_mean - true_prob)

                window_error = abs(window_mean - true_prob)

                if len(recent_rewards) >= 1:
                    hist_errors = [(r - true_prob) ** 2 for r in recent_rewards]
                    real_mse_hist = float(np.mean(hist_errors))
                else:
                    real_mse_hist = float((window_mean - true_prob) ** 2)
                real_mse_pred = float((pred_mean - true_prob) ** 2)
                cov_bound = 0.0
                A = max(real_mse_hist, 1e-12)
                B = max(real_mse_pred, 1e-12)
                C = max(cov_bound, 0.0)
                denom = A + B - 2.0 * C
                if denom > 1e-12:
                    lambda_value = (A - C) / denom
                else:
                    lambda_value = 0.5
                lambda_value = float(np.clip(lambda_value, 0.01, 0.99))

                if step % 50 == 0 or step < 10:
                    print(f"    Debug: History data length={len(local_rewards_history)}")
                    print(
                        f"    Debug: 真实MSE组件: hist={real_mse_hist:.4f}, pred={real_mse_pred:.4f}, λ={lambda_value:.4f}")

                confidence_bound = 0.0

                history["steps"].append(step)
                history["true_probs"].append(true_prob)
                history["predicted_means"].append(pred_mean)
                history["predicted_sigmas"].append(pred_sigma)
                history["actual_rewards"].append(reward)
                history["prediction_errors"].append(pred_error)
                history["optimal_window_sizes"].append(optimal_window)
                history["lambda_values"].append(lambda_value)
                history["mse_components"].append((real_mse_hist, real_mse_pred, cov_bound))
                history["confidence_bounds"].append(confidence_bound)
                history["window_means"].append(window_mean)
                history["window_errors"].append(window_error)

                if step % 10 == 0 and step > hist_len:
                    predictor.train_step(arm_idx=0)

                if step % 50 == 0 or step < 10:
                    print(f"  Step {step + 1:3d}: True={true_prob:.3f}, Pred={pred_mean:.3f}±{pred_sigma:.3f}, "
                          f"Prediction error={pred_error:.3f}, Window={optimal_window}, Window mean={window_mean:.3f}, "
                          f"Window error={window_error:.3f}, λ={lambda_value:.3f}")

            except StopIteration:
                break
            except Exception as e:
                print(f"  Step {step + 1} Error: {e}")
                continue

        self.results[f"{env_name}_d{d_value}"] = {
            'history': history,
            'env_params': {'d': d_value, 'T': T, 'noise_std': noise_std},
            'env_name': env_name
        }

        final_estimated_d = history.get("estimated_d_value", d_value)
        print(f"✅ 实验completed: {env_name} (d={d_value})")
        print(f"🎯 Use的d值: {final_estimated_d:.4f}")

        return history

    def run_all_experiments(self, T=200, hist_len=50, timesteps=20, noise_std=0.05):
        
        d_values = [0.7]

        experiments = [
            ("gradually_diverging", gradually_diverging, {"max_divergence": 0.2}),
            ("high_frequency_changes", high_frequency_changes, {"oscillation_strength": 0.2})
        ]

        print("🎯 开始单臂environments验证实验")
        print("=" * 70)
        print(f"📋 Experiment configuration:")
        print(f"   • Time steps: {T}")
        print(f"   • History length: {hist_len}")
        print(f"   • Diffusion steps: {timesteps}")
        print(f"   • Noise standard deviation: {noise_std}")
        print(f"   • d parameter values: {d_values}")
        print(f"   • environments类型: {[exp[0] for exp in experiments]}")

        for env_name, env_func, env_kwargs in experiments:
            for d_value in d_values:
                try:
                    self.run_single_experiment(
                        env_name=env_name,
                        env_func=env_func,
                        d_value=d_value,
                        T=T,
                        hist_len=hist_len,
                        timesteps=timesteps,
                        noise_std=noise_std,
                        **env_kwargs
                    )
                except Exception as e:
                    print(f"❌ Experiment failed: {env_name} (d={d_value}) - {e}")
                    continue

        print(f"\n🏆 所有实验completed!")
        print(f"   • Successful experiments: {len(self.results)}")
        print(f"   • Total experiments: {len(experiments) * len(d_values)}")

        print(f"\n📊 d parameter estimation results summary:")
        for key, result in self.results.items():
            env_name = result['env_name']
            true_d = result['env_params']['d']
            estimated_d = result['history'].get("estimated_d_value", self.default_d_value)
            error = abs(estimated_d - true_d)
            print(f"   {env_name} (True d={true_d:.1f}): Estimated d={estimated_d:.4f}, Error={error:.4f}")

    def create_iclr_plots(self, save_dir="single_arm_validation_results"):
        
        os.makedirs(save_dir, exist_ok=True)

        if not self.results:
            print("❌ No results available for visualization")
            return

        plt.rcParams.update({
            'font.size': 16,
            'axes.titlesize': 18,
            'axes.labelsize': 16,
            'xtick.labelsize': 14,
            'ytick.labelsize': 14,
            'legend.fontsize': 14,
            'figure.titlesize': 20,
            'lines.linewidth': 2,
            'axes.linewidth': 1.2,
            'grid.alpha': 0.3
        })

        env_types = set([key.split('_d')[0] for key in self.results.keys()])

        for env_type in env_types:
            self._create_env_plots(env_type, save_dir)

        self._create_comparison_plots(save_dir)

    def _create_env_plots(self, env_type, save_dir):
        
        env_results = {k: v for k, v in self.results.items() if k.startswith(env_type)}
        if not env_results:
            return

        def extract_d_value(key):
            parts = key.split('_d')
            if len(parts) > 1:
                return float(parts[-1])
            return 0.0

        sorted_results = sorted(env_results.items(), key=lambda x: extract_d_value(x[0]))

        self._create_fused_dual_panel(env_type, sorted_results, save_dir)

        print(f"✅ {env_type} environments图表已保存")

    def _create_fused_dual_panel(self, env_type, sorted_results, save_dir):
        
        if not sorted_results:
            return
        key, result = sorted_results[0]
        d_value = self._extract_d_value(key) if hasattr(self, '_extract_d_value') else float(key.split('_d')[-1])
        history = result['history']
        steps = history['steps']
        true_probs = history['true_probs']
        pred_means = history['predicted_means']
        window_means = history['window_means']
        windows = history['optimal_window_sizes']
        lambdas = history['lambda_values']

        fig = plt.figure(figsize=(12, 10))
        gs = fig.add_gridspec(2, 1, height_ratios=[3, 2], hspace=0.3)

        ax_top = fig.add_subplot(gs[0, 0])
        ax_top.plot(steps, true_probs, label='True', linewidth=2.5, color='#2E86AB')
        ax_top.plot(steps, pred_means, label='Pred', linewidth=2.5, linestyle='--', color='#A23B72')
        ax_top.plot(steps, window_means, label='Esti', linewidth=2.0, linestyle='-.', color='#3C9D5B')
        ax_top.set_ylim(0, 1)
        ax_top.set_ylabel('Reward', fontsize=16)
        ax_top.set_title('Performance Of Predictor and Estimator  (d={})'.format(d_value), fontsize=18,
                         fontweight='bold', pad=16)
        ax_top.grid(True, alpha=0.25)
        ax_top.legend(loc='upper right', fontsize=14)

        ax_bot = fig.add_subplot(gs[1, 0])
        ax_bot2 = ax_bot.twinx()
        line_w = ax_bot.plot(steps, windows, label='Window', linewidth=2.5, color='#8B5A3C')
        line_l = ax_bot2.plot(steps, lambdas, label='λ', linewidth=2.5, color='#F18F01')
        ax_bot.set_ylabel('Window Size', fontsize=16, color='#8B5A3C')
        ax_bot2.set_ylabel('Mixing Weight λ', fontsize=16, color='#F18F01')
        ax_bot2.set_ylim(0, 1)
        ax_bot.set_xlabel('Time Step', fontsize=16)
        ax_bot.set_title('Window Size & Mixing Weight Evolution', fontsize=18, fontweight='bold', pad=16)
        ax_bot.grid(True, alpha=0.25)
        lines = line_w + line_l
        labels = [l.get_label() for l in lines]
        ax_bot.legend(lines, labels, loc='lower right', fontsize=14)

        plt.tight_layout()
        plt.savefig(f"{save_dir}/{env_type}_fused_dual_panel.png", dpi=300, bbox_inches='tight')
        plt.savefig(f"{save_dir}/{env_type}_fused_dual_panel.pdf", dpi=300, bbox_inches='tight')
        plt.show()

    def _create_comparison_plots(self, save_dir):
        

        def extract_d_value(key):
            parts = key.split('_d')
            if len(parts) > 1:
                return float(parts[-1])
            return 0.0

        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle('Cross-Environment Performance Comparison',
                     fontsize=16, fontweight='bold', y=0.98)

        all_data = {}
        for key, result in self.results.items():
            env_type = key.split('_d')[0]
            d_value = extract_d_value(key)

            if env_type not in all_data:
                all_data[env_type] = {}

            all_data[env_type][d_value] = result['history']

        ax1 = axes[0, 0]
        for env_type, env_data in all_data.items():
            d_values = sorted(env_data.keys())
            mean_errors = [np.mean(env_data[d]['prediction_errors']) for d in d_values]
            ax1.plot(d_values, mean_errors, 'o-', label=env_type.replace('_', ' ').title(),
                     linewidth=2, markersize=6)

        ax1.set_title('Average Prediction Error vs d Parameter', fontweight='bold')
        ax1.set_xlabel('d Parameter')
        ax1.set_ylabel('Average Prediction Error')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        ax2 = axes[0, 1]
        for env_type, env_data in all_data.items():
            d_values = sorted(env_data.keys())
            mean_windows = [np.mean(env_data[d]['optimal_window_sizes']) for d in d_values]
            ax2.plot(d_values, mean_windows, 'o-', label=env_type.replace('_', ' ').title(),
                     linewidth=2, markersize=6)

        ax2.set_title('Average Window Size vs d Parameter', fontweight='bold')
        ax2.set_xlabel('d Parameter')
        ax2.set_ylabel('Average Window Size')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        ax3 = axes[1, 0]
        for env_type, env_data in all_data.items():
            d_values = sorted(env_data.keys())
            mean_lambdas = [np.mean(env_data[d]['lambda_values']) for d in d_values]
            ax3.plot(d_values, mean_lambdas, 'o-', label=env_type.replace('_', ' ').title(),
                     linewidth=2, markersize=6)

        ax3.set_title('Average Mixing Weight vs d Parameter', fontweight='bold')
        ax3.set_xlabel('d Parameter')
        ax3.set_ylabel('Average Mixing Weight λ')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        ax3.set_ylim(0, 1)

        ax4 = axes[1, 1]

        env_types = list(all_data.keys())
        d_values = sorted(list(set([d for env_data in all_data.values() for d in env_data.keys()])))

        performance_matrix = np.zeros((len(env_types), len(d_values)))

        for i, env_type in enumerate(env_types):
            for j, d_value in enumerate(d_values):
                if d_value in all_data[env_type]:
                    mean_error = np.mean(all_data[env_type][d_value]['prediction_errors'])
                    performance_matrix[i, j] = 1 / (mean_error + 1e-6)

        im = ax4.imshow(performance_matrix, cmap='viridis', aspect='auto')
        ax4.set_xticks(range(len(d_values)))
        ax4.set_xticklabels([f'd={d}' for d in d_values])
        ax4.set_yticks(range(len(env_types)))
        ax4.set_yticklabels([env.replace('_', ' ').title() for env in env_types])
        ax4.set_title('Performance Heatmap (1/Error)', fontweight='bold')

        cbar = plt.colorbar(im, ax=ax4)
        cbar.set_label('Performance Score (1/Error)')

        for i in range(len(env_types)):
            for j in range(len(d_values)):
                if d_value in all_data[env_types[i]]:
                    text = ax4.text(j, i, f'{performance_matrix[i, j]:.2f}',
                                    ha="center", va="center", color="white", fontsize=8)

        plt.tight_layout()
        plt.savefig(f"{save_dir}/cross_environment_comparison.png",
                    dpi=300, bbox_inches='tight', facecolor='white')
        plt.savefig(f"{save_dir}/cross_environment_comparison.pdf",
                    dpi=300, bbox_inches='tight', facecolor='white')
        plt.show()

        print("✅ Comprehensive comparison charts saved")

    def save_results(self, save_dir="single_arm_validation_results"):
        
        os.makedirs(save_dir, exist_ok=True)

        import pickle
        with open(f"{save_dir}/experiment_results.pkl", 'wb') as f:
            pickle.dump(self.results, f)

        def extract_d_value(key):
            parts = key.split('_d')
            if len(parts) > 1:
                return float(parts[-1])
            return 0.0

        summary_data = []
        for key, result in self.results.items():
            env_type = key.split('_d')[0]
            d_value = extract_d_value(key)
            history = result['history']

            summary_data.append({
                'Environment': env_type,
                'd_parameter': d_value,
                'Mean_Pred_Error': np.mean(history['prediction_errors']),
                'Std_Pred_Error': np.std(history['prediction_errors']),
                'Mean_Window_Error': np.mean(history['window_errors']),
                'Std_Window_Error': np.std(history['window_errors']),
                'Mean_Window': np.mean(history['optimal_window_sizes']),
                'Std_Window': np.std(history['optimal_window_sizes']),
                'Mean_Lambda': np.mean(history['lambda_values']),
                'Std_Lambda': np.std(history['lambda_values']),
                'Final_Pred_Error': history['prediction_errors'][-1] if history['prediction_errors'] else 0,
                'Final_Window_Error': history['window_errors'][-1] if history['window_errors'] else 0,
                'Final_Window': history['optimal_window_sizes'][-1] if history['optimal_window_sizes'] else 0,
                'Final_Lambda': history['lambda_values'][-1] if history['lambda_values'] else 0
            })

        import pandas as pd
        df = pd.DataFrame(summary_data)
        df.to_csv(f"{save_dir}/experiment_summary.csv", index=False)

        print(f"✅ Experiment results saved to: {save_dir}")
        print(f"   • raw data: experiment_results.pkl")
        print(f"   • Statistical summary: experiment_summary.csv")

        return df

    def print_mean_lambda_by_setting(self):
        
        if not self.results:
            print("❌ No available results, cannot calculate λ mean")
            return None
        import pandas as pd
        rows = []
        for key, result in self.results.items():
            env_type = key.split('_d')[0]
            try:
                d_value = float(key.split('_d')[-1])
            except Exception:
                d_value = None
            lambdas = result['history'].get('lambda_values', [])
            mean_lambda = float(np.mean(lambdas)) if len(lambdas) > 0 else float('nan')
            rows.append({
                'Environment': env_type,
                'd_parameter': d_value,
                'Mean_Lambda': mean_lambda
            })
        df = pd.DataFrame(rows).sort_values(by=['Environment', 'd_parameter'])
        print("\n🔎 Mean λ by setting:")
        for _, r in df.iterrows():
            print(f"  - {r['Environment']}_d{r['d_parameter']}: Mean λ = {r['Mean_Lambda']:.4f}")
        return df


def main():
    
    print("🎯 单臂environments验证实验系统")
    print("=" * 70)
    print("📋 Experiment objectives:")
    print("   • 验证ImprovedBanditDiffUCBalgorithmsEvaluated in单臂environments下的性能")
    print("   • 测试gradually_diverging和high_frequency_changesenvironments")
    print("   • Analyze the impact of different d parameters (0.2,0.5,0.7,0.8,0.9,1.0,1.2)")
    print("   • Plot time series of 5 key indicators")
    print("   • Generate visualization results conforming to ICLR standards")

    experiment = SingleArmValidationExperiment()

    experiment.run_all_experiments(T=200, hist_len=50, timesteps=20, noise_std=0.05)

    print(f"\n📊 Generating visualization charts...")
    experiment.create_iclr_plots()

    print(f"\n💾 Saving experiment results...")
    summary_df = experiment.save_results()

    print(f"\n🏆 实验completed摘要:")
    print(f"   • Total experiments: {len(experiment.results)}")
    print(f"   • environments类型: {len(set([key.split('_d')[0] for key in experiment.results.keys()]))}")

    def extract_d_value(key):
        parts = key.split('_d')
        if len(parts) > 1:
            return float(parts[-1])
        return 0.0

    print(f"   • d parameter range: {sorted(set([extract_d_value(key) for key in experiment.results.keys()]))}")

    print(f"\n🎯 d parameter estimation performance summary:")
    all_errors = []
    for key, result in experiment.results.items():
        true_d = result['env_params']['d']
        estimated_d = result['history'].get("estimated_d_value", experiment.default_d_value)
        error = abs(estimated_d - true_d)
        all_errors.append(error)

    if all_errors:
        avg_error = np.mean(all_errors)
        max_error = np.max(all_errors)
        min_error = np.min(all_errors)
        print(f"   • Average estimation error: {avg_error:.4f}")
        print(f"   • Maximum estimation error: {max_error:.4f}")
        print(f"   • Minimum estimation error: {min_error:.4f}")
        print(f"   • Error standard deviation: {np.std(all_errors):.4f}")

    print(f"\n📈 Performance statistics:")
    print(summary_df.groupby('Environment')[
              ['Mean_Pred_Error', 'Mean_Window_Error', 'Mean_Window', 'Mean_Lambda']].mean())

    experiment.print_mean_lambda_by_setting()

    print(f"\n✅ 所有实验和可视化已completed!")


if __name__ == "__main__":
    main()
