

import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import numpy as np
import pandas as pd
from collections import defaultdict, deque
import torch
import warnings
from scipy import stats
import json
import pickle
from datetime import datetime
from src.new_formal.env.env import (
    smooth_nonstationary,
    abrupt_change_environment,
    gradually_diverging,
    high_frequency_changes,
    competitive_balanced_environment,
    static_traditional_environment
)

warnings.filterwarnings('ignore')

from src.new_formal.framework.beta720 import ImprovedBanditDiffUCB
from src.UEP.framework.UEP_721 import ImprovedBanditDiffUCB as ImprovedBanditDiffUCB_721
from src.UEP.framework.UEP import BalancedBanditDiffUCBPerfectD
from src.UEP.framework.UEP_721_pd import ImprovedBanditDiffUCBPerfectD



class BaseBanditAlgorithm:
    
    
    def __init__(self, n_arms, algorithm_name, **kwargs):
        self.n_arms = n_arms
        self.algorithm_name = algorithm_name
        self.t = 0
    
    def get_diagnostics(self):
        
        return {'algorithm': self.algorithm_name, 'd_estimate': 1.0}
    
    def calculate_ucb_confidence(self, count, multiplier=2.0):
        
        if count == 0:
            return float('inf')
        return np.sqrt(multiplier * np.log(self.t) / count)
    
    def select_arm(self):
        
        raise NotImplementedError("子类必须实现select_arm方法")
    
    def update(self, arm, reward):
        
        raise NotImplementedError("子类必须实现update方法")


class UCB1(BaseBanditAlgorithm):
    

    def __init__(self, n_arms, **kwargs):
        super().__init__(n_arms, 'UCB1', **kwargs)
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.rewards[arm] / self.counts[arm]
                confidence = self.calculate_ucb_confidence(self.counts[arm])
                ucb_values.append(mean + confidence)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward


class UCB1Tuned(BaseBanditAlgorithm):
    

    def __init__(self, n_arms, **kwargs):
        super().__init__(n_arms, 'UCB1-Tuned', **kwargs)
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.squared_rewards = np.zeros(n_arms)

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.rewards[arm] / self.counts[arm]
                variance = (self.squared_rewards[arm] / self.counts[arm]) - mean ** 2
                variance = max(0, variance)
                V = variance + np.sqrt(2 * np.log(self.t) / self.counts[arm])
                confidence = np.sqrt(np.log(self.t) * min(0.25, V) / self.counts[arm])
                ucb_values.append(mean + confidence)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward
        self.squared_rewards[arm] += reward ** 2


class ThompsonSampling(BaseBanditAlgorithm):
    

    def __init__(self, n_arms, **kwargs):
        super().__init__(n_arms, 'Thompson Sampling', **kwargs)
        self.alpha = np.ones(n_arms)
        self.beta = np.ones(n_arms)

    def select_arm(self):
        samples = [np.random.beta(self.alpha[i], self.beta[i]) for i in range(self.n_arms)]
        return np.argmax(samples)

    def update(self, arm, reward):
        if reward == 1:
            self.alpha[arm] += 1
        else:
            self.beta[arm] += 1


class EpsilonGreedy(BaseBanditAlgorithm):
    

    def __init__(self, n_arms, epsilon=0.1, decay=True, **kwargs):
        super().__init__(n_arms, 'Epsilon-Greedy', **kwargs)
        self.epsilon = epsilon
        self.decay = decay
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)

    def select_arm(self):
        self.t += 1
        current_epsilon = self.epsilon / np.sqrt(self.t) if self.decay else self.epsilon

        if np.random.random() < current_epsilon:
            return np.random.randint(self.n_arms)
        else:
            means = np.divide(self.rewards, self.counts,
                              out=np.zeros_like(self.rewards), where=self.counts != 0)
            return np.argmax(means)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward


class MOSS:
    

    def __init__(self, n_arms, horizon=1000, **kwargs):
        self.n_arms = n_arms
        self.horizon = horizon
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.rewards[arm] / self.counts[arm]
                confidence = np.sqrt(max(0, np.log(self.horizon / (self.n_arms * self.counts[arm]))) / self.counts[arm])
                ucb_values.append(mean + confidence)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward

    def get_diagnostics(self):
        return {'algorithm': 'MOSS', 'd_estimate': 1.0}


class KLUCB:
    

    def __init__(self, n_arms, **kwargs):
        self.n_arms = n_arms
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.t = 0

    def kl_divergence(self, p, q):
        p = max(min(p, 1 - 1e-15), 1e-15)
        q = max(min(q, 1 - 1e-15), 1e-15)
        return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))

    def find_upper_bound(self, mean, count, t):
        if count == 0:
            return 1.0
        threshold = (np.log(t) + 3 * np.log(np.log(t) + 1)) / count
        low, high = mean, 1.0
        for _ in range(20):
            mid = (low + high) / 2
            if self.kl_divergence(mean, mid) <= threshold:
                low = mid
            else:
                high = mid
        return low

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.rewards[arm] / self.counts[arm]
                upper_bound = self.find_upper_bound(mean, self.counts[arm], self.t)
                ucb_values.append(upper_bound)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward

    def get_diagnostics(self):
        return {'algorithm': 'KL-UCB', 'd_estimate': 1.0}


class SlidingWindowUCB:
    

    def __init__(self, n_arms, window_size=100, **kwargs):
        self.n_arms = n_arms
        self.window_size = window_size
        self.history = [deque(maxlen=window_size) for _ in range(n_arms)]
        self.t = 0

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if len(self.history[arm]) == 0:
                ucb_values.append(float('inf'))
            else:
                mean = np.mean(self.history[arm])
                n = len(self.history[arm])
                confidence = np.sqrt(2 * np.log(self.t) / n)
                ucb_values.append(mean + confidence)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.history[arm].append(reward)

    def get_diagnostics(self):
        return {'algorithm': 'Sliding-Window-UCB', 'd_estimate': 1.0}


class DiscountedUCB:
    

    def __init__(self, n_arms, gamma=0.99, **kwargs):
        self.n_arms = n_arms
        self.gamma = gamma
        self.weighted_rewards = np.zeros(n_arms)
        self.weighted_counts = np.zeros(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.weighted_counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.weighted_rewards[arm] / self.weighted_counts[arm]
                confidence = np.sqrt(2 * np.log(self.t) / self.weighted_counts[arm])
                ucb_values.append(mean + confidence)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.weighted_rewards *= self.gamma
        self.weighted_counts *= self.gamma
        self.weighted_rewards[arm] += reward
        self.weighted_counts[arm] += 1

    def get_diagnostics(self):
        return {'algorithm': 'Discounted-UCB', 'd_estimate': 1.0}


class EXP3:
    

    def __init__(self, n_arms, gamma=None, **kwargs):
        self.n_arms = n_arms
        self.gamma = gamma if gamma else np.sqrt(2 * np.log(n_arms) / 1000)
        self.weights = np.ones(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1
        total_weight = np.sum(self.weights)
        probs = (1 - self.gamma) * (self.weights / total_weight) + self.gamma / self.n_arms
        return np.random.choice(self.n_arms, p=probs)

    def update(self, arm, reward):
        total_weight = np.sum(self.weights)
        prob = (1 - self.gamma) * (self.weights[arm] / total_weight) + self.gamma / self.n_arms
        estimated_reward = reward / prob
        self.weights[arm] *= np.exp(self.gamma * estimated_reward / self.n_arms)

    def get_diagnostics(self):
        return {'algorithm': 'EXP3', 'd_estimate': 1.0}


class AdaptiveGreedy:
    

    def __init__(self, n_arms, alpha=0.1, **kwargs):
        self.n_arms = n_arms
        self.alpha = alpha
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        means = np.divide(self.rewards, self.counts,
                          out=np.zeros_like(self.rewards), where=self.counts != 0)
        confidence = np.sqrt(2 * np.log(self.t) / np.maximum(self.counts, 1))

        best_arm = np.argmax(means)
        second_best = np.argsort(means)[-2]

        if means[best_arm] - confidence[best_arm] > means[second_best] + confidence[second_best]:
            return best_arm
        else:
            exploration_probs = confidence / np.sum(confidence)
            return np.random.choice(self.n_arms, p=exploration_probs)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward

    def get_diagnostics(self):
        return {'algorithm': 'Adaptive-Greedy', 'd_estimate': 1.0}


class GradientBandit:
    

    def __init__(self, n_arms, alpha=0.1, baseline=True, **kwargs):
        self.n_arms = n_arms
        self.alpha = alpha
        self.use_baseline = baseline
        self.preferences = np.zeros(n_arms)
        self.average_reward = 0.0
        self.reward_count = 0
        self.t = 0

    def get_probabilities(self):
        exp_prefs = np.exp(self.preferences)
        return exp_prefs / np.sum(exp_prefs)

    def select_arm(self):
        self.t += 1
        probabilities = self.get_probabilities()
        return np.random.choice(self.n_arms, p=probabilities)

    def update(self, arm, reward):
        self.reward_count += 1
        self.average_reward += (reward - self.average_reward) / self.reward_count

        baseline = self.average_reward if self.use_baseline else 0.0
        probabilities = self.get_probabilities()
        one_hot = np.zeros(self.n_arms)
        one_hot[arm] = 1.0
        self.preferences += self.alpha * (reward - baseline) * (one_hot - probabilities)

    def get_diagnostics(self):
        return {'algorithm': 'Gradient-Bandit', 'd_estimate': 1.0}



class ADWIN:
    

    def __init__(self, delta=0.001):
        self.delta = delta
        self.window = deque()
        self.total = 0.0

    def add_element(self, value):
        
        self.window.append(value)
        self.total += value

        if len(self.window) < 2:
            return False

        n = len(self.window)
        change_detected = False

        for i in range(1, n):
            n0, n1 = i, n - i
            if n0 > 0 and n1 > 0:
                sum0 = sum(list(self.window)[:i])
                sum1 = sum(list(self.window)[i:])
                mean0 = sum0 / n0
                mean1 = sum1 / n1
                m = 1.0 / (1.0 / n0 + 1.0 / n1)
                delta_prime = self.delta / n
                cutoff = np.sqrt((1.0 / (2 * m)) * np.log(2.0 / delta_prime))

                if abs(mean0 - mean1) >= cutoff:
                    for _ in range(i):
                        removed = self.window.popleft()
                        self.total -= removed
                    change_detected = True
                    break

        return change_detected


class ADS_ThompsonSampling:
    

    def __init__(self, n_arms, delta=0.001, **kwargs):
        self.n_arms = n_arms
        self.delta = delta
        self.detectors = [ADWIN(delta) for _ in range(n_arms)]
        self.alpha = np.ones(n_arms)
        self.beta = np.ones(n_arms)
        self.window = deque()
        self.t = 0
        self.change_detected = np.zeros(n_arms, dtype=bool)

    def select_arm(self):
        self.t += 1
        samples = [np.random.beta(self.alpha[i], self.beta[i]) for i in range(self.n_arms)]

        for arm in range(self.n_arms):
            if self.change_detected[arm]:
                samples[arm] *= 1.1
                self.change_detected[arm] = False

        return np.argmax(samples)

    def update(self, arm, reward):
        self.window.append((arm, reward))

        if reward == 1:
            self.alpha[arm] += 1
        else:
            self.beta[arm] += 1

        change_detected = self.detectors[arm].add_element(reward)
        if change_detected:
            self.change_detected[arm] = True

            recent_window = deque()
            window_list = list(self.window)

            if len(window_list) > 20:
                recent_data = window_list[-20:]
            else:
                recent_data = window_list

            for old_arm, old_reward in recent_data:
                if old_arm == arm:
                    continue
                recent_window.append((old_arm, old_reward))

            self.window = recent_window

            if self.alpha[arm] + self.beta[arm] > 10:
                total_samples = self.alpha[arm] + self.beta[arm]
                self.alpha[arm] = max(1, self.alpha[arm] // 3)
                self.beta[arm] = max(1, self.beta[arm] // 3)

    def get_diagnostics(self):
        return {'algorithm': 'ADS-Thompson-Sampling', 'd_estimate': 1.0}


class ADR_ThompsonSampling:
    

    def __init__(self, n_arms, delta=0.001, N=10, **kwargs):
        self.n_arms = n_arms
        self.delta = delta
        self.N = N
        self.reset()

    def reset(self):
        self.alpha = np.ones(self.n_arms)
        self.beta = np.ones(self.n_arms)
        self.detectors = [ADWIN(self.delta) for _ in range(self.n_arms)]
        self.t = 0
        self.current_block = 1
        self.monitoring_arm = None
        self.block_start = 0
        self.arm_counts_in_block = np.zeros(self.n_arms)
        self.window = deque()

    def select_arm(self):
        self.t += 1
        block_size = self.n_arms * self.N * (2 ** (self.current_block - 1))
        steps_in_block = self.t - self.block_start

        if (self.current_block >= 2 and
                steps_in_block % self.n_arms == 0 and
                self.monitoring_arm is not None):
            return self.monitoring_arm

        samples = [np.random.beta(self.alpha[i], self.beta[i]) for i in range(self.n_arms)]
        return np.argmax(samples)

    def update(self, arm, reward):
        self.window.append((arm, reward))
        self.arm_counts_in_block[arm] += 1

        if reward == 1:
            self.alpha[arm] += 1
        else:
            self.beta[arm] += 1

        change_detected = self.detectors[arm].add_element(reward)
        if change_detected:
            self.reset()
            return

        block_size = self.n_arms * self.N * (2 ** (self.current_block - 1))
        if self.t - self.block_start >= block_size:
            self.monitoring_arm = np.argmax(self.arm_counts_in_block)
            self.current_block += 1
            self.block_start = self.t
            self.arm_counts_in_block = np.zeros(self.n_arms)

    def get_diagnostics(self):
        return {'algorithm': 'ADR-Thompson-Sampling', 'd_estimate': 1.0}


class ADS_KLUCB:
    

    def __init__(self, n_arms, delta=0.001, **kwargs):
        self.n_arms = n_arms
        self.delta = delta
        self.detectors = [ADWIN(delta) for _ in range(n_arms)]
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.window = deque()
        self.t = 0
        self.change_detected = np.zeros(n_arms, dtype=bool)

    def kl_divergence(self, p, q):
        p = max(min(p, 1 - 1e-15), 1e-15)
        q = max(min(q, 1 - 1e-15), 1e-15)
        return p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))

    def find_upper_bound(self, mean, count, t):
        if count == 0:
            return 1.0
        threshold = (np.log(t) + 3 * np.log(np.log(t) + 1)) / count
        low, high = mean, 1.0
        for _ in range(20):
            mid = (low + high) / 2
            if self.kl_divergence(mean, mid) <= threshold:
                low = mid
            else:
                high = mid
        return low

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.rewards[arm] / self.counts[arm]
                upper_bound = self.find_upper_bound(mean, self.counts[arm], self.t)

                if self.change_detected[arm]:
                    upper_bound *= 1.2
                    self.change_detected[arm] = False

                ucb_values.append(upper_bound)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.window.append((arm, reward))
        self.counts[arm] += 1
        self.rewards[arm] += reward

        change_detected = self.detectors[arm].add_element(reward)
        if change_detected:
            self.change_detected[arm] = True

            recent_window = deque()
            window_list = list(self.window)

            if len(window_list) > 20:
                recent_data = window_list[-20:]
            else:
                recent_data = window_list

            for old_arm, old_reward in recent_data:
                if old_arm == arm:
                    continue
                recent_window.append((old_arm, old_reward))

            self.window = recent_window

            if self.counts[arm] > 10:
                self.counts[arm] = max(1, self.counts[arm] // 3)
                current_mean = self.rewards[arm] / (self.counts[arm] * 3)
                self.rewards[arm] = current_mean * self.counts[arm]

    def get_diagnostics(self):
        return {'algorithm': 'ADS-KL-UCB', 'd_estimate': 1.0}


class CUSUM:
    

    def __init__(self, n_arms, threshold=2.0, **kwargs):
        self.n_arms = n_arms
        self.threshold = threshold
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.cusum_stats = np.zeros(n_arms)
        self.reference_means = np.zeros(n_arms)
        self.t = 0
        self.change_detected = np.zeros(n_arms, dtype=bool)

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.rewards[arm] / self.counts[arm]
                confidence = np.sqrt(2 * np.log(self.t) / self.counts[arm])

                if self.cusum_stats[arm] > self.threshold or self.change_detected[arm]:
                    confidence *= 1.5
                    self.change_detected[arm] = False

                ucb_values.append(mean + confidence)
        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward

        current_mean = self.rewards[arm] / self.counts[arm]

        if self.counts[arm] > 5:
            deviation = reward - self.reference_means[arm]
            self.cusum_stats[arm] = max(0, self.cusum_stats[arm] + deviation - 0.01)

            if self.cusum_stats[arm] > self.threshold:
                self.reference_means[arm] = current_mean
                self.cusum_stats[arm] = 0
                self.change_detected[arm] = True
                self.counts[arm] = max(1, self.counts[arm] // 2)
                self.rewards[arm] = current_mean * self.counts[arm]
        else:
            self.reference_means[arm] = current_mean

    def get_diagnostics(self):
        return {'algorithm': 'CUSUM-UCB', 'd_estimate': 1.0}



class BetaSWTS:
    

    def __init__(self, n_arms, window_size=100, **kwargs):
        self.n_arms = n_arms
        self.window_size = window_size
        self.reward_windows = [deque(maxlen=window_size) for _ in range(n_arms)]
        self.selection_windows = [deque(maxlen=window_size) for _ in range(n_arms)]
        self.t = 0

    def select_arm(self):
        theta_samples = np.zeros(self.n_arms)

        for arm in range(self.n_arms):
            arm_selections = list(self.selection_windows[arm])
            arm_rewards = list(self.reward_windows[arm])

            if len(arm_selections) == 0:
                alpha = 1
                beta = 1
            else:
                successes = sum(arm_rewards)
                failures = len(arm_rewards) - successes
                alpha = 1 + successes
                beta = 1 + failures

            theta_samples[arm] = np.random.beta(alpha, beta)

        return np.argmax(theta_samples)

    def update(self, arm, reward):
        self.t += 1
        self.selection_windows[arm].append(arm)
        self.reward_windows[arm].append(reward)

    def get_diagnostics(self):
        return {'algorithm': 'Beta-SWTS', 'd_estimate': 1.0}


class GammaSWGTS:
    

    def __init__(self, n_arms, window_size=100, gamma=1.0, **kwargs):
        self.n_arms = n_arms
        self.window_size = window_size
        self.gamma = gamma
        self.reward_windows = [deque(maxlen=window_size) for _ in range(n_arms)]
        self.initialized = False
        self.t = 0

    def select_arm(self):
        if not self.initialized:
            for arm in range(self.n_arms):
                if len(self.reward_windows[arm]) == 0:
                    return arm
            self.initialized = True

        for arm in range(self.n_arms):
            if len(self.reward_windows[arm]) == 0:
                return arm

        theta_samples = np.zeros(self.n_arms)
        for arm in range(self.n_arms):
            rewards = list(self.reward_windows[arm])
            n_samples = len(rewards)

            if n_samples > 0:
                sample_mean = np.mean(rewards)
                posterior_mean = sample_mean
                posterior_var = 1.0 / (self.gamma * n_samples)
                posterior_std = np.sqrt(posterior_var)
                theta_samples[arm] = np.random.normal(posterior_mean, posterior_std)
            else:
                theta_samples[arm] = np.random.normal(0, 1)

        return np.argmax(theta_samples)

    def update(self, arm, reward):
        self.t += 1
        self.reward_windows[arm].append(reward)

    def get_diagnostics(self):
        return {'algorithm': 'Gamma-SWGTS', 'd_estimate': 1.0}



class AdaptiveEpsilonGreedy(BaseBanditAlgorithm):
    

    def __init__(self, n_arms, epsilon_init=0.5, epsilon_min=0.01, decay_rate=0.995, **kwargs):
        super().__init__(n_arms, 'Adaptive-Epsilon-Greedy', **kwargs)
        self.epsilon_init = epsilon_init
        self.epsilon_min = epsilon_min
        self.decay_rate = decay_rate
        self.epsilon = epsilon_init
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)

    def select_arm(self):
        self.t += 1
        self.epsilon = max(self.epsilon_min, self.epsilon * self.decay_rate)

        if np.random.random() < self.epsilon:
            return np.random.randint(0, self.n_arms)
        else:
            estimates = np.divide(self.rewards, self.counts,
                                  out=np.zeros_like(self.rewards), where=self.counts != 0)
            return np.argmax(estimates)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward


class SoftmaxBandit:
    

    def __init__(self, n_arms, temperature=1.0, **kwargs):
        self.n_arms = n_arms
        self.temperature = temperature
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1
        estimates = np.divide(self.rewards, self.counts,
                              out=np.zeros_like(self.rewards), where=self.counts != 0)

        exp_values = np.exp(estimates / self.temperature)
        probabilities = exp_values / np.sum(exp_values)
        return np.random.choice(self.n_arms, p=probabilities)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward

    def get_diagnostics(self):
        return {'algorithm': 'Softmax-Bandit', 'd_estimate': 1.0}


class NeuralUCB:
    

    def __init__(self, n_arms, context_dim=4, hidden_dim=64, lr=0.01, lambda_param=1.0,
                 nu=0.1, device=None, **kwargs):
        self.n_arms = n_arms
        self.context_dim = context_dim
        self.hidden_dim = hidden_dim
        self.lr = lr
        self.lambda_param = lambda_param
        self.nu = nu
        self.device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.network = self._create_network().to(self.device)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)

        self.contexts = []
        self.arms = []
        self.rewards = []

        self.t = 0
        self.grad_norm = 0.0

    def _create_network(self):
        
        return torch.nn.Sequential(
            torch.nn.Linear(self.context_dim, self.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_dim, self.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_dim, 1)
        )

    def _get_context(self):
        
        context = np.zeros(self.context_dim)
        context[0] = self.t / 1000.0
        context[1] = np.sin(self.t * 0.1)
        context[2] = np.cos(self.t * 0.1)
        context[3] = np.random.normal(0, 1)
        return context

    def _compute_ucb(self, context, arm):
        
        context_tensor = torch.FloatTensor(context).unsqueeze(0).to(self.device)

        with torch.no_grad():
            pred_reward = self.network(context_tensor).item()

        if len(self.contexts) > 0:
            recent_data = min(100, len(self.contexts))
            contexts_batch = torch.FloatTensor(self.contexts[-recent_data:]).to(self.device)

            self.network.zero_grad()
            preds = self.network(contexts_batch)
            loss = torch.mean(preds ** 2)
            loss.backward()

            grad_norm = 0.0
            for param in self.network.parameters():
                if param.grad is not None:
                    grad_norm += torch.norm(param.grad) ** 2
            grad_norm = torch.sqrt(grad_norm).item()
        else:
            grad_norm = 1.0

        uncertainty = self.nu * grad_norm * np.sqrt(self.lambda_param)
        ucb_value = pred_reward + uncertainty

        return ucb_value, pred_reward, uncertainty

    def select_arm(self):
        
        self.t += 1
        context = self._get_context()

        ucb_values = []
        for arm in range(self.n_arms):
            ucb_val, pred, uncert = self._compute_ucb(context, arm)
            ucb_values.append(ucb_val)

        selected_arm = np.argmax(ucb_values)

        self.contexts.append(context)
        self.arms.append(selected_arm)

        return selected_arm

    def update(self, arm, reward):
        
        self.rewards.append(reward)

        if len(self.contexts) >= 10:
            self._train_network()

    def _train_network(self):
        
        if len(self.contexts) < 10:
            return

        contexts_tensor = torch.FloatTensor(self.contexts).to(self.device)
        rewards_tensor = torch.FloatTensor(self.rewards).to(self.device)

        self.optimizer.zero_grad()
        preds = self.network(contexts_tensor).squeeze()
        loss = torch.nn.MSELoss()(preds, rewards_tensor)
        loss.backward()
        self.optimizer.step()

    def get_diagnostics(self):
        
        return {
            'algorithm': 'Neural-UCB',
            'd_estimate': 1.0,
            'network_params': sum(p.numel() for p in self.network.parameters()),
            'training_samples': len(self.contexts)
        }


class AdaptiveWindowUCB:
    

    def __init__(self, n_arms, min_window=10, max_window=200, **kwargs):
        self.n_arms = n_arms
        self.min_window = min_window
        self.max_window = max_window
        self.windows = [deque() for _ in range(n_arms)]
        self.window_sizes = np.ones(n_arms) * min_window
        self.t = 0

    def _adapt_window_size(self, arm):
        if len(self.windows[arm]) < self.min_window:
            return

        recent_rewards = list(self.windows[arm])[-self.min_window:]
        variance = np.var(recent_rewards)

        if variance > 0.1:
            self.window_sizes[arm] = max(self.min_window, self.window_sizes[arm] * 0.9)
        else:
            self.window_sizes[arm] = min(self.max_window, self.window_sizes[arm] * 1.1)

    def select_arm(self):
        self.t += 1
        ucb_values = np.zeros(self.n_arms)

        for i in range(self.n_arms):
            if len(self.windows[i]) == 0:
                ucb_values[i] = float('inf')
            else:
                window_size = int(self.window_sizes[i])
                recent_rewards = list(self.windows[i])[-window_size:]
                estimates = np.mean(recent_rewards)
                n_i = len(recent_rewards)
                confidence = np.sqrt(2 * np.log(self.t) / n_i)
                ucb_values[i] = estimates + confidence

        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.windows[arm].append(reward)
        max_size = int(self.window_sizes[arm])
        while len(self.windows[arm]) > max_size:
            self.windows[arm].popleft()

        if self.t % 10 == 0:
            self._adapt_window_size(arm)

    def get_diagnostics(self):
        return {'algorithm': 'Adaptive-Window-UCB', 'd_estimate': 1.0}


class ChangePointDetectionMAB:
    

    def __init__(self, n_arms, detection_threshold=0.1, window_size=50, **kwargs):
        self.n_arms = n_arms
        self.detection_threshold = detection_threshold
        self.window_size = window_size
        self.arm_histories = [deque(maxlen=window_size) for _ in range(n_arms)]
        self.change_points = [[] for _ in range(n_arms)]
        self.current_estimates = np.ones(n_arms) * 0.5
        self.t = 0

    def _detect_change_point(self, arm):
        history = list(self.arm_histories[arm])
        if len(history) < 20:
            return False

        mid = len(history) // 2
        mean1 = np.mean(history[:mid])
        mean2 = np.mean(history[mid:])
        return abs(mean1 - mean2) > self.detection_threshold

    def select_arm(self):
        self.t += 1

        for arm in range(self.n_arms):
            if self._detect_change_point(arm):
                self.change_points[arm].append(self.t)
                recent_data = list(self.arm_histories[arm])[-10:]
                self.current_estimates[arm] = np.mean(recent_data) if recent_data else 0.5
            elif len(self.arm_histories[arm]) > 0:
                self.current_estimates[arm] = np.mean(self.arm_histories[arm])

        ucb_values = np.zeros(self.n_arms)
        for i in range(self.n_arms):
            if len(self.arm_histories[i]) == 0:
                ucb_values[i] = float('inf')
            else:
                n_i = len(self.arm_histories[i])
                confidence = np.sqrt(2 * np.log(self.t) / n_i)
                ucb_values[i] = self.current_estimates[i] + confidence

        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.arm_histories[arm].append(reward)

    def get_diagnostics(self):
        return {'algorithm': 'Change-Point-Detection-MAB', 'd_estimate': 1.0}


class VarianceAwareUCB:
    

    def __init__(self, n_arms, c=2.0, **kwargs):
        self.n_arms = n_arms
        self.c = c
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.squared_rewards = np.zeros(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1
        if self.t <= self.n_arms:
            return (self.t - 1) % self.n_arms

        ucb_values = []
        for arm in range(self.n_arms):
            if self.counts[arm] == 0:
                ucb_values.append(float('inf'))
            else:
                mean = self.rewards[arm] / self.counts[arm]
                variance = (self.squared_rewards[arm] / self.counts[arm]) - mean ** 2
                variance = max(0, variance)
                confidence = self.c * np.sqrt((variance + 0.1) * np.log(self.t) / self.counts[arm])
                ucb_values.append(mean + confidence)

        return np.argmax(ucb_values)

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward
        self.squared_rewards[arm] += reward ** 2

    def get_diagnostics(self):
        return {'algorithm': 'Variance-Aware-UCB', 'd_estimate': 1.0}


class OptimisticInitialization:
    

    def __init__(self, n_arms, init_value=1.0, **kwargs):
        self.n_arms = n_arms
        self.init_value = init_value
        self.value_estimates = np.full(n_arms, init_value)
        self.counts = np.zeros(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1
        return np.argmax(self.value_estimates)

    def update(self, arm, reward):
        self.counts[arm] += 1
        alpha = 1.0 / self.counts[arm]
        self.value_estimates[arm] = (1 - alpha) * self.value_estimates[arm] + alpha * reward

    def get_diagnostics(self):
        return {'algorithm': 'Optimistic-Initialization', 'd_estimate': 1.0}


class ExploreThenCommit:
    

    def __init__(self, n_arms, exploration_rounds=50, **kwargs):
        self.n_arms = n_arms
        self.exploration_rounds = exploration_rounds
        self.exploration_per_arm = exploration_rounds // n_arms
        self.committed_arm = None
        self.counts = np.zeros(n_arms)
        self.rewards = np.zeros(n_arms)
        self.t = 0

    def select_arm(self):
        self.t += 1

        if self.t <= self.exploration_rounds:
            return (self.t - 1) % self.n_arms
        else:
            if self.committed_arm is None:
                estimates = np.divide(self.rewards, self.counts,
                                      out=np.zeros_like(self.rewards), where=self.counts != 0)
                self.committed_arm = np.argmax(estimates)
            return self.committed_arm

    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward

    def get_diagnostics(self):
        return {'algorithm': 'Explore-Then-Commit', 'd_estimate': 1.0}



class BanditEvaluator:
    

    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", reward_type=None):
        if reward_type not in ["bernoulli", "continuous"]:
            raise ValueError("reward_type参数必须为'bernoulli'或'continuous'，且必须显式指定！")
        self.device = device
        self.reward_type = reward_type
        self.environments = self._create_showcase_environments()
        self.algorithms = self._setup_algorithms()

    def _create_showcase_environments(self):
        
        return {
            'smooth_nonstationary': lambda K, T: smooth_nonstationary(K, T, d=1.2, noise_level=0.15),
            'abrupt_changes': lambda K, T: abrupt_change_environment(K, T, d=0.8, noise_strength=0.1),
            'gradual_divergence': lambda K, T: gradually_diverging(K, T, d=0.8, max_divergence=0.2),
            'high_frequency': lambda K, T: high_frequency_changes(K, T, d=0.6, oscillation_strength=0.2),
            'competitive_balanced': lambda K, T: competitive_balanced_environment(K, T, d=0.7,
                                                                                  competition_strength=0.15),
            'static_traditional': lambda K, T: static_traditional_environment(K, T, d=0.0, noise_level=0.02),
        }

    def _initialize_algorithm(self, algorithm_name, K, true_d=None):
        
        if algorithm_name == 'ImprovedBanditDiff':
            return self.algorithms[algorithm_name](K)
        elif algorithm_name in ['BalancedBanditDiff_perfectD', 'ImprovedBanditDiff_PerfectD']:
            return self.algorithms[algorithm_name](K, true_d)
        else:
            return self.algorithms[algorithm_name](K)

    def run_single_experiment(self, env_name, algorithm_name, K=4, T=1000, seed=42):
        
        np.random.seed(seed)
        torch.manual_seed(seed)

        env_means, true_d = self.environments[env_name](K, T)

        algorithm = self._initialize_algorithm(algorithm_name, K, true_d)

        cumulative_regret = []
        instantaneous_regret = []
        arm_selections = []
        d_estimates = []

        total_regret = 0

        for t in range(T):
            current_means = env_means[:, t]
            optimal_arm = np.argmax(current_means)
            optimal_reward = current_means[optimal_arm]

            selected_arm = algorithm.select_arm()
            arm_selections.append(selected_arm)

            if self.reward_type == "bernoulli":
                reward = np.random.binomial(1, current_means[selected_arm])
            elif self.reward_type == "continuous":
                reward = current_means[selected_arm]
            else:
                raise ValueError(f"未知的reward_type: {self.reward_type}")

            instant_regret = optimal_reward - current_means[selected_arm]
            total_regret += instant_regret

            instantaneous_regret.append(instant_regret)
            cumulative_regret.append(total_regret)

            algorithm.update(selected_arm, reward)

            diag = algorithm.get_diagnostics()
            d_estimates.append(diag.get('d_estimate', 1.0))

        return {
            'cumulative_regret': np.array(cumulative_regret),
            'instantaneous_regret': np.array(instantaneous_regret),
            'arm_selections': arm_selections,
            'd_estimates': d_estimates,
            'final_regret': total_regret,
            'env_name': env_name,
            'algorithm': algorithm_name,
            'true_d': true_d
        }

    def run_comprehensive_experiments(self, K=4, T=1000, n_runs=10):
        
        print("🚀 Running comprehensive experimental evaluation")
        print("=" * 60)

        results = defaultdict(lambda: defaultdict(list))

        total_experiments = len(self.environments) * len(self.algorithms) * n_runs
        completed = 0

        for env_name in self.environments.keys():
            print(f"\n📊 Environment: {env_name}")

            for alg_name in self.algorithms.keys():
                print(f"  🔬 Algorithm: {alg_name}")

                for run in range(n_runs):
                    try:
                        result = self.run_single_experiment(
                            env_name, alg_name, K, T, seed=42 + run
                        )
                        results[env_name][alg_name].append(result)
                        completed += 1

                        if run % 3 == 0:
                            print(f"    runs {run + 1}/{n_runs} completed", end="")
                    except Exception as e:
                        print(f"    Error: {e}")
                        continue

                print(f" ✅ ({completed}/{total_experiments})")

        print(f"\n✅ completed {len(self.environments)} environments × {len(self.algorithms)} algorithms × {n_runs} runs")
        return results

    def _calculate_regret_statistics(self, final_regrets):
        
        return {
            'mean_final_regret': float(np.mean(final_regrets)),
            'std_final_regret': float(np.std(final_regrets)),
            'median_final_regret': float(np.median(final_regrets)),
            'min_final_regret': float(np.min(final_regrets)),
            'max_final_regret': float(np.max(final_regrets)),
            'q25_final_regret': float(np.percentile(final_regrets, 25)),
            'q75_final_regret': float(np.percentile(final_regrets, 75))
        }

    def calculate_statistics(self, results):
        
        stats = {}

        for env_name in results.keys():
            stats[env_name] = {}

            for alg_name in results[env_name].keys():
                runs_data = results[env_name][alg_name]

                if len(runs_data) == 0:
                    continue

                final_regrets = [run['final_regret'] for run in runs_data]

                stats[env_name][alg_name] = self._calculate_regret_statistics(final_regrets)
                stats[env_name][alg_name]['n_runs'] = len(runs_data)

                regret_trajectories = [run['cumulative_regret'] for run in runs_data]
                if regret_trajectories:
                    regret_array = np.array(regret_trajectories)
                    stats[env_name][alg_name]['mean_regret_trajectory'] = regret_array.mean(axis=0).tolist()
                    stats[env_name][alg_name]['std_regret_trajectory'] = regret_array.std(axis=0).tolist()

        return stats

    def save_results(self, results, stats, save_dir="results"):
        
        os.makedirs(save_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        txt_filename = f"{save_dir}/raw_results_{timestamp}.txt"
        self._save_results_as_txt(results, txt_filename)

        with open(f"{save_dir}/statistics_{timestamp}.json", 'w') as f:
            json.dump(stats, f, indent=2)

        self.create_summary_csv(stats, f"{save_dir}/summary_{timestamp}.csv")

        print(f"📁 Results saved to {save_dir}/")
        print(f"   • raw_results_{timestamp}.txt (raw data)")
        print(f"   • statistics_{timestamp}.json (statistics)")
        print(f"   • summary_{timestamp}.csv (summary table)")

        return timestamp

    def _save_results_as_txt(self, results, filename):
        
        with open(filename, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write("多臂老虎机algorithms实验结果\n")
            f.write("=" * 80 + "\n\n")

            for env_name in results.keys():
                f.write(f"environments: {env_name}\n")
                f.write("-" * 60 + "\n")

                for alg_name in results[env_name].keys():
                    f.write(f"\nalgorithms: {alg_name}\n")
                    f.write("." * 40 + "\n")

                    runs_data = results[env_name][alg_name]
                    f.write(f"runstimes数: {len(runs_data)}\n\n")

                    for run_idx, run_result in enumerate(runs_data):
                        f.write(f"  runs {run_idx + 1}:\n")
                        f.write(f"    Final cumulative regret: {run_result['final_regret']:.4f}\n")
                        f.write(f"    True d value: {run_result['true_d']:.4f}\n")

                        regret_trajectory = run_result['cumulative_regret']
                        if len(regret_trajectory) > 100:
                            step_size = len(regret_trajectory) // 100
                            sampled_regret = regret_trajectory[::step_size]
                            f.write(f"    Cumulative regret trajectory (sampled): {[f'{x:.2f}' for x in sampled_regret[:20]]}...\n")
                        else:
                            f.write(f"    Cumulative regret trajectory: {[f'{x:.2f}' for x in regret_trajectory[:20]]}...\n")

                        arm_counts = {}
                        for arm in run_result['arm_selections']:
                            arm_counts[arm] = arm_counts.get(arm, 0) + 1
                        f.write(f"    Arm selection statistics: {arm_counts}\n")

                        d_estimates = run_result['d_estimates']
                        if d_estimates and not all(d == 1.0 for d in d_estimates):
                            final_d_estimate = d_estimates[-1] if d_estimates else 1.0
                            f.write(f"    Final d estimate: {final_d_estimate:.4f}\n")

                        f.write("\n")

                    f.write("\n")

                f.write("\n" + "=" * 80 + "\n\n")

    def create_summary_csv(self, stats, filename):
        
        summary_data = []

        for env_name in stats.keys():
            for alg_name in stats[env_name].keys():
                alg_stats = stats[env_name][alg_name]
                summary_data.append({
                    'Environment': env_name,
                    'Algorithm': alg_name,
                    'Mean_Final_Regret': alg_stats['mean_final_regret'],
                    'Std_Final_Regret': alg_stats['std_final_regret'],
                    'Median_Final_Regret': alg_stats['median_final_regret'],
                    'Min_Final_Regret': alg_stats['min_final_regret'],
                    'Max_Final_Regret': alg_stats['max_final_regret'],
                    'Q25_Final_Regret': alg_stats['q25_final_regret'],
                    'Q75_Final_Regret': alg_stats['q75_final_regret'],
                    'N_Runs': alg_stats['n_runs']
                })

        df = pd.DataFrame(summary_data)
        df.to_csv(filename, index=False)

    def print_performance_summary(self, stats):
        
        print(f"\n📋 Performance Summary:")
        print("=" * 85)

        for env_name in stats.keys():
            print(f"\n🎯 {env_name}:")

            env_results = []
            for alg_name in stats[env_name].keys():
                mean_regret = stats[env_name][alg_name]['mean_final_regret']
                std_regret = stats[env_name][alg_name]['std_final_regret']
                env_results.append((alg_name, mean_regret, std_regret))

            env_results.sort(key=lambda x: x[1])

            for rank, (alg_name, mean_regret, std_regret) in enumerate(env_results, 1):
                marker = "🏆" if rank == 1 else f"{rank:2d}"
                print(f"  {marker} {alg_name:<25}: {mean_regret:8.1f} ± {std_regret:6.1f}")

        print(f"\n🏆 Overall Winning Statistics:")
        winner_counts = defaultdict(int)

        for env_name in stats.keys():
            env_results = [(alg, stats[env_name][alg]['mean_final_regret'])
                           for alg in stats[env_name].keys()]
            winner = min(env_results, key=lambda x: x[1])[0]
            winner_counts[winner] += 1

        for alg, count in sorted(winner_counts.items(), key=lambda x: x[1], reverse=True):
            percentage = count / len(stats.keys()) * 100
            print(f"  {alg}: {count}/{len(stats.keys())} environments获胜 ({percentage:.1f}%)")

    def run_evaluation(self, K=4, T=1000, n_runs=15):
        
        print("🎯 开始多臂老虎机algorithms评估")
        print("=" * 70)
        print(f"⚙️ Configuration:")
        print(f"   • Number of arms (K): {K}")
        print(f"   • Time steps (T): {T}")
        print(f"   • runstimes数: {n_runs}")
        print(f"   • environments数: {len(self.environments)}")
        print(f"   • algorithms数: {len(self.algorithms)}")

        results = self.run_comprehensive_experiments(K=K, T=T, n_runs=n_runs)

        print(f"\n📊 Calculating statistics...")
        stats = self.calculate_statistics(results)

        print(f"\n💾 Saving results...")
        timestamp = self.save_results(results, stats)

        self.print_performance_summary(stats)

        print(f"\n✅ 评估completed!")
        print(f"📈 Key findings:")
        print(f"   • Tested {len(self.algorithms)} 种algorithms")
        print(f"   • Evaluated in {len(self.environments)} 种environments中进行评估")
        print(f"   • 每个Configurationruns {n_runs} times")
        print(f"   • 结果保存Evaluated in results/ directory")

        return results, stats, timestamp

    def _setup_algorithms(self):
        
        return {
            'BalancedBanditDiff_perfectD': lambda K, d_true: BalancedBanditDiffUCBPerfectD(
                n_arms=K, d_true=d_true, device=self.device,
                use_diffusion=True,
                use_adaptive_window=True,
                use_theory_guided=True,
                use_virtual_data=True
            ),

            'UCB1': UCB1,
            'MOSS': lambda K: MOSS(K, horizon=1000),
            'KL-UCB': KLUCB,
            'Variance-Aware-UCB': VarianceAwareUCB,

            'Sliding-Window-UCB': lambda K: SlidingWindowUCB(K, window_size=150),
            'Discounted-UCB': lambda K: DiscountedUCB(K, gamma=0.995),
            'Adaptive-Window-UCB': AdaptiveWindowUCB,

            'Beta-SWTS': BetaSWTS,

            'Change-Point-Detection-MAB': ChangePointDetectionMAB
        }


def main():
    
    print("📊 多臂老虎机algorithms性能评估系统")
    print("=" * 60)

    reward_type = "continuous"

    evaluator = BanditEvaluator(reward_type=reward_type)

    results, stats, timestamp = evaluator.run_evaluation(
        K=4,
        T=1000,
        n_runs=10
    )

    print(f"\n🎯 评估completed! 时间戳: {timestamp}")
    print(f"📄 Suggested follow-up analysis:")
    print(f"   • Check results/summary_{timestamp}.csv for quick overview")
    print(f"   • Load results/statistics_{timestamp}.json for detailed analysis")
    print(f"   • Use results/raw_results_{timestamp}.pkl for custom analysis")


if __name__ == "__main__":
    main()