# train_and_log.py
import os, time, argparse, numpy as np, torch, pandas as pd
from algos import *

from args import Args
from env_core_cb_ev_curt import CBEnv

# 百分之多少，建议是100，但我写成0
thermal_limit = 0



# ---------- 简单日志器（覆盖式存储） ----------
class SimpleLogger:
    def __init__(self, out_dir: str):
        self.out_dir = out_dir
        os.makedirs(out_dir, exist_ok=True)
        # 预定义文件路径
        self.train_path = os.path.join(out_dir, "train_episode.csv")
        self.dual_path  = os.path.join(out_dir, "dual_vars.csv")
        # 覆盖旧文件：如果上一次运行留下了同名文件，这里直接删除
        for p in [self.train_path, self.dual_path]:
            try:
                os.remove(p)
            except FileNotFoundError:
                pass
        # 本次运行的缓冲
        self.rows = []
        self.dual_rows = []

    def log_train_episode(self, episode, reward, volt_num, volt_deg, line_cost, extra=None):
        row = {
            "episode": episode,
            "reward": reward,
            "volt_num": volt_num,
            "volt_deg": volt_deg,
            "line_cost": line_cost,
            "ts": time.time()
        }
        if extra:
            row.update(extra)
        self.rows.append(row)

    def flush(self):
        """把本次运行内的训练条目写盘（覆盖旧 run，不再读旧文件）"""
        if self.rows:
            df = pd.DataFrame(self.rows)
            # 这里用追加写，但由于 __init__ 已删除旧文件，因此等价于覆盖旧 run
            header_needed = not os.path.exists(self.train_path)
            df.to_csv(self.train_path, index=False, mode="a", header=header_needed)
            # 可选：如果想减少内存占用，可以清空缓冲
            self.rows.clear()

        if self.dual_rows:
            df = pd.DataFrame(self.dual_rows)
            header_needed = not os.path.exists(self.dual_path)
            df.to_csv(self.dual_path, index=False, mode="a", header=header_needed)
            self.dual_rows.clear()

    def log_dual_vars(self, step: int, lambdas):
        """把拉格朗日乘子纪录到内存缓冲；flush 时一次性写盘"""
        row = {"step": step}
        if np.isscalar(lambdas):
            row["lambda_0"] = float(lambdas)
        else:
            for i, v in enumerate(np.array(lambdas).reshape(-1)):
                row[f"lambda_{i}"] = float(v)
        self.dual_rows.append(row)


# ---------- 保存/载入 ----------
def save_ppo(agent, path):
    os.makedirs(path, exist_ok=True)
    torch.save({"pi":agent.ac.pi.state_dict(),
                "log_std":agent.ac.log_std.detach().cpu(),
                "vf":agent.ac.vf.state_dict()}, os.path.join(path,"ppo.pt"))

def save_td3(agent, path):
    os.makedirs(path, exist_ok=True)
    torch.save({"actor":agent.actor.state_dict(),
                "critic":agent.critic.state_dict(),
                "actor_t":agent.actor_t.state_dict(),
                "critic_t":agent.critic_t.state_dict()}, os.path.join(path,"td3.pt"))

def save_sac(agent, path):
    os.makedirs(path, exist_ok=True)
    sd={"actor":agent.actor.state_dict(),
        "critic":agent.q.state_dict(),
        "critic_t":agent.q_t.state_dict(),
        "auto":True,
        "log_alpha":agent.log_alpha.detach().cpu()}
    torch.save(sd, os.path.join(path,"sac.pt"))


def save_ppo_lag(agent: LagrangianPPOAgent, path):
    os.makedirs(path, exist_ok=True)
    torch.save({
        "pi": agent.ac.pi.state_dict(),
        "log_std": agent.ac.log_std.detach().cpu(),
        "vf": agent.ac.vf.state_dict(),
        "vc": agent.ac.vc.state_dict(),
        "lambda": float(agent.lam.item()),
        "cost_limit": float(agent.cost_limit),
    }, os.path.join(path,"ppo_lag.pt"))

def save_ppo_lag_multi(agent: LagrangianPPOMultiAgent, path):
    os.makedirs(path, exist_ok=True)
    torch.save({
        "pi": agent.ac.pi.state_dict(),
        "log_std": agent.ac.log_std.detach().cpu(),
        "vf": agent.ac.vf.state_dict(),
        "vc_list": [vc.state_dict() for vc in agent.ac.vc_list],
        "lambdas": agent.lam.detach().cpu().numpy(),
        "cost_limits": agent.cost_limits.detach().cpu().numpy(),
    }, os.path.join(path,"ppo_lag_multi.pt"))


def extract_cost_vector(info: dict, cost_keys: list):
    """
    支持的 key:
    - 'voltage'   : 你的 info['voltage']（越界数量或你后续的合成）
    - 'volt_deg'  : 你的 info['volt_deg']（若已加入）
    - 'line_over' : 从 info['thermal'] 里把 >100% 的超额求和
    其他 key 返回 0
    """
    vec=[]
    for k in cost_keys:
        if k == "voltage_num":
            vec.append(float(info.get("voltage_num", 0.0)))
        elif k == "voltage_degree":
            vec.append(float(info.get("voltage_degree", 0.0)))
        elif k == "batt_deg":
            vec.append(float(info.get("batt_deg", 0.0)))
        elif k == "pv_curt_fair":
            vec.append(float(info.get("pv_curt_fair", 0.0)))
        elif k == "line_over":
            thermal = info.get("thermal", None)
            if thermal is None: vec.append(0.0)
            else:
                over = np.maximum(thermal-thermal_limit, 0.0).sum()
                vec.append(float(over))
        else:
            vec.append(0.0)
    return np.array(vec, dtype=np.float32)


# ---------- 训练器 ----------
def train_ppo(env, out_dir, steps_per_epoch=288, epochs=10, device="cuda"):
    obs_dim=env.observation_space.shape[0]; act_dim=env.action_space.shape[0]
    agent=PPOAgent(obs_dim,act_dim,device=device)
    buf=PPOBuffer(obs_dim,act_dim,steps_per_epoch)
    norm=RunningNorm(obs_dim)
    logger=SimpleLogger(out_dir)
    episode=0; ep_ret=0.0; ep_v=0; ep_vn=0.0; ep_vd=0.0; ep_l=0.0; ep_len=0
    ep_bdeg=0.0; ep_pvfair=0.0

    o = env.reset()
    for epoch in range(epochs):
        obs_batch=[]
        for t in range(steps_per_epoch):
            o_n = norm.normalize(o); obs_batch.append(o)
            a, v, logp = agent.act(o_n, deterministic=False)
            a = clip_action_to_space(a, env.action_space)
            ret = env.step(a)
            # 兼容你的 step 返回 list
            if isinstance(ret, (list, tuple)) and len(ret)==4:
                o2, r, truncated, info = ret
                done = bool(truncated)
            else:
                raise RuntimeError("Unexpected env.step return shape.")

            sum_r = r - 1 * (1.0*info.get("voltage_num",0.0)
                + 1.0*info.get("voltage_degree",0.0)
                + 1.0*np.maximum(np.array(info.get("thermal",0.0))-thermal_limit,0.0).sum()
                + 1.0*info.get("batt_deg",0.0)
                + 1.0*info.get("pv_curt_fair",0.0))

            ep_ret += r; ep_len += 1
            # 记录电压/线路
            volt_cnt=float(info.get("voltage_num",0.0))
            volt_deg=float(info.get("voltage_degree",0.0))  # 如果你加入了程度
            ep_vn = volt_cnt
            ep_vd = volt_deg
            ep_v += volt_cnt + volt_deg
            thermal=info.get("thermal", None)
            if thermal is not None:
                over = np.maximum(thermal-thermal_limit, 0.0).sum()
                ep_l += float(over)
            ep_bdeg   += float(info.get("batt_deg", 0.0))
            ep_pvfair += float(info.get("pv_curt_fair", 0.0))


            buf.store(o_n, a, sum_r, v, logp)
            o=o2
            timeout = (ep_len == (env.truncated_length//env.time_step))
            terminal = done or timeout or (t==steps_per_epoch-1)
            if terminal:
                last_val = 0.0 if done else agent.ac.vf(torch.as_tensor(norm.normalize(o),dtype=torch.float32).unsqueeze(0)).item()
                buf.finish_path(last_val)
                episode += 1
                extra = {
                    "batt_deg_avg":   (ep_bdeg/ep_len) if ep_len>0 else 0.0,
                    "pv_curt_fair_avg": (ep_pvfair/ep_len) if ep_len>0 else 0.0,
                }
                logger.log_train_episode(episode, ep_ret, ep_vn, ep_vd, ep_l, extra=extra)
                if episode % 10 == 0: logger.flush()
                o = env.reset(); ep_ret=ep_vn=ep_vd=ep_l=0.0; ep_len=0;ep_bdeg=0.0; ep_pvfair=0.0

        norm.update(np.array(obs_batch))
        agent.update(buf)
        if (epoch+1) % 10 == 0:
            save_ppo(agent, os.path.join(out_dir, "ckpt"))


def train_ppo_lag(env, out_dir, cost_key="voltage", cost_limit=0.0,
                  steps_per_epoch=4096, epochs=200, dual_lr=0.05,
                  device="cpu"):
    obs_dim=env.observation_space.shape[0]; act_dim=env.action_space.shape[0]
    agent=LagrangianPPOAgent(obs_dim, act_dim, cost_limit=cost_limit, dual_lr=dual_lr, device=device)
    buf=PPOCostBuffer(obs_dim, act_dim, steps_per_epoch, cost_dim=1)
    norm=RunningNorm(obs_dim); logger=SimpleLogger(out_dir)
    episode=0; ep_ret=0.0; ep_v=0.0; ep_vn=0.0; ep_vd=0.0; ep_l=0.0; ep_len=0
    ep_bdeg=0.0; ep_pvfair=0.0
    global_update_step=0
    o=env.reset()

    for epoch in range(epochs):
        obs_batch=[]
        for t in range(steps_per_epoch):
            o_n=norm.normalize(o); obs_batch.append(o)
            a, v, c_v, logp = agent.act(o_n, deterministic=False)
            a = clip_action_to_space(a, env.action_space)
            o2,r,truncated,info = env.step(a); done=bool(truncated)

            c_total = (1.0*info.get("voltage_num",0.0)
                + 1.0*info.get("voltage_degree",0.0)
                + 1.0*np.maximum(np.array(info.get("thermal",0.0))-thermal_limit,0.0).sum()
                + 1.0*info.get("batt_deg",0.0)
                + 1.0*info.get("pv_curt_fair",0.0))
            c_vec = np.array([c_total], dtype=np.float32)

            # cost（单维）：
            # c_vec = extract_cost_vector(info, [cost_key])
            ep_ret += r; ep_len += 1
            volt_cnt=float(info.get("voltage_num",0.0)); volt_deg=float(info.get("voltage_degree",0.0))
            ep_bdeg   += float(info.get("batt_deg", 0.0))
            ep_pvfair += float(info.get("pv_curt_fair", 0.0))
            ep_vn = volt_cnt
            ep_vd = volt_deg
            ep_v += volt_cnt + volt_deg
            thermal=info.get("thermal",None)
            if thermal is not None: ep_l += float(np.maximum(thermal-thermal_limit,0.0).sum())

            buf.store(o_n, a, r, v, logp, c_vec, np.array([[c_v]], np.float32))
            o=o2

            timeout = (ep_len == (env.truncated_length//env.time_step))
            terminal = done or timeout or (t==steps_per_epoch-1)
            if terminal:
                last_v = 0.0 if done else agent.ac.vf(torch.as_tensor(norm.normalize(o),dtype=torch.float32).unsqueeze(0)).item()
                last_cv= 0.0 if done else agent.ac.vc(torch.as_tensor(norm.normalize(o),dtype=torch.float32).unsqueeze(0)).item()
                buf.finish_path(last_val=last_v, last_cval=np.array([[last_cv]], np.float32))
                episode += 1
                extra = {
                    "batt_deg_avg":   (ep_bdeg/ep_len) if ep_len>0 else 0.0,
                    "pv_curt_fair_avg": (ep_pvfair/ep_len) if ep_len>0 else 0.0,
                }
                logger.log_train_episode(episode, ep_ret, ep_vn, ep_vd, ep_l, extra=extra)
                if episode%10==0: logger.flush()
                o=env.reset(); ep_ret=ep_v=ep_vn=ep_vd=ep_l=0.0; ep_len=0
                ep_bdeg=0.0; ep_pvfair=0.0

        norm.update(np.array(obs_batch))
        stats = agent.update(buf)  # dict(Jc, lambda_value)
        global_update_step += 1
        logger.log_dual_vars(global_update_step, stats["lambda_value"])
        if (epoch+1)%10==0: save_ppo_lag(agent, os.path.join(out_dir,"ckpt"))
        print(f"[PPO-Lag] epoch {epoch+1}/{epochs} Jc={stats['Jc']:.4f} lam={stats['lambda_value']:.4f}")



def train_ppo_lag_multi(env, out_dir, cost_keys, cost_limits,
                        steps_per_epoch=4096, epochs=200, dual_lr=0.05,
                        device="cpu"):
    """
    cost_keys: list[str] 例如 ["voltage","volt_deg","line_over"]
    cost_limits: list[float] 与 cost_keys 一一对应
    """
    K = len(cost_keys)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    agent = LagrangianPPOMultiAgent(
        obs_dim, act_dim, np.array(cost_limits, np.float32),
        dual_lr=dual_lr, device=device
    )
    buf = PPOCostBuffer(obs_dim, act_dim, steps_per_epoch, cost_dim=K)
    norm = RunningNorm(obs_dim)
    logger = SimpleLogger(out_dir)

    episode = 0
    ep_ret = ep_v = ep_vn = ep_vd = ep_l = 0.0
    ep_bdeg=0.0; ep_pvfair=0.0
    ep_len = 0
    global_update_step = 0

    o = env.reset()

    for epoch in range(epochs):
        obs_batch = []
        for t in range(steps_per_epoch):
            o_n = norm.normalize(o)
            obs_batch.append(o)

            # ---- 采样：一次前向拿 a, v, c_v, logp_old（放在同一 device）----
            with torch.no_grad():
                obs_t = torch.as_tensor(o_n, dtype=torch.float32, device=device).unsqueeze(0)
                a_t, v_t, c_v_t, logp_t = agent.ac.step(obs_t)   # a: (1,A), c_v: (1,K), logp: (1,1)
                a = a_t.squeeze(0).cpu().numpy()
                v = float(v_t.item())
                c_v_np = c_v_t.squeeze(0).cpu().numpy()          # (K,)
                logp = float(logp_t.item())

            a = clip_action_to_space(a, env.action_space)
            o2, r, truncated, info = env.step(a)
            done = bool(truncated)

            # cost 向量（按你的 info 提取）
            c_vec = extract_cost_vector(info, cost_keys)  # shape (K,)
            ep_ret += r
            ep_len += 1

            # 这两行若你的 info 用的是 "voltage" / "volt_deg" 键名，请对齐
            volt_cnt = float(info.get("voltage_num", info.get("voltage", 0.0)))
            volt_deg = float(info.get("voltage_degree", info.get("volt_deg", 0.0)))

            ep_vn = volt_cnt
            ep_vd = volt_deg
            ep_v += volt_cnt + volt_deg
            ep_bdeg   += float(info.get("batt_deg", 0.0))
            ep_pvfair += float(info.get("pv_curt_fair", 0.0))

            thermal = info.get("thermal", None)
            if thermal is not None:
                ep_l += float(np.maximum(thermal-thermal_limit, 0.0).sum())

            # 存入 buffer（含 old logp）
            buf.store(o_n, a, r, v, logp, c_vec, c_v_np[None, ...])

            o = o2
            timeout = (ep_len == (env.truncated_length // env.time_step))
            terminal = done or timeout or (t == steps_per_epoch - 1)
            if terminal:
                with torch.no_grad():
                    obs_last = torch.as_tensor(norm.normalize(o), dtype=torch.float32,
                                               device=device).unsqueeze(0)
                    last_v = 0.0 if done else agent.ac.vf(obs_last).item()
                    last_cv_list = [net(obs_last).item() for net in agent.ac.vc_list]
                    last_cv = np.array(last_cv_list, np.float32)[None, :]  # (1, K)
                buf.finish_path(last_val=last_v, last_cval=last_cv)

                episode += 1
                # 如果你的 SimpleLogger 接口是 (episode, reward, volt_cost, line_cost)，
                # 那就把 ep_v 传进去，并把细分项放 extra 里保存。
                extra = {
                    "batt_deg_avg":   (ep_bdeg/ep_len) if ep_len>0 else 0.0,
                    "pv_curt_fair_avg": (ep_pvfair/ep_len) if ep_len>0 else 0.0,
                }
                logger.log_train_episode(episode, ep_ret, ep_vn, ep_vd, ep_l, extra=extra)
                if episode % 10 == 0:
                    logger.flush()

                o = env.reset()
                ep_ret = ep_v = ep_vn = ep_vd = ep_l = 0.0
                ep_bdeg=0.0; ep_pvfair=0.0
                ep_len = 0

        norm.update(np.array(obs_batch))
        stats = agent.update(buf)  # dict(Jc (K,), lambdas (K,))
        global_update_step += 1
        logger.log_dual_vars(global_update_step, stats["lambdas"])

        if (epoch + 1) % 10 == 0:
            save_ppo_lag_multi(agent, os.path.join(out_dir, "ckpt"))

        print(f"[PPO-Lag-M] epoch {epoch+1}/{epochs} Jc={stats['Jc']} lambdas={stats['lambdas']}")



# ---------- CLI ----------
def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--algo", choices=["ppo","ppo_lag","ppo_lag_multi"],required=True)
    ap.add_argument("--out", default="/root/L_MCPPO_ICLR/results/train/")
    ap.add_argument("--para_set", default=0)
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    # 下面这些按需调整
    ap.add_argument("--ppo_epochs", type=int, default=2000)
    ap.add_argument("--ppo_steps_per_epoch", type=int, default=288)
    ap.add_argument("--off_total_steps", type=int, default=288*2000)
    ap.add_argument("--off_start_steps", type=int, default=288*0)
    ap.add_argument("--off_batch", type=int, default=256)
    ap.add_argument("--dual_lr", type=float, default=5e-3)
    ap.add_argument("--cost_key", default="voltage", help="for ppo_lag, single key")
    ap.add_argument("--cost_limit", type=float, default=0.9)
    ap.add_argument("--cost_keys", 
                default="voltage_num,voltage_degree,line_over,batt_deg,pv_curt_fair",
                help="for ppo_lag_multi, comma-separated")

    ap.add_argument("--cost_limits", default=[0.05,0.05,0.1,0.4,0.3])
    
    args_cli=ap.parse_args()

    args_env = Args()  # TODO: 用你的环境参数生成器
    env = CBEnv(args_env)
        

    os.makedirs(args_cli.out+str(args_cli.para_set)+'/'+args_cli.algo, exist_ok=True)
    print(f"Save & logs => {args_cli.out+args_cli.para_set+'/'+args_cli.algo}")

    # 现实角度来说大概三套东西，一套general而言针对电压，但是其他的项目都很松
    # 一套针对CBESS紧，一套针对PV紧
    # para_set = [[0.05,0.05,0.1,0.6,0.7],
    #                                  [0.1,0.1,0.1,0.6,0.7],
    #                                  [0.15,0.15,0.1,0.6,0.7],
    #                                  [0.2,0.2,0.1,0.6,0.7], # 最松
    #                                  [0.05,0.05,0.1,0.45,0.7],
    #                                  [0.1,0.1,0.1,0.45,0.7],
    #                                  [0.15,0.15,0.1,0.45,0.7],
    #                                  [0.2,0.2,0.1,0.45,0.7],
    #                                  [0.05,0.05,0.1,0.6,0.5],
    #                                  [0.1,0.1,0.1,0.6,0.5],
    #                                  [0.15,0.15,0.1,0.6,0.5],
    #                                  [0.2,0.2,0.1,0.6,0.5],
    #                                  [0.05,0.05,0.1,0.45,0.5] # 最紧
    #                                  ]

    para_set = [[9,9,0.1,30,30],
                                     [12,12,0.1,30,30],
                                     [15,15,0.1,30,30],
                                     [18,18,0.1,30,30], # 最松
                                     [9,9,0.1,20,30],
                                     [12,12,0.1,20,30],
                                     [15,15,0.1,20,30],
                                     [18,18,0.1,20,30],
                                     [9,9,0.1,30,20],
                                     [12,12,0.1,30,20],
                                     [15,15,0.1,30,20],
                                     [18,18,0.1,30,20],
                                     [9,9,0.1,20,20], # 都紧
                                     [7,7,0.1,15,15]   # 最紧
                                     ]

    args_cli.cost_limits = para_set[int(args_cli.para_set)]
    args_cli.cost_limit = np.sum(args_cli.cost_limits)
    

    if args_cli.algo=="ppo":
        train_ppo(env, args_cli.out+args_cli.algo, steps_per_epoch=args_cli.ppo_steps_per_epoch,
                  epochs=args_cli.ppo_epochs, device=args_cli.device)
    elif args_cli.algo=="ppo_lag":
        train_ppo_lag(env, args_cli.out+str(args_cli.para_set)+'/'+args_cli.algo, cost_key=args_cli.cost_key,
                    cost_limit=args_cli.cost_limit, steps_per_epoch=args_cli.ppo_steps_per_epoch,
                    epochs=args_cli.ppo_epochs, dual_lr=args_cli.dual_lr, device=args_cli.device)
    elif args_cli.algo=="ppo_lag_multi":
        keys = [s.strip() for s in args_cli.cost_keys.split(",") if s.strip()]
        # limits = [float(x) for x in args_cli.cost_limits.split(",")]
        limits = args_cli.cost_limits
        assert len(keys)==len(limits), "cost_keys 与 cost_limits 数量不一致"
        
        tmp_dir = args_cli.out+str(args_cli.para_set)+'/'+args_cli.algo
        train_ppo_lag_multi(env, tmp_dir, cost_keys=keys, cost_limits=limits,
                            steps_per_epoch=args_cli.ppo_steps_per_epoch, epochs=args_cli.ppo_epochs,
                            dual_lr=args_cli.dual_lr, device=args_cli.device)

if __name__=="__main__":
    main()
