import json
import os
import pdb
import pickle
import time

import numpy as np
import torch as th

import logger
from model_free_vec import DDPG
from replay_memory import Transition
from utils import make_env


def train(cfg):
    # Configure logger
    run_dir = os.path.join(cfg["log_dir"], f"run_{cfg['run_id']}")
    logger.configure(dir=run_dir, format_strs=["csv", "stdout"])
    with open(os.path.join(run_dir, "config.json"), "w") as f:
        json.dump(cfg, f, indent=4, default=str)

    # Load data
    with open("./ckpt_plot/offline-replay/replay.pkl", "rb") as f:
        memory = pickle.load(f)
    # Hardcode the stats
    n_steps = 25
    n_agent = 15
    obs_dim = 26
    n_action = 5

    # Train policy offline
    agent = DDPG(
        gamma=cfg["gamma"],
        tau=cfg["tau"],
        hidden_size=cfg["hidden_size"],
        obs_dim=obs_dim,
        n_action=n_action,
        n_agent=n_agent,
        obs_dims=[obs_dim for _ in range(n_agent)],
        agent_id=None,
        actor_lr=cfg["actor_lr"],
        critic_lr=cfg["critic_lr"],
        fixed_lr=True,  # not used
        actor_type=cfg["actor_type"],
        critic_type=cfg["critic_type"],
        train_noise=False,
        num_episodes=1,  # not used
        num_steps=1,  # not used
        critic_dec_cen=False,  # not used
        device=cfg["device"],
    )

    start = time.perf_counter()
    for i_iter in range(cfg["n_iters"]):
        # Sample batch from replay
        transitions = memory.sample(cfg["batch_size"])
        batch = Transition(*zip(*transitions))
        # batch -- a namedtuple:
        #   state: a tuple of length batch_size, each elem of shape 1 x (n_agent x obs_dim)
        #   action: a tuple of length batch_size, each elem of shape 1 x (n_agent x act_dim)
        #   mask: a tuple of length batch_size, each elem of shape 1 x n_agent
        #   next_state: a tuple of length batch_size, each elem of shape 1 x (n_agent x obs_dim)
        #   reward: a tuple of length batch_size, each elem of shape 1 x n_agent
        policy_loss = agent.update_actor_parameters(batch, 0, penalty=cfg["penalty"])
        value_loss, _, _ = agent.update_critic_parameters(batch, 0)
        # Evaluate the policy
        if (i_iter + 1) % cfg["eval_freq"] == 0:
            eval_env = make_env(scenario_name="simple_spread_n15", arglist=None)
            eval_rewards = []
            for _ in range(cfg["n_eval_episodes"]):
                obs_n = eval_env.reset()
                episode_reward = 0
                episode_step = 0
                for _ in range(n_steps):
                    obs_n = th.Tensor(obs_n).to(cfg["device"])
                    action_n = agent.select_action(obs_n, action_noise=True, param_noise=False)
                    action_n = action_n.squeeze().cpu().numpy()
                    next_obs_n, reward_n, done_n, _ = eval_env.step(action_n)
                    episode_step += 1
                    episode_reward += np.sum(reward_n)
                    obs_n = next_obs_n
                    if done_n[0]:
                        break
                eval_rewards.append(episode_reward)
            del eval_env
            mean_reward = np.mean(eval_rewards)
        else:
            mean_reward = np.nan
        if (i_iter + 1) % cfg["logging_freq"] == 0:
            end = time.perf_counter()
            logger.logkv("step", i_iter + 1)
            logger.logkv("time", end - start)
            logger.logkv("value_loss", value_loss)
            logger.logkv("policy_loss", policy_loss)
            logger.logkv("eval_reward", mean_reward)
            logger.dumpkvs()
            start = time.perf_counter()


if __name__ == "__main__":

    import argparse

    parser = argparse.ArgumentParser(description="DDPG offline")
    parser.add_argument("--run_id", type=int, required=True)
    parser.add_argument("--gpu_id", type=int)
    parser.add_argument("--gamma", type=float, default=0.95)
    parser.add_argument("--tau", type=float, default=0.01)
    parser.add_argument("--hidden_size", type=int, default=128)
    parser.add_argument("--actor_type", type=str, default="mlp")
    parser.add_argument(
        "--critic_type", type=str, default="mlp", choices=["mlp", "gcn", "deepset", "setformer"],
    )
    parser.add_argument("--actor_lr", type=float, default=1e-3)
    parser.add_argument("--critic_lr", type=float, default=1e-3)
    parser.add_argument("--penalty", action="store_true")
    parser.add_argument("--n_iters", type=int, default=60000)
    parser.add_argument("--batch_size", type=int, default=1024)
    parser.add_argument("--n_eval_episodes", type=int, default=100)
    parser.add_argument("--logging_freq", type=int, default=50)
    parser.add_argument("--eval_freq", type=int, default=500)
    args = parser.parse_args()

    cfg = {
        "run_id": args.run_id,
        "log_dir": f"./exps/ddpg_offline/{args.critic_type}",
        "gamma": args.gamma,
        "tau": args.tau,
        "hidden_size": args.hidden_size,
        "actor_lr": args.actor_lr,
        "critic_lr": args.critic_lr,
        "actor_type": args.actor_type,
        "critic_type": args.critic_type,
        "penalty": args.penalty,
        "device": "cpu" if args.gpu_id is None else f"cuda:{args.gpu_id}",
        "n_iters": args.n_iters,
        "batch_size": args.batch_size,
        "n_eval_episodes": args.n_eval_episodes,
        "logging_freq": args.logging_freq,
        "eval_freq": args.eval_freq,
    }

    train(cfg)
