import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def static_traditional_environment(K, T, d=0.0, noise_level=0.02):
    """静态传统多臂老虎机环境 - 固定奖励分布"""
    means = np.zeros((K, T))

    # 固定的arm奖励均值，有明显的差异
    arm_means = [0.35, 0.75, 0.45, 0.60]  # arm 1 是最优的

    for i in range(K):
        base_mean = arm_means[i]

        for t in range(T):
            # 只添加小量的高斯噪声来模拟测量误差
            noise = noise_level * np.random.normal(0, 1)

            # 确保在0.1-0.9范围内
            means[i, t] = np.clip(base_mean + noise, 0.1, 0.9)

    return means, d

def smooth_nonstationary(K, T, d=1.0, noise_level=0.15):
    """平滑非平稳环境 - 真正利用0.1-0.9范围"""
    means = np.zeros((K, T))

    # 基础奖励分布在更大范围内
    base_rewards = [0.3, 0.7, 0.4, 0.6]  # 不同arm有不同的基础水平

    for i in range(K):
        base = base_rewards[i]

        for t in range(T):
            # 大幅度的时间衰减噪声
            temporal_change = noise_level * (t + 1) ** (-d)

            # 多频率组合，大幅增加振幅
            freq1 = 2 * np.pi * t / (120 + 15 * i)
            freq2 = 2 * np.pi * t / (200 + 25 * i)
            freq3 = 2 * np.pi * t / (80 + 10 * i)

            # 大幅增加周期性变化的振幅
            periodic_change = 0.25 * (
                    0.6 * np.sin(freq1) +
                    0.3 * np.sin(freq2) +
                    0.1 * np.cos(freq3)
            )

            # 全局趋势，让不同arm轮流占优
            global_trend = 0.15 * np.sin(2 * np.pi * t / 300 + i * np.pi / 2)

            # 大幅增加随机噪声
            random_noise = 0.1 * np.random.normal(0, 1) * (t + 1) ** (-d / 2)

            means[i, t] = np.clip(base + periodic_change + global_trend + random_noise, 0.1, 0.9)

    return means, d


def abrupt_change_environment(K, T, d=0.8, noise_strength=0.1):
    """突变环境 - 真正利用0.1-0.9范围"""
    means = np.zeros((K, T))

    # 更频繁的变化点
    change_points = [T // 5, 2 * T // 5, 3 * T // 5, 4 * T // 5]

    # 每个时期arm的基础奖励差异更大
    phases = [
        [0.7, 0.3, 0.5, 0.4],  # 显著差异
        [0.4, 0.6, 0.3, 0.7],
        [0.6, 0.4, 0.8, 0.2],
        [0.3, 0.7, 0.4, 0.6],
        [0.8, 0.5, 0.3, 0.7]
    ]

    for i in range(K):
        current_phase = 0
        for t in range(T):
            # 检查是否需要切换时期
            while current_phase < len(change_points) and t >= change_points[current_phase]:
                current_phase += 1

            # 当前时期的基础奖励
            base_reward = phases[current_phase][i]

            # 大幅增加噪声
            noise = noise_strength * (t + 1) ** (-d) * np.sin(t / 30 + i * np.pi / 3)
            random_noise = 0.08 * np.random.normal(0, 1)

            means[i, t] = np.clip(base_reward + noise + random_noise, 0.1, 0.9)

    return means, d


def gradually_diverging(K, T, d=0.8, max_divergence=0.2):
    """逐渐分化环境 - 真正利用0.1-0.9范围"""
    means = np.zeros((K, T))

    # 起始点有一定差异
    base_starts = [0.45, 0.55, 0.4, 0.6]

    for i in range(K):
        base_start = base_starts[i]

        for t in range(T):
            # 逐渐分化过程
            time_factor = (t / T) ** d

            # 分化方向和终点
            if i == 0:
                target_direction = -0.25  # 向低值分化
            elif i == 1:
                target_direction = 0.25  # 向高值分化
            elif i == 2:
                target_direction = -0.3  # 向更低值分化
            else:
                target_direction = 0.3  # 向更高值分化

            divergence = target_direction * time_factor

            # 大幅增加周期性噪声
            noise1 = 0.15 * np.sin(2 * np.pi * t / (90 + 15 * i))
            noise2 = 0.1 * np.cos(2 * np.pi * t / (120 + 20 * i))
            noise3 = 0.05 * np.sin(2 * np.pi * t / (60 + 10 * i))

            # 随机噪声
            random_noise = 0.08 * np.random.normal(0, 1)

            means[i, t] = np.clip(base_start + divergence + noise1 + noise2 + noise3 + random_noise, 0.1, 0.9)

    return means, d


def high_frequency_changes(K, T, d=0.6, oscillation_strength=0.2):
    """高频变化环境 - 真正利用0.1-0.9范围"""
    means = np.zeros((K, T))

    # 基础奖励分布在不同水平
    base_rewards = [0.4, 0.6, 0.3, 0.7]

    for i in range(K):
        base = base_rewards[i]

        for t in range(T):
            # 高频变化，大幅增加振幅
            decay_factor = (t + 1) ** (-d)

            # 多个频率组合
            freq1 = 2 * np.pi * t / (25 + 5 * i)
            freq2 = 2 * np.pi * t / (45 + 8 * i)
            freq3 = 2 * np.pi * t / (70 + 12 * i)

            high_freq_change = oscillation_strength * (
                    0.5 * np.sin(freq1) +
                    0.3 * np.sin(freq2) +
                    0.2 * np.cos(freq3)
            )

            # 大幅增加随机噪声
            random_noise = 0.1 * np.random.normal(0, 1) * decay_factor

            means[i, t] = np.clip(base + high_freq_change + random_noise, 0.1, 0.9)

    return means, d


def competitive_balanced_environment(K, T, d=0.7, competition_strength=0.15):
    """竞争平衡环境 - 真正利用0.1-0.9范围"""
    means = np.zeros((K, T))

    # 基础奖励有差异
    base_rewards = [0.45, 0.55, 0.4, 0.6]

    for i in range(K):
        base = base_rewards[i]

        for t in range(T):
            # 轮流优势：每个arm在不同时期有显著优势
            advantage_cycle = competition_strength * np.sin(2 * np.pi * t / (80 + 15 * i) + i * np.pi / 2)

            # 全局波动
            global_wave = 0.08 * np.sin(2 * np.pi * t / 200)

            # 局部扰动
            local_noise = 0.12 * (t + 1) ** (-d) * np.sin(t / 40 + i * np.pi / 3)

            # 大幅增加随机噪声
            random_noise = 0.08 * np.random.normal(0, 1)

            means[i, t] = np.clip(base + advantage_cycle + global_wave + local_noise + random_noise, 0.1, 0.9)

    return means, d


def analyze_environment_difficulty(means, K, T):
    """分析环境的难度指标"""
    gaps = []
    optimal_switches = 0
    overlap_ratio = 0

    for t in range(T):
        sorted_means = np.sort(means[:, t])
        gap = sorted_means[-1] - sorted_means[-2]
        gaps.append(gap)

        # 计算重叠比例（基于新的范围0.1-0.9）
        min_max = np.max(means[:, t])
        max_min = np.min(means[:, t])
        overlap = 1 - (min_max - max_min) / (0.9 - 0.1)  # 标准化重叠
        overlap_ratio += overlap

        if t > 0:
            if np.argmax(means[:, t]) != np.argmax(means[:, t - 1]):
                optimal_switches += 1

    avg_gap = np.mean(gaps)
    min_gap = np.min(gaps)
    avg_overlap = overlap_ratio / T

    return {
        'avg_gap': avg_gap,
        'min_gap': min_gap,
        'max_gap': np.max(gaps),
        'optimal_switches': optimal_switches,
        'avg_overlap': avg_overlap,
        'difficulty_score': 1 / (avg_gap + 1e-6)
    }


def create_highly_overlapping_visualization():
    """创建高度重叠的环境可视化"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Highly Overlapping Non-stationary Bandit Environments (0.1-0.9 Range)', fontsize=16,
                 fontweight='bold')

    K, T = 4, 1000
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

    # 设置随机种子确保可复现
    np.random.seed(42)

    environments = [
        ("Smooth Non-stationary (d=1.2)", smooth_nonstationary, {'d': 1.2}),
        ("Abrupt Changes (d=0.8)", abrupt_change_environment, {'d': 0.8}),
        ("Gradual Divergence (d=0.8)", gradually_diverging, {'d': 0.8}),
        ("High Frequency (d=0.6)", high_frequency_changes, {'d': 0.6}),
        ("Competitive Balanced (d=0.7)", competitive_balanced_environment, {'d': 0.7}),
    ]

    positions = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1)]
    difficulty_stats = []

    for idx, (env_name, env_func, params) in enumerate(environments):
        if idx < len(positions):
            row, col = positions[idx]
            ax = axes[row, col]

            # 生成环境数据
            means, d = env_func(K, T, **params)

            # 分析难度
            difficulty = analyze_environment_difficulty(means, K, T)
            difficulty_stats.append((env_name, difficulty))

            # 绘制每个arm的奖励曲线
            for i in range(K):
                ax.plot(means[i, :], color=colors[i], label=f'Arm {i + 1}',
                        linewidth=1.5, alpha=0.7)

            # 添加瞬时最优arm的点
            optimal_arms = np.argmax(means, axis=0)
            for t in range(0, T, 100):
                ax.scatter(t, means[optimal_arms[t], t], color=colors[optimal_arms[t]],
                           s=15, marker='o', alpha=0.8, edgecolors='black', linewidth=0.5)

            # 添加gap信息
            gaps = [np.sort(means[:, t])[-1] - np.sort(means[:, t])[-2] for t in range(T)]

            ax.set_title(f'{env_name}\nAvg Gap: {difficulty["avg_gap"]:.4f}, Overlap: {difficulty["avg_overlap"]:.2f}',
                         fontweight='bold', fontsize=10)
            ax.set_xlabel('Time Step')
            ax.set_ylabel('Reward Mean')
            ax.set_xlim(0, T)
            ax.set_ylim(0.1, 0.9)  # 完整的0.1-0.9范围
            ax.legend(loc='upper right', framealpha=0.9, fontsize=8)
            ax.grid(True, alpha=0.3)

    # 隐藏多余的子图
    axes[1, 2].axis('off')

    plt.tight_layout()
    plt.savefig('figures/highly_overlapping_environments.pdf', dpi=300, bbox_inches='tight')
    plt.savefig('figures/highly_overlapping_environments.png', dpi=300, bbox_inches='tight')
    plt.show()

    return difficulty_stats


def create_gap_distribution_analysis():
    """创建gap分布分析"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Gap Distribution Analysis: Best vs 2nd Best Arms (0.1-0.9 Range)', fontsize=16, fontweight='bold')

    K, T = 4, 1000
    np.random.seed(42)

    environments = [
        ("Smooth", smooth_nonstationary, {'d': 1.2}),
        ("Abrupt", abrupt_change_environment, {'d': 0.8}),
        ("Gradual", gradually_diverging, {'d': 0.8}),
        ("High Freq", high_frequency_changes, {'d': 0.6}),
        ("Competitive", competitive_balanced_environment, {'d': 0.7}),
    ]

    positions = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1)]

    for idx, (env_name, env_func, params) in enumerate(environments):
        if idx < len(positions):
            row, col = positions[idx]
            ax = axes[row, col]

            means, d = env_func(K, T, **params)

            # 计算gap时间序列
            gaps = []
            for t in range(T):
                sorted_means = np.sort(means[:, t])
                gap = sorted_means[-1] - sorted_means[-2]
                gaps.append(gap)

            # 绘制gap时间序列
            ax.plot(gaps, color='red', linewidth=1.5, alpha=0.7)
            ax.axhline(y=np.mean(gaps), color='blue', linestyle='--',
                       label=f'Avg: {np.mean(gaps):.4f}')

            ax.set_title(f'{env_name}\nGap Range: [{np.min(gaps):.4f}, {np.max(gaps):.4f}]',
                         fontweight='bold')
            ax.set_xlabel('Time Step')
            ax.set_ylabel('Gap (Best - 2nd Best)')
            ax.legend()
            ax.grid(True, alpha=0.3)

    axes[1, 2].axis('off')

    plt.tight_layout()
    plt.savefig('figures/gap_distribution_analysis.pdf', dpi=300, bbox_inches='tight')
    plt.show()


if __name__ == "__main__":
    import os

    os.makedirs('figures', exist_ok=True)

    print("Creating highly overlapping bandit environments with 0.1-0.9 reward range...")

    # 创建高度重叠的环境
    difficulty_stats = create_highly_overlapping_visualization()

    # 创建gap分布分析
    create_gap_distribution_analysis()
    # 打印详细统计
    print("\n" + "=" * 60)
    print("DETAILED ENVIRONMENT ANALYSIS (0.1-0.9 Range)")
    print("=" * 60)

    for env_name, stats in difficulty_stats:
        print(f"\n{env_name}:")
        print(f"  Average Gap: {stats['avg_gap']:.5f}")
        print(f"  Gap Range: [{stats['min_gap']:.5f}, {stats['max_gap']:.5f}]")
        print(f"  Optimal Switches: {stats['optimal_switches']}")
        print(f"  Overlap Ratio: {stats['avg_overlap']:.3f}")
        print(f"  Difficulty Score: {stats['difficulty_score']:.1f}")

        # 评估难度等级（基于新的范围调整）
        if stats['avg_gap'] < 0.05:
            difficulty_level = "EXTREME"
        elif stats['avg_gap'] < 0.1:
            difficulty_level = "HIGH"
        elif stats['avg_gap'] < 0.2:
            difficulty_level = "MODERATE"
        else:
            difficulty_level = "LOW"

        print(f"  Difficulty Level: {difficulty_level}")

    print("\n" + "=" * 60)
    print("BANDIT SUITABILITY ASSESSMENT (0.1-0.9 Range)")
    print("=" * 60)
    print("✓ Reward range truly expanded to 0.1-0.9")
    print("✓ Arms can reach both high (0.8-0.9) and low (0.1-0.2) values")
    print("✓ Significant variation ensures exploration necessity")
    print("✓ Frequent optimal switches test adaptation")
    print("✓ Suitable for rigorous bandit algorithm evaluation")
