import os
import numpy as np
import torch
import sys
import mage.utils as utils
import mage.datasets as datasets
from mage.models.vqvae import VQContinuousVAE
import wandb

class Parser(utils.Parser):
    dataset: str = 'pen-expert-v0'
    config: str = 'config.vqvae'


args = Parser().parse_args('train')

#######################
####### dataset #######
#######################
env_name = args.dataset if "-v" in args.dataset else args.dataset+"-v0"
env = datasets.load_environment(env_name)
sequence_length = (args.history_horizon + args.horizon + 1) * args.step
args.logbase = os.path.join(args.cwd, args.logbase)
if not os.path.exists(args.savepath):
    os.makedirs(args.savepath)


dataset_class = datasets.SequenceDataset

dataset_config = utils.Config(
    dataset_class,
    savepath=(args.savepath, 'data_config.pkl'),
    env=args.dataset,
    penalty=args.termination_penalty,
    sequence_length=sequence_length,
    history_horizon=args.history_horizon,
    horizon=args.horizon,
    use_action = True,
    step=args.step,
    discount=args.discount,
    disable_goal=args.disable_goal,
    normalize_raw=args.normalize,
    normalize_reward=args.normalize_reward,
    max_path_length=int(args.max_path_length),
)


dataset = dataset_config()
obs_dim = dataset.observation_dim
act_dim = dataset.action_dim
if args.task_type == "locomotion":
    transition_dim = 1+obs_dim+1
else:
    transition_dim = 1+128+1
transition_dim += act_dim if args.use_action else 0

#######################
######## model ########
#######################

block_size = sequence_length * transition_dim

print(
    f'Dataset size: {len(dataset)} | '
    f'Joined dim: {transition_dim} '
    f'(observation: {obs_dim}, action: {act_dim})'
)



model_config = utils.Config(
    VQContinuousVAE,
    savepath=(args.savepath, 'model_config.pkl'),
    history_horizon=args.history_horizon,
    horizon=args.horizon,
    use_action=args.use_action,
    vocab_size=args.N, block_size=block_size,
    K=args.K,
    n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd * args.n_head,
    observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
    action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight,
    position_weight=args.position_weight,
    current_obs_weight=args.current_obs_weight,
    current_action_weight=args.current_action_weight,
    next_obs_weight=args.next_obs_weight,
    next_action_weight=args.next_action_weight,
    trajectory_embd=args.trajectory_embd,
    model=args.model,
    latent_step=args.latent_step,
    ma_update=args.ma_update,
    residual=args.residual,
    obs_shape=args.obs_shape,
    embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
    bottleneck=args.bottleneck,
    masking=args.masking,
    state_conditional=args.state_conditional,
)


model = model_config()
model.to(args.device)
if args.normalize:
    model.set_padding_vector(torch.from_numpy(
                dataset.normalize_RandS(np.zeros(model.transition_dim-1))
            ))

#######################
####### trainer #######
#######################

warmup_tokens = len(dataset) * block_size 
final_tokens = 20 * warmup_tokens

trainer_config = utils.Config(
    utils.VQTrainer,
    savepath=(args.savepath, 'trainer_config.pkl'),
    batch_size=args.batch_size,
    learning_rate=args.learning_rate,
    betas=(0.9, 0.95),
    grad_norm_clip=1.0,
    weight_decay=0.1, 
    lr_decay=args.lr_decay,
    warmup_tokens=warmup_tokens,
    kl_warmup_tokens=warmup_tokens*10,
    final_tokens=final_tokens,
    num_workers=0,
    device=args.device,
)

trainer = trainer_config()

#######################
###### main loop ######
####################### 

n_epochs = int(1e6 / len(dataset) * args.n_epochs_ref)
save_freq = int(n_epochs // args.n_saves)
wandb.init(project="MAGE", config=args, tags=[args.exp_name, args.tag])

for epoch in range(n_epochs):
    print(f'\nEpoch: {epoch} / {n_epochs} | {args.dataset} | {args.exp_name}')

    trainer.train(model, dataset)
    save_epoch = (epoch + 1) // save_freq * save_freq
    statepath = os.path.join(args.savepath, f'state_{save_epoch}.pt')
    print(f'Saving model to {statepath}')
    state = model.state_dict()
    torch.save(state, statepath)
