import os
import sys

import numpy as np
import torch
from torch.distributions.categorical import Categorical
from tqdm import trange

from PPO.alg import PPO
from PPO.wm_env import WMEnv
from WM.model import GPT, GPTConfig
from env_list import env_list, env_action_dict
from rdpred import MLP
from single_best_config import single_best_config

env_name = sys.argv[-1]
os.makedirs('PPO/out/{}'.format(env_name), exist_ok=True)
ih = 32
device = 'cuda:{}'.format(int(sys.argv[-2]))

A = len(env_action_dict[env_name])
N = single_best_config[env_name]
D = 512
B = 256

env_list = [env_name] * B

model_args = dict(
    n_layer=N,
    n_head=16,
    n_embd=D,
    input_dim=512,
    context_length=144, # 128 + 16
    bias=False,
    action_size=A,
    dropout=0.0,
)
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
ckpt = torch.load('WM/out/{}/ckpt_{}.pt'.format(env_name, N), map_location=device)['model']
model.load_state_dict(ckpt)
model.to(device)
model.eval()
ckpt = None

reward_head = MLP(3 if env_name in ('Boxing', 'Pong') else 2, A)
ckpt_rh = torch.load('WM/out/{}/reward_head.pt'.format(env_name))['model']
reward_head.load_state_dict(ckpt_rh)
reward_head.eval()
ckpt_rh = None

done_head = MLP(2, A)
ckpt_d = torch.load('WM/out/{}/done_head.pt'.format(env_name))['model']
done_head.load_state_dict(ckpt_d)
done_head.eval()
ckpt_d = None

envs = WMEnv(model, reward_head, done_head, env_name)
envs = envs.to(device)

alg = PPO(A, ih).to(device)

loss_totals = []
rewards_all = []

initial_lr = 1e-4
optimizer = torch.optim.Adam(alg.parameters(), lr=initial_lr, eps=1e-5)

total_updates = 100000
pbar = trange(total_updates, desc='Alg {}'.format(env_name))

best = -1e9; current = -1e9

tmp = []
env_list_expanded = np.array([item for item in env_list for _ in range(ih)])
kls = []
for iter in pbar:
    context_vectors, actions, action_logits, values, rewards, dones = alg.batch_rollout(envs, env_list)
    advantage = alg.compute_advantage(rewards, values, dones)
    rewards = torch.sum(rewards, dim=1).cpu().numpy()
    rewards_all.append(rewards)
    tmp.append(np.mean(rewards))
    context_vectors_flatten = context_vectors.flatten(0, 1)
    actions_ = actions.flatten()
    d = Categorical(logits=action_logits)
    log_probs = d.log_prob(actions).flatten()
    lambda_returns = (advantage + values[:, :-1]).flatten()
    advantage = advantage.flatten()
    advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
    values_ = values[:, :-1].flatten()
    l = context_vectors_flatten.shape[0]
    assert l % 8 == 0
    ll = l // 8

    stop_early = False
    for epoch in range(4):
        bs_idx = torch.randperm(l)
        for i in range(8):
            idxs = bs_idx[i*ll:(i+1)*ll]
            cv_minibatch = context_vectors_flatten[idxs]
            acts = actions_[idxs]
            lp = log_probs[idxs]
            lr = lambda_returns[idxs]
            adv = advantage[idxs]
            vals = values_[idxs]
            env_list_mb = list(env_list_expanded[idxs.tolist()])
            losses, approx_kl = alg(cv_minibatch, acts, lp, lr, adv, vals, env_list_mb)
            optimizer.zero_grad()
            loss_total_step = losses.loss_total.float()
            loss_total_step.backward()
            loss_totals.append((loss_total_step.item(), losses.intermediate_losses))
            torch.nn.utils.clip_grad_norm_(alg.parameters(), 0.5)
            optimizer.step()
            kls.append(approx_kl.item())
            if approx_kl.item() > 1.5 * alg.target_kl:
                stop_early = True
                break
        if stop_early: break
    
    if iter != 0 and iter % 100 == 0:
        current = np.mean(tmp)
        if current > best:
            best = current
            print('saving', iter)
            checkpoint = {
                    "losses": (loss_totals, rewards_all),
                    "alg": alg.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "kls": kls,
            }
            torch.save(checkpoint, f'PPO/out/{env_name}/ckpt_N{N}_B{B}_ih{ih}.pt')
        tmp = []

    pbar.set_postfix(current_return=f"{current:.4f}", best_return=f"{best:.4f}")
    progress = (iter + 1) / total_updates