import os
import numpy as np
import torch
import pdb
import sys 
import mage.utils as utils
import mage.datasets as datasets
from mage.models.vqvae import TransformerPrior
import wandb

class Parser(utils.Parser):
    dataset: str = 'pen-expert-v0'
    config: str = 'config.vqvae'

args = Parser().parse_args('plan')
args.logbase = os.path.expanduser(args.logbase)
args.savepath = os.path.expanduser(args.savepath)

#######################
####### dataset #######
#######################

env_name = args.dataset if "-v" in args.dataset else args.dataset+"-v0"
env = datasets.load_environment(env_name)

#######################
######## model ########
#######################

dataset = utils.load_from_config(args.logbase, args.dataset, args.exp_name,
        'data_config.pkl')
obs_dim = dataset.observation_dim
act_dim = dataset.action_dim
transition_dim = obs_dim + 1


representation, _ = utils.load_model(args.logbase, args.dataset, args.exp_name, epoch=args.gpt_epoch, device=args.device)

for param in representation.parameters():
    param.requires_grad = False
representation.eval()

args = Parser().parse_args('train')
args.logbase = os.path.expanduser(args.logbase)
args.savepath = os.path.expanduser(args.savepath)

if args.normalize:
    representation.set_padding_vector(torch.from_numpy(
                dataset.normalize_RandS(np.zeros(representation.transition_dim-1))
            ))

sequence_length = (args.history_horizon + args.horizon + 1) * args.step
block_size = sequence_length * transition_dim
obs_dim = dataset.observation_dim

model_config = utils.Config(
    TransformerPrior,
    savepath=(args.savepath, 'prior_model_config.pkl'),
    K=args.K, block_size=block_size,
    n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd * args.n_head,
    trajectory_embd=representation.trajectory_embd,
    use_action=args.use_action,
    observation_dim=obs_dim,
    action_dim = act_dim,
    history_horizon=args.history_horizon,
    horizon=args.horizon,
    v_patch_nums = representation.v_patch_nums,
    latent_step=args.latent_step,
    embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
    obs_shape=args.obs_shape,
    max_path_length=args.max_path_length,
)


model = model_config()
model.to(args.device)

#######################
####### trainer #######
#######################

warmup_tokens = len(dataset) * block_size
final_tokens = 20 * warmup_tokens

trainer_config = utils.Config(
    utils.PriorTrainer,
    savepath=(args.savepath, 'priortrainer_config.pkl'),
    use_action=args.use_action,
    batch_size=args.batch_size,
    learning_rate=args.learning_rate,
    betas=(0.9, 0.95),
    grad_norm_clip=1.0,
    weight_decay=0.1,
    lr_decay=args.lr_decay,
    warmup_tokens=warmup_tokens,
    kl_warmup_tokens=warmup_tokens*10,
    final_tokens=final_tokens,
    num_workers=0,
    device=args.device,
    rtg=args.rtg
)

trainer = trainer_config()

#######################
###### main loop ######
#######################

n_epochs = int(1e6 / len(dataset) * args.n_epochs_ref)
save_freq = int(n_epochs // args.n_saves)
wandb.init(project="MAGE", config=args, tags=[args.exp_name, args.tag, "prior"])

temp = 0.5
decay_rate=0.99995
max_score = 0
min_stde = 0
for epoch in range(n_epochs):
    print(f'\nEpoch: {epoch} / {n_epochs} | {args.dataset} | {args.exp_name}')

    mean_score, stde_score, temp = trainer.train(representation, model, dataset, temp=temp, decay_rate=decay_rate)

    save_epoch = (epoch + 1) // save_freq * save_freq
    statepath = os.path.join(args.savepath, f'prior_state_{save_epoch}.pt')
    print(f'Saving model to {statepath}')

    state = model.state_dict()
    torch.save(state, statepath)

    if (epoch == 0) or (mean_score > max_score or (mean_score == max_score and stde_score < min_stde) or (mean_score == max_score and stde_score == min_stde)):
        print(f'epoch {epoch} is best')
        best_statepath = os.path.join(args.savepath, f'best_prior_state.pt')
        print(f'Saving model to {best_statepath}')
        torch.save(state, best_statepath)
        max_score = mean_score
        min_stde = stde_score      
