import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import warnings

warnings.filterwarnings('ignore')

from src.new_formal.model.predictor720 import BanditDDPMPredictor720
from src.UEP.framework.UEP import BalancedBanditDiffUCBPerfectD
from src.UEP.env.envs import gradually_diverging


class SingleArmProcessExperiment:
    

    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.results = {}

    def create_single_arm_env(self, env_func, d_value, T=200, noise_std=0.05, **env_kwargs):
        
        means, d = env_func(1, T, d=d_value, **env_kwargs)

        class SingleArmWrapper:
            def __init__(self, means, d, noise_std):
                self.means = means[0]
                self.d = d
                self.noise_std = noise_std
                self.t = 0
                self.T = len(self.means)

            def step(self):
                if self.t >= self.T:
                    raise StopIteration("Reached maximum steps")

                true_prob = self.means[self.t]
                reward = true_prob + np.random.normal(0, self.noise_std)
                reward = np.clip(reward, 0, 1)

                self.t += 1
                return float(reward), float(true_prob)

            def reset(self):
                self.t = 0

            def get_current_state(self):
                return {
                    'step': self.t,
                    'true_prob': self.means[self.t] if self.t < self.T else None,
                    'd_value': self.d
                }

        return SingleArmWrapper(means, d, noise_std)

    def run_single_experiment(self, d_value, T=200, hist_len=50, timesteps=20, noise_std=0.05):
        
        print(f"\n{'=' * 60}")
        print(f"🔬 runs实验: d={d_value}, T={T}")
        print(f"{'=' * 60}")

        env = self.create_single_arm_env(gradually_diverging, d_value, T, noise_std, max_divergence=0.2)

        ucb = BalancedBanditDiffUCBPerfectD(
            n_arms=1,
            d_true=d_value,
            device=self.device,
            use_diffusion=True,
            use_adaptive_window=True,
            use_theory_guided=True,
            use_virtual_data=True,
            hist_len=hist_len,
            timesteps=timesteps
        )

        predictor = BanditDDPMPredictor720(
            n_arms=1,
            hist_len=hist_len,
            timesteps=timesteps,
            device=self.device
        )

        history = {
            "steps": [],
            "true_probs": [],
            "predicted_means": [],
            "actual_rewards": [],
            "prediction_errors": [],
            "optimal_window_sizes": [],
            "lambda_values": [],
            "window_means": [],
            "window_errors": [],
            "mse_hist_series": [],
            "mse_pred_series": []
        }

        print(f"📊 environments参数: d={d_value}, T={T}, Noise={noise_std}")
        print(f"🤖 algorithms参数: History length={hist_len}, Diffusion steps={timesteps}")
        print(f"🚀 Starting experiment...")

        rewards_history = []
        for step in range(T):
            try:
                optimal_window = max(1, min(len(rewards_history), hist_len)) if len(rewards_history) > 0 else 1

                with torch.no_grad():
                    try:
                        pred_mean, pred_sigma = predictor.predict_next_reward(
                            arm_idx=0,
                            window_size=optimal_window
                        )
                        pred_mean = float(pred_mean.item() if torch.is_tensor(pred_mean) else pred_mean)
                    except Exception as e:
                        print(f"  Prediction failed: {e}")
                        pred_mean = 0.5

                recent_rewards = rewards_history[-optimal_window:]
                window_mean = np.mean(recent_rewards) if recent_rewards else 0.5

                reward, true_prob = env.step()

                ucb.update(0, reward)
                predictor.update_history(torch.tensor([[reward]], device=self.device).float())
                if hasattr(predictor, 'true_prob_history'):
                    predictor.true_prob_history.append(true_prob)
                rewards_history.append(reward)

                pred_error = abs(pred_mean - true_prob)

                window_error = abs(window_mean - true_prob)

                if len(recent_rewards) >= 1:
                    hist_errors = [(r - true_prob) ** 2 for r in recent_rewards]
                    real_mse_hist = float(np.mean(hist_errors))
                else:
                    real_mse_hist = float((window_mean - true_prob) ** 2)
                real_mse_pred = float((pred_mean - true_prob) ** 2)

                cov_bound = 0.0
                A = max(real_mse_hist, 1e-12)
                B = max(real_mse_pred, 1e-12)
                C = max(cov_bound, 0.0)
                denom = A + B - 2.0 * C
                if denom > 1e-12:
                    lambda_value = (A - C) / denom
                else:
                    lambda_value = 0.5
                lambda_value = float(np.clip(lambda_value, 0.01, 0.99))

                history["mse_hist_series"].append(float(real_mse_hist))
                history["mse_pred_series"].append(float(real_mse_pred))

                history["steps"].append(step)
                history["true_probs"].append(true_prob)
                history["predicted_means"].append(pred_mean)
                history["actual_rewards"].append(reward)
                history["prediction_errors"].append(pred_error)
                history["optimal_window_sizes"].append(optimal_window)
                history["lambda_values"].append(lambda_value)
                history["window_means"].append(window_mean)
                history["window_errors"].append(window_error)

                if step % 10 == 0 and step > hist_len:
                    predictor.train_step(arm_idx=0)

            except StopIteration:
                break
            except Exception as e:
                print(f"  Step {step + 1} Error: {e}")
                continue

        self.results[f"d{d_value}"] = {
            'history': history,
            'env_params': {'d': d_value, 'T': T, 'noise_std': noise_std}
        }

        print(f"✅ 实验completed: d={d_value}")

        try:
            self.plot_mse_curves(history, d_value)
        except Exception as e:
            print(f"⚠️ Visualization failed: {e}")

        return history

    def run_all_experiments(self, T=200, hist_len=50, timesteps=20, noise_std=0.05, num_runs=1):
        
        d_values = [0.3, 1.0, 5.0]

        print("🎯 开始单臂environments处理实验")
        print("=" * 70)
        print(f"📋 Experiment configuration:")
        print(f"   • Time steps: {T}")
        print(f"   • History length: {hist_len}")
        print(f"   • Diffusion steps: {timesteps}")
        print(f"   • Noise standard deviation: {noise_std}")
        print(f"   • d parameter values: {d_values}")
        print(f"   • environments类型: gradually_diverging")

        for d_value in d_values:
            try:
                run_mean_lambdas = []
                for run_idx in range(num_runs):
                    history = self.run_single_experiment(
                        d_value=d_value,
                        T=T,
                        hist_len=hist_len,
                        timesteps=timesteps,
                        noise_std=noise_std
                    )
                    try:
                        self.results[f"d{d_value}_run{run_idx + 1}"] = {
                            'history': history,
                            'env_params': {'d': d_value, 'T': T, 'noise_std': noise_std},
                            'run_idx': run_idx + 1
                        }
                    except Exception:
                        pass
                    if 'lambda_values' in history and len(history['lambda_values']) > 0:
                        run_mean_lambdas.append(float(history['lambda_values'][-1]))
                if len(run_mean_lambdas) > 0:
                    if num_runs == 1:
                        print(f"\n📌 d={d_value}: Final lambda = {run_mean_lambdas[0]:.4f}")
                        lambdas = history.get('lambda_values', [])
                        if len(lambdas) >= 100:
                            avg_lambda_100_to_end = float(np.mean(lambdas[99:]))
                            print(f"📌 d={d_value}: 100-last step average lambda = {avg_lambda_100_to_end:.4f}")
                    else:
                        overall_mean_lambda = float(np.mean(run_mean_lambdas))
                        print(f"\n📌 d={d_value}: 十times的最终平均lambda = {overall_mean_lambda:.4f}")
                    try:
                        sys.stdout.flush()
                    except Exception:
                        pass
                else:
                    msg = "本times" if num_runs == 1 else "十times的最终平均lambda（数据不足）"
                    print(f"\n📌 d={d_value}: Cannot calculate{msg}")
                    try:
                        sys.stdout.flush()
                    except Exception:
                        pass
            except Exception as e:
                print(f"❌ Experiment failed: d={d_value} - {e}")
                continue

        print(f"\n🏆 所有实验completed!")
        print(f"   • Successful experiments: {len(self.results)}")
        print(f"   • Total experiments: {len(d_values)}")

        if num_runs > 1:
            try:
                print("\n🧮 汇总: 各d的十times最终平均lambda")
                for d_value in d_values:
                    group_keys = [k for k in self.results.keys() if k.startswith(f"d{d_value}_run")]
                    finals = []
                    for k in group_keys:
                        lambdas = self.results[k]['history'].get('lambda_values', [])
                        if lambdas:
                            finals.append(float(lambdas[-1]))
                    if finals:
                        print(f"  d={d_value}: 十times的最终平均lambda = {float(np.mean(finals)):.4f}")
                    else:
                        print(f"  d={d_value}: Cannot calculate (insufficient data)")
                try:
                    sys.stdout.flush()
                except Exception:
                    pass
            except Exception:
                pass

    def _smooth(self, array_like, window_size=11):
        
        data = np.asarray(array_like, dtype=float)
        if len(data) == 0:
            return data
        if window_size < 1:
            return data
        if window_size % 2 == 0:
            window_size += 1
        pad = window_size // 2
        left_pad = np.repeat(data[:1], pad)
        right_pad = np.repeat(data[-1:], pad)
        padded = np.concatenate([left_pad, data, right_pad])
        kernel = np.ones(window_size, dtype=float) / window_size
        smoothed = np.convolve(padded, kernel, mode='valid')
        return smoothed

    def plot_mse_curves(self, history, d_value, window_size=9):
        
        steps = history.get("steps", [])
        mse_hist = history.get("mse_hist_series", [])
        mse_pred = history.get("mse_pred_series", [])
        if len(steps) == 0 or len(mse_hist) == 0 or len(mse_pred) == 0:
            print("ℹ️ Insufficient data to plot MSE curves")
            return
        n = min(len(steps), len(mse_hist), len(mse_pred))
        steps = steps[:n]
        mse_hist = mse_hist[:n]
        mse_pred = mse_pred[:n]

        try:
            sm_hist = self._smooth(mse_hist, window_size=window_size)
            sm_pred = self._smooth(mse_pred, window_size=window_size)
        except Exception:
            sm_hist = np.array(mse_hist)
            sm_pred = np.array(mse_pred)

        plt.figure(figsize=(10, 5))
        plt.plot(steps, sm_hist[:len(steps)], label='Estimator MSE (smoothed)', color='#3C9D5B', linewidth=2)
        plt.plot(steps, sm_pred[:len(steps)], label='Predictor MSE (smoothed)', color='#A23B72', linewidth=2)
        plt.xlabel('Step')
        plt.ylabel('MSE')
        plt.title(f'MSE over Time (d={d_value})')
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.show()

    def print_results_summary(self):
        
        if not self.results:
            print("❌ No available results")
            return

        print(f"\n📊 Experiment results summary (average values every 20 steps):")
        print("=" * 80)

        for key, result in self.results.items():
            d_value = result['env_params']['d']
            history = result['history']

            print(f"\n🔬 d = {d_value}:")
            print("-" * 40)

            steps = history['steps']
            pred_errors = history['prediction_errors']
            window_errors = history['window_errors']
            lambdas = history['lambda_values']

            window_size = 20
            num_windows = len(steps) // window_size

            for i in range(num_windows):
                start_idx = i * window_size
                end_idx = min((i + 1) * window_size, len(steps))

                if start_idx < len(pred_errors):
                    avg_pred_error = np.mean(pred_errors[start_idx:end_idx])
                    avg_window_error = np.mean(window_errors[start_idx:end_idx])
                    avg_lambda = np.mean(lambdas[start_idx:end_idx])

                    print(f"  Step {start_idx + 1:3d}-{end_idx:3d}: "
                          f"平均Prediction error={avg_pred_error:.4f}, "
                          f"平均估计Error={avg_window_error:.4f}, "
                          f"平均lambda={avg_lambda:.4f}")

            if len(steps) % window_size != 0:
                start_idx = num_windows * window_size
                end_idx = len(steps)
                if start_idx < len(pred_errors):
                    avg_pred_error = np.mean(pred_errors[start_idx:end_idx])
                    avg_window_error = np.mean(window_errors[start_idx:end_idx])
                    avg_lambda = np.mean(lambdas[start_idx:end_idx])

                    print(f"  Step {start_idx + 1:3d}-{end_idx:3d}: "
                          f"平均Prediction error={avg_pred_error:.4f}, "
                          f"平均估计Error={avg_window_error:.4f}, "
                          f"平均lambda={avg_lambda:.4f}")

        print(f"\n📈 Overall statistics:")
        print("-" * 40)

        all_pred_errors = []
        all_window_errors = []
        all_lambdas = []

        for key, result in self.results.items():
            history = result['history']
            all_pred_errors.extend(history['prediction_errors'])
            all_window_errors.extend(history['window_errors'])
            all_lambdas.extend(history['lambda_values'])

        if all_pred_errors:
            print(f"  Overall average prediction error: {np.mean(all_pred_errors):.4f} ± {np.std(all_pred_errors):.4f}")
            print(f"  整体Average estimation error: {np.mean(all_window_errors):.4f} ± {np.std(all_window_errors):.4f}")
            print(f"  Overall average lambda: {np.mean(all_lambdas):.4f} ± {np.std(all_lambdas):.4f}")


def main():
    
    print("🎯 单臂environments处理实验系统")
    print("=" * 70)
    print("📋 Experiment objectives:")
    print("   • 验证ImprovedBanditDiffUCBalgorithmsEvaluated in单臂environments下的性能")
    print("   • 测试gradually_divergingenvironments")
    print("   • Analyze the impact of different d parameters (0.3, 1.0, 5.0)")
    print("   • Use真实d值进行预测")
    print("   • Output average performance indicators every 20 steps")

    experiment = SingleArmProcessExperiment()

    experiment.run_all_experiments(T=200, hist_len=50, timesteps=20, noise_std=0.05, num_runs=1)

    experiment.print_results_summary()

    print(f"\n✅ 所有实验已completed!")


if __name__ == "__main__":
    main()
