from __future__ import annotations

import argparse
import os
from typing import Any, Dict

import numpy as np
import torch
import yaml

from gdc_project.utils.helpers import set_seed, make_env
from gdc_project.utils.logger import Logger
from gdc_project.utils.normalizer import RunningNorm
from gdc_project.gdc.replay_buffer import ReplayBuffer
from gdc_project.gdc.sac_agent import SACAgent, SACConfig
from gdc_project.gdc.gdc_sac_agent import GDCSACAgent, GDCSACConfig
from gdc_project.gdc.lagrangian_sac_agent import LagrangianSACAgent, LagrangianSACConfig


def load_config(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        cfg = yaml.safe_load(f)
    inherit = cfg.get("inherit")
    if inherit:
        with open(inherit, "r") as f:
            base = yaml.safe_load(f)
        base.update({k: v for k, v in cfg.items() if k != "inherit"})
        return base
    return cfg


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
    parser.add_argument("--logdir", type=str, default="runs/gdc")
    parser.add_argument("--save_every", type=int, default=None, help="Override save_every in config")
    parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from")
    parser.add_argument("--env_id", type=str, default=None, help="Override env_id")
    parser.add_argument("--seed", type=int, default=None, help="Override seed")
    parser.add_argument("--num_train_steps", type=int, default=None, help="Override number of training steps")
    args = parser.parse_args()

    cfg = load_config(args.config)

    seed = int(cfg.get("seed", 42))
    if args.seed is not None:
        seed = int(args.seed)
    device = torch.device(cfg.get("device", "cpu") if torch.cuda.is_available() else "cpu")
    env_id = args.env_id or cfg.get("env_id", "OptimalTrap-v0")
    set_seed(seed)

    env = make_env(env_id, seed)
    obs_space = env.observation_space
    act_space = env.action_space
    obs_dim = int(np.prod(obs_space.shape))
    act_dim = int(np.prod(act_space.shape))

    hidden_dims = tuple(cfg.get("hidden_dims", [256, 256]))
    gamma = float(cfg.get("gamma", 0.99))
    tau = float(cfg.get("tau", 0.005))

                   
    buffer = ReplayBuffer(obs_dim, act_dim, size=int(cfg.get("buffer_size", 1_000_000)))

                     
    agent_name = cfg.get("agent_name", "sac").lower()
    if agent_name == "a_gdc_sac":
        agent = GDCSACAgent(
            GDCSACConfig(
                obs_dim=obs_dim,
                act_dim=act_dim,
                hidden_dims=hidden_dims,
                gamma=gamma,
                tau=tau,
                actor_lr=float(cfg.get("actor_lr", 3e-4)),
                critic_lr=float(cfg.get("critic_lr", 3e-4)),
                alpha_lr=float(cfg.get("alpha_lr", 3e-4)),
                target_entropy_scale=float(cfg.get("target_entropy_scale", 1.0)),
                curvature_weight=float(cfg.get("curvature_weight", 1.0)),
                sigmoid_slope=float(cfg.get("sigmoid_slope", 1.0)),
                lanczos_steps=int(cfg.get("lanczos_steps", 6)),
                adaptive_rate=float(cfg.get("adaptive_rate", 0.05)),
                cost_ema_decay=float(cfg.get("cost_ema_decay", 0.05)),
                target_cost=float(cfg.get("target_cost", 0.01)),
                init_kappa0=float(cfg.get("init_kappa0", 0.0)),
                bootstrap_risk_weight=float(cfg.get("bootstrap_risk_weight", 0.0)),
                curvature_method=str(cfg.get("curvature_method", "power")),
                curvature_mode=str(cfg.get("curvature_mode", "action")),
                curvature_state_weight=float(cfg.get("curvature_state_weight", 1.0)),
                curvature_action_weight=float(cfg.get("curvature_action_weight", 1.0)),
                actor_spectral_norm=bool(cfg.get("actor_spectral_norm", False)),
                critic_spectral_norm=bool(cfg.get("critic_spectral_norm", True)),
                grad_clip_norm=float(cfg.get("grad_clip_norm", 0.0)),
            ),
            device,
        )
    elif agent_name in ("pcpo", "focops"):
        agent = LagrangianSACAgent(
            LagrangianSACConfig(
                obs_dim=obs_dim,
                act_dim=act_dim,
                hidden_dims=hidden_dims,
                gamma=gamma,
                tau=tau,
                actor_lr=float(cfg.get("actor_lr", 3e-4)),
                critic_lr=float(cfg.get("critic_lr", 3e-4)),
                alpha_lr=float(cfg.get("alpha_lr", 3e-4)),
                target_entropy_scale=float(cfg.get("target_entropy_scale", 1.0)),
                cost_limit=float(cfg.get("cost_limit", 0.02)),
                lambda_lr=float(cfg.get("lambda_lr", 0.05)),
                use_kl=bool(cfg.get("use_kl", agent_name == "focops")),
                kl_target=float(cfg.get("kl_target", 0.01)),
                kl_lr=float(cfg.get("kl_lr", 0.05)),
                actor_spectral_norm=bool(cfg.get("actor_spectral_norm", False)),
                critic_spectral_norm=bool(cfg.get("critic_spectral_norm", True)),
                grad_clip_norm=float(cfg.get("grad_clip_norm", 1.0)),
            ),
            device,
        )
    else:
        agent = SACAgent(
            SACConfig(
                obs_dim=obs_dim,
                act_dim=act_dim,
                hidden_dims=hidden_dims,
                gamma=gamma,
                tau=tau,
                actor_lr=float(cfg.get("actor_lr", 3e-4)),
                critic_lr=float(cfg.get("critic_lr", 3e-4)),
                alpha_lr=float(cfg.get("alpha_lr", 3e-4)),
                target_entropy_scale=float(cfg.get("target_entropy_scale", 1.0)),
                actor_spectral_norm=bool(cfg.get("actor_spectral_norm", False)),
                critic_spectral_norm=bool(cfg.get("critic_spectral_norm", False)),
                grad_clip_norm=float(cfg.get("grad_clip_norm", 0.0)),
            ),
            device,
        )

                          
    logdir = os.path.join(args.logdir, agent_name)
    logger = Logger(logdir)
    ckpt_dir = os.path.join(logdir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

                
    num_steps = int(cfg.get("num_train_steps", 200_000))
    if args.num_train_steps is not None:
        num_steps = int(args.num_train_steps)
    start_steps = int(cfg.get("start_steps", 1_000))
    update_after = int(cfg.get("update_after", 1_000))
    update_every = int(cfg.get("update_every", 1))
    eval_every = int(cfg.get("eval_every", 10_000))
    max_ep_steps = int(cfg.get("max_episode_steps", 500))
    batch_size = int(cfg.get("batch_size", 256))
    save_every = int(cfg.get("save_every", 0))
    reward_scale = float(cfg.get("reward_scale", 1.0))
    if args.save_every is not None:
        save_every = args.save_every

            
    if args.resume and os.path.isfile(args.resume):
        sd = torch.load(args.resume, map_location=device)
        try:
            agent.load_state_dict(sd)
            print(f"Resumed from {args.resume}")
        except Exception as e:
            print(f"Failed to resume: {e}")

    obs, _ = env.reset()
    ep_ret, ep_len, ep_cost = 0.0, 0, 0.0
                 
    obs_norm = RunningNorm()
                                  
    obs_norm.update(obs[None, :])

    for t in range(1, num_steps + 1):
                          
        if t < start_steps:
            act = env.action_space.sample()
        else:
                                                    
            obs_in = obs_norm.normalize(obs)
            act = agent.select_action(obs_in, deterministic=False)

        next_obs, rew, term, trunc, info = env.step(act)
        done = term or trunc
        cost = info.get("cost", 0.0)

        rew_scaled = float(rew) * reward_scale
        buffer.add(obs, act, np.array([rew_scaled], dtype=np.float32), next_obs, np.array([done], dtype=np.float32), np.array([cost], dtype=np.float32))
                                        
        obs_norm.update(next_obs[None, :])
        obs = next_obs
        ep_ret += rew
        ep_cost += cost
        ep_len += 1

        if done or (ep_len >= max_ep_steps):
            logger.log({"ep/return": ep_ret, "ep/len": ep_len, "ep/cost": ep_cost}, step=t)
            obs, _ = env.reset()
            ep_ret, ep_len, ep_cost = 0.0, 0, 0.0

                 
        if (t >= update_after) and (t % update_every == 0):
            batch = buffer.sample(batch_size)
                                                                        
            batch.obs = obs_norm.normalize(batch.obs)
            batch.next_obs = obs_norm.normalize(batch.next_obs)
            logs = agent.update(batch)
            logger.log(logs, step=t)

                                                          

              
        if save_every and (t % save_every == 0):
            path = os.path.join(ckpt_dir, f"step_{t}.pt")
            torch.save(agent.state_dict(), path)
            print(f"Saved checkpoint to {path}")

    logger.close()


if __name__ == "__main__":
    main()
