# @title RL Experiment: 3 Groups + Progress Print + Final Dist + Ribbon Plot
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import seaborn as sns  # 强烈建议使用 seaborn 简化绘图
import random
import os
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback

# ==========================================
# 1. 全局配置
# ==========================================

EXP_CONFIG = {
    # 三组对比实验：低熵(坍缩)、中等、高熵(探索/蓄水池)
    "ENT_COEF_LIST": [0.0001, 0.05, 0.2],
    "SEEDS": [2025, 2026, 2027],  # 每个实验跑3个种子以计算带状误差
    "TOTAL_TIMESTEPS": 10000,
    "LOG_INTERVAL": 100,  # 每多少步记录一次熵
    "LAST_STEPS_WINDOW": 1000,  # 记录最后多少步的状态分布
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
}


def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ==========================================
# 2. 环境定义 (双井势能)
# ==========================================


class DoubleWellEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.observation_space = gym.spaces.Box(
            low=-10, high=10, shape=(1,), dtype=np.float32
        )
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
        self.state = np.array([0.0], dtype=np.float32)
        self.max_steps = 20
        self.steps = 0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None:
            self.observation_space.seed(seed)
            self.action_space.seed(seed)
        self.state = np.array([0.0], dtype=np.float32)
        self.steps = 0
        return self.state, {}

    def step(self, action):
        # 简单的动力学
        self.state = np.clip(self.state + action * 0.5, -10, 10)
        self.steps += 1
        x = self.state[0]
        # 双井奖励函数：左峰(-2), 右峰(4)
        reward = (
            1.0 * np.exp(-0.5 * (x - (-2)) ** 2 / 0.5**2)
            + 10.0 * np.exp(-0.5 * (x - 4) ** 2 / 0.5**2)
            - 0.01 * abs(x)
        )
        return self.state, reward, False, self.steps >= self.max_steps, {}


# ==========================================
# 3. 增强版 Callback (打印 + 记录熵 + 记录分布)
# ==========================================


class AdvancedLoggerCallback(BaseCallback):
    def __init__(self, total_timesteps, check_last_n=1000):
        super().__init__(verbose=0)
        self.total_timesteps_target = total_timesteps
        self.check_last_n = check_last_n

        # 存储容器
        self.entropy_history = []  # 存 [step, entropy]
        self.final_states = []  # 存 [state_value]

    def _on_step(self) -> bool:
        current_step = self.num_timesteps

        # --- 功能1: 中间打印 (每1000步) ---
        if current_step % 1000 == 0:
            print(
                f"   > Progress: {current_step}/{self.total_timesteps_target} steps",
                end="\r",
            )

        # --- 功能2: 记录熵 (每 LOG_INTERVAL 步) ---
        if current_step % EXP_CONFIG["LOG_INTERVAL"] == 0:
            with torch.no_grad():
                obs = torch.as_tensor(self.locals["new_obs"]).to(self.model.device)
                _, log_std, _ = self.model.policy.actor.get_action_dist_params(obs)
                entropy_proxy = log_std.mean().item()

            self.entropy_history.append(
                {"step": current_step, "entropy": entropy_proxy}
            )

        # --- 功能3: 记录最后 N 步的状态 (用于分布图) ---
        # 只要当前步数进入了最后 N 步的范围，就记录状态
        if current_step > (self.total_timesteps_target - self.check_last_n):
            # 获取当前状态 (numpy array)
            current_state = self.locals["new_obs"][0][0]  # 假设 batch_size=1 环境
            self.final_states.append(current_state)

        return True


# ==========================================
# 4. 实验主循环
# ==========================================

# 准备两个列表，分别存两种数据
entropy_dataset = []  # 用于画带状图
distribution_dataset = []  # 用于画分布图

print(f"Starting Experiment: {EXP_CONFIG['ENT_COEF_LIST']}")
print("-" * 50)

for ent_coef in EXP_CONFIG["ENT_COEF_LIST"]:
    print(f"\n>>> Testing Condition: Ent_Coef = {ent_coef}")

    for seed in EXP_CONFIG["SEEDS"]:
        seed_everything(seed)
        print(f"   [Seed {seed}] Initializing...", end="")

        env = DoubleWellEnv()
        callback = AdvancedLoggerCallback(
            total_timesteps=EXP_CONFIG["TOTAL_TIMESTEPS"],
            check_last_n=EXP_CONFIG["LAST_STEPS_WINDOW"],
        )

        model = SAC(
            "MlpPolicy",
            env,
            verbose=0,
            seed=seed,
            device=EXP_CONFIG["DEVICE"],
            ent_coef=ent_coef,
            learning_rate=3e-3,
        )

        model.learn(total_timesteps=EXP_CONFIG["TOTAL_TIMESTEPS"], callback=callback)

        print(f" Done. Final states collected: {len(callback.final_states)}")

        # --- 数据收集 1: 熵的历史 ---
        for record in callback.entropy_history:
            entropy_dataset.append(
                {
                    "condition": str(ent_coef),  # 转字符串方便绘图分类
                    "seed": seed,
                    "step": record["step"],
                    "entropy": record["entropy"],
                }
            )

        # --- 数据收集 2: 最终分布 ---
        for state_val in callback.final_states:
            distribution_dataset.append(
                {"condition": str(ent_coef), "seed": seed, "final_state": state_val}
            )

print("\n" + "=" * 50)
print("所有实验完成！正在保存数据...")

# ==========================================
# 5. 保存数据 (两个 CSV)
# ==========================================

df_entropy = pd.DataFrame(entropy_dataset)
df_dist = pd.DataFrame(distribution_dataset)

df_entropy.to_csv("exp_entropy_history_10000.csv", index=False)
df_dist.to_csv("exp_final_distribution_10000.csv", index=False)

print("1. 熵历史数据已保存: exp_entropy_history.csv")
print("2. 最终分布数据已保存: exp_final_distribution.csv")

# ==========================================
# 6. 联合绘图 (Seaborn)
# ==========================================

# 设置 Seaborn 风格
sns.set_theme(style="whitegrid")

# 创建画布：1行2列
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# --- 左图：熵的带状图 (Ribbon Plot) ---
sns.lineplot(
    data=df_entropy,
    x="step",
    y="entropy",
    hue="condition",
    palette="viridis",
    errorbar="sd",  # 显示标准差带状
    ax=axes[0],
    linewidth=2,
)
axes[0].set_title("Policy Entropy Evolution (Mean $\pm$ Std)")
axes[0].set_ylabel("Entropy (Log Std)")
axes[0].set_xlabel("Training Steps")

# --- 右图：最后1000步的状态分布 (KDE Plot) ---
# 使用 kdeplot (核密度估计) 比 histogram 更平滑，适合看# @title RL Experiment: Publication Ready Plotting
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
import torch
import seaborn as sns
import random
import os
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback

# ==========================================
# 1. 全局配置
# ==========================================

EXP_CONFIG = {
    # 三组对比实验：低熵(坍缩)、高熵(探索/蓄水池)
    "ENT_COEF_LIST": [0.0001, 0.2],
    "SEEDS": [2025, 2026, 2027],  # 每个实验跑3个种子以计算带状误差
    "TOTAL_TIMESTEPS": 10000,
    "LOG_INTERVAL": 100,  # 每多少步记录一次熵
    "LAST_STEPS_WINDOW": 1000,  # 记录最后多少步的状态分布
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
}


def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ==========================================
# 2. 环境定义 (双井势能)
# ==========================================


class DoubleWellEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.observation_space = gym.spaces.Box(
            low=-10, high=10, shape=(1,), dtype=np.float32
        )
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
        self.state = np.array([0.0], dtype=np.float32)
        self.max_steps = 20
        self.steps = 0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None:
            self.observation_space.seed(seed)
            self.action_space.seed(seed)
        self.state = np.array([0.0], dtype=np.float32)
        self.steps = 0
        return self.state, {}

    def step(self, action):
        # 简单的动力学
        self.state = np.clip(self.state + action * 0.5, -10, 10)
        self.steps += 1
        x = self.state[0]
        # 双井奖励函数：左峰(-2), 右峰(4)
        reward = (
            1.0 * np.exp(-0.5 * (x - (-2)) ** 2 / 0.5**2)
            + 10.0 * np.exp(-0.5 * (x - 4) ** 2 / 0.5**2)
            - 0.01 * abs(x)
        )
        return self.state, reward, False, self.steps >= self.max_steps, {}


# ==========================================
# 3. 增强版 Callback
# ==========================================


class AdvancedLoggerCallback(BaseCallback):
    def __init__(self, total_timesteps, check_last_n=1000):
        super().__init__(verbose=0)
        self.total_timesteps_target = total_timesteps
        self.check_last_n = check_last_n
        self.entropy_history = []
        self.final_states = []

    def _on_step(self) -> bool:
        current_step = self.num_timesteps

        if current_step % 1000 == 0:
            print(
                f"   > Progress: {current_step}/{self.total_timesteps_target} steps",
                end="\r",
            )

        if current_step % EXP_CONFIG["LOG_INTERVAL"] == 0:
            with torch.no_grad():
                obs = torch.as_tensor(self.locals["new_obs"]).to(self.model.device)
                _, log_std, _ = self.model.policy.actor.get_action_dist_params(obs)
                entropy_proxy = log_std.mean().item()

            self.entropy_history.append(
                {"step": current_step, "entropy": entropy_proxy}
            )

        if current_step > (self.total_timesteps_target - self.check_last_n):
            current_state = self.locals["new_obs"][0][0]
            self.final_states.append(current_state)

        return True


# ==========================================
# 4. 实验主循环
# ==========================================

entropy_dataset = []
distribution_dataset = []

print(f"Starting Experiment: {EXP_CONFIG['ENT_COEF_LIST']}")
print("-" * 50)

for ent_coef in EXP_CONFIG["ENT_COEF_LIST"]:
    # 将系数转为字符串，并加上前缀，使图例更清晰
    condition_label = f"Ent Coef = {ent_coef}"
    print(f"\n>>> Testing Condition: {condition_label}")

    for seed in EXP_CONFIG["SEEDS"]:
        seed_everything(seed)

        env = DoubleWellEnv()
        callback = AdvancedLoggerCallback(
            total_timesteps=EXP_CONFIG["TOTAL_TIMESTEPS"],
            check_last_n=EXP_CONFIG["LAST_STEPS_WINDOW"],
        )

        model = SAC(
            "MlpPolicy",
            env,
            verbose=0,
            seed=seed,
            device=EXP_CONFIG["DEVICE"],
            ent_coef=ent_coef,
            learning_rate=3e-3,
        )

        model.learn(total_timesteps=EXP_CONFIG["TOTAL_TIMESTEPS"], callback=callback)

        for record in callback.entropy_history:
            entropy_dataset.append(
                {
                    "condition": condition_label,
                    "seed": seed,
                    "step": record["step"],
                    "entropy": record["entropy"],
                }
            )

        for state_val in callback.final_states:
            distribution_dataset.append(
                {"condition": condition_label, "seed": seed, "final_state": state_val}
            )

print("\n" + "=" * 50)
print("Data collection finished.")

# 转换为 DataFrame
df_entropy = pd.DataFrame(entropy_dataset)
df_dist = pd.DataFrame(distribution_dataset)


import pandas as pd

# -------------------------------------------------------
# 请将 df1 和 df2 替换为你实际的 DataFrame 变量名
# 例如：如果你的变量叫 result_df 和 loss_df，就改成它们
# -------------------------------------------------------

# 假设 df1 是你的第一个数据框
# encoding='utf-8-sig' 是为了防止在 Excel 中打开中文乱码
# index=False 表示不保存索引列（除非你的索引是有意义的时间或其他数据）
df_entropy.to_csv("df_entropy.csv", index=False, encoding="utf-8-sig")

# 假设 df2 是你的第二个数据框
df_dist.to_csv("df_dist.csv", index=False, encoding="utf-8-sig")

print("两个 CSV 文件已成功保存到当前目录。")
