import os, argparse, numpy as np, pandas as pd, torch
from algos import *

from args import Args
from env_core_cb_ev_curt import CBEnv

# ---------- 载入 ----------
def load_ppo(agent, ckpt_dir, device="cpu"):
    sd=torch.load(os.path.join(ckpt_dir,"ppo.pt"), map_location=device)
    agent.ac.pi.load_state_dict(sd["pi"])
    agent.ac.log_std.data = sd["log_std"].to(device)
    agent.ac.vf.load_state_dict(sd["vf"])

def load_td3(agent, ckpt_dir, device="cpu"):
    sd=torch.load(os.path.join(ckpt_dir,"td3.pt"), map_location=device)
    agent.actor.load_state_dict(sd["actor"])
    agent.critic.load_state_dict(sd["critic"])
    agent.actor_t.load_state_dict(sd["actor_t"])
    agent.critic_t.load_state_dict(sd["critic_t"])

def load_sac(agent, ckpt_dir, device="cpu"):
    sd=torch.load(os.path.join(ckpt_dir,"sac.pt"), map_location=device)
    agent.actor.load_state_dict(sd["actor"])
    agent.q.load_state_dict(sd["critic"])
    agent.q_t.load_state_dict(sd["critic_t"])
    agent.log_alpha.data = sd["log_alpha"].to(device)

def load_ppo_lag(agent, ckpt_dir, device="cpu"):
    sd=torch.load(os.path.join(ckpt_dir,"ppo_lag.pt"), map_location=device)
    agent.ac.pi.load_state_dict(sd["pi"])
    agent.ac.log_std.data = sd["log_std"].to(device)
    agent.ac.vf.load_state_dict(sd["vf"])
    agent.ac.vc.load_state_dict(sd["vc"])
    agent.lam.data = torch.tensor(sd["lambda"], dtype=torch.float32, device=device)

def load_ppo_lag_multi(agent, ckpt_dir, device="cpu"):
    sd=torch.load(os.path.join(ckpt_dir,"ppo_lag_multi.pt"), map_location=device)
    agent.ac.pi.load_state_dict(sd["pi"])
    agent.ac.log_std.data = sd["log_std"].to(device)
    agent.ac.vf.load_state_dict(sd["vf"])
    for net, sd_net in zip(agent.ac.vc_list, sd["vc_list"]):
        net.load_state_dict(sd_net)
    agent.lam.data = torch.as_tensor(sd["lambdas"], dtype=torch.float32, device=device)


# ---------- 评估：导出价格/总线/线路/电池时序 ----------
def evaluate_one_day(env, agent, out_dir, device="cpu", deterministic=True):
    os.makedirs(out_dir, exist_ok=True)
    price_rows=[]; bus_rows=[]; line_rows=[]; batt_rows=[]
    
    # NEW: 记录 action
    actions_cb_rows = []     # 每个 CB：t, cb_id, cd, q, ratio
    actions_curt_rows = []   # 每个 bus(除母线)：t, bus, curt_ratio

    # NEW: 记录 reward / cost（逐步 & 汇总）
    step_metric_rows = []    # 每步
    ep_reward_sum = 0.0
    ep_cost_sums = dict(
        voltage_num=0.0, voltage_degree=0.0,
        line_over=0.0, batt_deg=0.0, pv_curt_fair=0.0, cost_total=0.0
    )

    day_id = os.path.basename(out_dir)

    o = env.reset()
    
    steps = 288  # 5分钟粒度的一天
    for t in range(steps):
        # 选动作
        if isinstance(agent, PPOAgent):
            a,_,_ = agent.act(o, deterministic=deterministic)
        elif isinstance(agent, LagrangianPPOAgent):
            a,_,_,_ = agent.act(o, deterministic=deterministic)
        elif isinstance(agent, LagrangianPPOMultiAgent):
            a = agent.act(o, deterministic=deterministic)
        a = clip_action_to_space(a, env.action_space)

        # === NEW: 拆分并记录 action ===
        # 1) CB 的 (cd, q, ratio)
        for i in range(env.num_cb):
            base = 3 * i
            cd    = float(a[base + 0])
            q     = float(a[base + 1])
            ratio = float(a[base + 2])
            actions_cb_rows.append({
                "t": t, "cb_id": i,
                "cd": cd, "q": q, "ratio": ratio
            })

        # 2) PV curtailment（跳过母线 bus0；动作向量里紧随其后）
        curt_start = 3 * env.num_cb
        # 假设 env.total_buses 按 0..N；母线=0，不对应动作
        # 对齐顺序：动作中的第 k 个削减比例对应 total_buses[1+k]
        for k, bid in enumerate(env.total_buses[1:]):
            curt_ratio = float(a[curt_start + k])
            actions_curt_rows.append({
                "t": t,
                "bus": int(bid),
                "curt_ratio": curt_ratio
            })
        # === NEW END ===


        o2, r, truncated, info = env.step(a)

        v_num   = float(info.get("voltage_num", 0.0))
        v_deg   = float(info.get("voltage_degree", 0.0))
        # 训练脚本里 thermal_limit=0，这里沿用相同口径：line_over 就是 info['thermal'] 的非负部分
        line_over = float(max(float(info.get("thermal", 0.0)) - 0.0, 0.0))
        b_deg   = float(info.get("batt_deg", 0.0))
        pv_fair = float(info.get("pv_curt_fair", 0.0))
        cost_total = v_num + v_deg + line_over + b_deg + pv_fair

        step_metric_rows.append({
            "day": day_id, "t": t,
            "reward": float(r),
            "voltage_num": v_num, "voltage_degree": v_deg,
            "line_over": line_over, "batt_deg": b_deg, "pv_curt_fair": pv_fair,
            "cost_total": cost_total
        })

        ep_reward_sum += float(r)
        ep_cost_sums["voltage_num"]   += v_num
        ep_cost_sums["voltage_degree"]+= v_deg
        ep_cost_sums["line_over"]     += line_over
        ep_cost_sums["batt_deg"]      += b_deg
        ep_cost_sums["pv_curt_fair"]  += pv_fair
        ep_cost_sums["cost_total"]    += cost_total


        # 价格
        price_rows.append({"t":t, "purchase":env.current_purchase, "sell":env.current_sell})

        # Bus
        v = env.network.res_bus['vm_pu'].to_numpy()
        pv= env.network.res_sgen['p_mw'].to_numpy()
        pl= env.network.res_load['p_mw'].to_numpy()
        ql= env.network.res_load['q_mvar'].to_numpy()
        for bid in env.total_buses:
            bus_rows.append({
                "t":t, "bus":int(bid),
                "v_pu": float(v[bid]),
                "pv_mw": float(pv[bid-1]) if bid>0 and bid-1<len(pv) else np.nan,
                "p_load_mw": float(pl[bid-1]) if bid>0 and bid-1<len(pl) else np.nan,
                "q_load_mvar": float(ql[bid-1]) if bid>0 and bid-1<len(ql) else np.nan,
            })

        # 线路
        res_line=env.network.res_line; ln=env.network.line
        for idx in res_line.index:
            line_rows.append({
                "t":t, "line_idx":int(idx),
                "from_bus": int(ln.at[idx,"from_bus"]),
                "to_bus":   int(ln.at[idx,"to_bus"]),
                "loading_percent": float(res_line.at[idx,"loading_percent"]),
                "p_from_mw": float(res_line.at[idx,"p_from_mw"]),
                "p_to_mw": float(res_line.at[idx,"p_to_mw"]),
                "q_from_mvar": float(res_line.at[idx,"q_from_mvar"]),
                "q_to_mvar": float(res_line.at[idx,"q_to_mvar"]),
            })

        # 电池（按 storage 的顺序 0..num_cb-1）
        for i in range(env.num_cb):
            soc=float(env.network.storage.iloc[i]["soc_percent"])
            p_mw=float(env.network.storage.iloc[i]["p_mw"])
            q_mv=float(env.network.storage.iloc[i]["q_mvar"])
            trd=float(env.trading_amount[i]) if hasattr(env,"trading_amount") else np.nan
            batt_rows.append({"t":t, "cb_id":i, "bus_id":int(env.cb_bus_id.get(i,-1)),
                              "soc":soc, "p_mw":p_mw, "q_mvar":q_mv, "trading_kWh":trd})

        o=o2

    # 写 CSV
    pd.DataFrame(price_rows).to_csv(os.path.join(out_dir,"eval_price.csv"), index=False)
    pd.DataFrame(bus_rows).to_csv(os.path.join(out_dir,"eval_bus_timeseries.csv"), index=False)
    pd.DataFrame(line_rows).to_csv(os.path.join(out_dir,"eval_line_timeseries.csv"), index=False)
    pd.DataFrame(batt_rows).to_csv(os.path.join(out_dir,"eval_batt_timeseries.csv"), index=False)
    pd.DataFrame(actions_cb_rows).to_csv(os.path.join(out_dir,"eval_actions_cb.csv"), index=False)
    pd.DataFrame(actions_curt_rows).to_csv(os.path.join(out_dir,"eval_actions_curt.csv"), index=False)

    pd.DataFrame(step_metric_rows).to_csv(os.path.join(out_dir, "eval_reward_cost_timeseries.csv"), index=False)

    # 新增：本日（episode）汇总（追加写）
    ep_summary_path = os.path.join(os.path.dirname(out_dir), "eval_episode_summary.csv")
    ep_row = {
        "day": day_id,
        "reward_sum": ep_reward_sum,
        "cost_total_sum": ep_cost_sums["cost_total"],
        "voltage_num_sum": ep_cost_sums["voltage_num"],
        "voltage_degree_sum": ep_cost_sums["voltage_degree"],
        "line_over_sum": ep_cost_sums["line_over"],
        "batt_deg_sum": ep_cost_sums["batt_deg"],
        "pv_curt_fair_sum": ep_cost_sums["pv_curt_fair"],
        # 也给个均值（/steps）
        "reward_mean": ep_reward_sum / steps,
        "cost_total_mean": ep_cost_sums["cost_total"] / steps
    }
    df_ep = pd.DataFrame([ep_row])
    header_needed = not os.path.exists(ep_summary_path)
    df_ep.to_csv(ep_summary_path, mode="a", index=False, header=header_needed)

    return o


def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--algo", choices=["ppo","ppo_lag","ppo_lag_multi"], required=True)
    ap.add_argument("--ckpt", required=True, help="checkpoint dir, e.g., runs/exp1/ckpt")
    ap.add_argument("--out", default="/root/L_MCPPO_ICLR/results/test/")
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--days", type=int, default=100)
      
    args=ap.parse_args()

    args_env=Args()
    env=CBEnv(args_env)

    # 依据（包没包编码器）构建输出目录
    out_root = os.path.join(args.out, args.algo)
    os.makedirs(out_root, exist_ok=True)
    
    obs_dim=env.observation_space.shape[0]; act_dim=env.action_space.shape[0]
    device=args.device

    if args.algo=="ppo":
        agent=PPOAgent(obs_dim, act_dim, device=device)
        load_ppo(agent, args.ckpt, device=device)
    elif args.algo=="ppo_lag":
        agent = LagrangianPPOAgent(obs_dim, act_dim, cost_limit=0.0, device=device)
        load_ppo_lag(agent, args.ckpt, device=device)
    elif args.algo=="ppo_lag_multi":
        # 这里需要给 cost_limits，eval 时不影响 action，可随便给同维度占位
        dummy_limits = np.zeros(3, dtype=np.float32)  # 若你训练时是3个cost
        agent = LagrangianPPOMultiAgent(obs_dim, act_dim, dummy_limits, device=device)
        load_ppo_lag_multi(agent, args.ckpt, device=device)

    for d in range(args.days):
        out_d=os.path.join(args.out+args.algo, f"day_{d+1}")
        # evaluate_one_day(env, agent, out_d, device=device, deterministic=True)
        evaluate_one_day(env, agent, out_d, device=device, deterministic=True)
        print(f"Saved day {d+1} csvs to: {out_d}")

if __name__=="__main__":
    main()
