import os
import sys

import torch
import numpy as np
from tqdm import trange

from WM.model import GPT, GPTConfig
from get_batch import get_batch
from env_list import env_list, env_action_dict

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

env_name = sys.argv[-1]
N = int(sys.argv[-2])
gpu_id = int(sys.argv[-3])
save_dir = f"WM/out/{env_name}"
os.makedirs(save_dir, exist_ok=True)

input_dim = 512
B = 16
env_list = [env_name] * B

wm_eval_interval = 1000
wm_eval_iter = 500
wm_patience = 100

device = f'cuda:{gpu_id}'

ctx = torch.autocast(device_type='cuda', dtype=torch.float16)

model_args = dict(
    n_layer=N,
    n_head=16,
    input_dim=512,
    n_embd=512,
    context_length=144,
    bias=False,
    action_size=len(env_action_dict[env_name]) if env_name in env_action_dict else 18,
    dropout=0.0,
)
loss_log = []
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
model.to(device)
checkpoint = None
scaler = torch.cuda.amp.GradScaler()
lr = 1e-4
optimizer = model.configure_optimizers(1e-1, lr, (0.9, 0.95), 'cuda')

early_stop = 0
best_val_loss = 1e9
iter_total = 1000000
e_train = []
for iter in trange(iter_total, desc=f'WM training {env_name} N{N}', ncols=200):
    X_enc, X_action, X_reward, Y_enc = get_batch('train', device, env_list, input_dim)
    with ctx:
        loss, e = model(X_enc, X_action, X_reward, Y_enc)
    e_train.append(e.detach().cpu().numpy())

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    if iter != 0 and iter % wm_eval_interval == 0:
        e_train = np.mean(e_train, axis=0)
        e_val = []
        for _ in range(wm_eval_iter):
            X_enc, X_action, X_reward, Y_enc = get_batch('val', device, env_list, input_dim)
            with ctx:
                with torch.no_grad():
                    loss, e = model(X_enc, X_action, X_reward, Y_enc)
            e_val.append(e.detach().cpu().numpy())
        e_val = np.mean(e_val, axis=0).tolist()

        loss_log.append([e_train, e_val])
        lossf_val = np.array(e_val).mean()
        checkpoint = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "model_args": model_args,
        }
        if lossf_val < best_val_loss:
            early_stop = 0
            best_val_loss = lossf_val
            print('Saving model', iter, best_val_loss)
            torch.save(checkpoint, os.path.join(save_dir, f'ckpt_{N}_{gpu_id}.pt'))
            torch.save(loss_log, os.path.join(save_dir, f'loss_log_{N}_{gpu_id}.pt'))
        else:
            if iter > 200000:
                early_stop += 1
            if early_stop >= wm_patience: break

        e_train = []