import argparse
import os
import numpy as np
import torch
import inspect

from src.configs import video_tokenizer_config, dynamic_model_config
from src.models import VideoTokenizer, PolicyEmbedModel, InitDynamicModel
from src.utils.data import InitialTrajDataset
from src.utils.logger import Logger, make_log_dirs
from src.utils.utils import set_seed, DictToObject, get_optimizer_params
from src.utils.scheduler import CosineAnnealingWarmupRestarts
from src.utils.trainer import Trainer


def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, default='DynamicModel')
    parser.add_argument('--log', type=str, default='MetaWorld(256)')
    parser.add_argument('--train-data-path', type=str, default='/infinite/common/metaworld/train')
    parser.add_argument('--video-tokenizer-path', type=str, default='models/tokenizer.pth')
    parser.add_argument('--checkpoint-path', type=str, default=None)
    parser.add_argument('--embed-type', type=int, default=16)

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--info', type=str, default=None)

    known_args, _ = parser.parse_known_args()
    for arg_key, default_value in dynamic_model_config.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))
        
    args = parser.parse_args()

    return args


def main(config=get_config()):
    # setup
    set_seed(config.seed)

    # logger
    output_config = {
        "consoleout_backup": "stdout",
        "training_progress": "csv",
        "tb": "tensorboard"
    }
    log_path = make_log_dirs(config.log, config.model_name, config.seed, vars(config))
    logger = Logger(log_path, output_config)
    logger.log_hyperparameters(vars(config))

    # dataset
    train_dataset = InitialTrajDataset(config.train_data_path, video_tokenizer_config['seq_len'], None, "train", config.device)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=12)

    config.decoder.update({'action_dim': train_dataset.action_dim})
    config.policy.update({'action_dim': train_dataset.action_dim})
    
    # init latent action dynamics model
    video_tokenizer = VideoTokenizer(
        DictToObject(video_tokenizer_config['encoder']),
        DictToObject(video_tokenizer_config['decoder']),
        DictToObject(video_tokenizer_config['codebook']),
        video_tokenizer_config['img_size'], video_tokenizer_config['patch_size'], video_tokenizer_config['seq_len']
    ).to(config.device)
    video_tokenizer.load_state_dict(torch.load(config.video_tokenizer_path, map_location=config.device))
    video_tokenizer.eval()

    if config.embed_type > 0:
        config.posterior.update({'cat_size': config.embed_type, 'class_size': config.embed_type})
        embed_model = PolicyEmbedModel(
            DictToObject(config.posterior),
            DictToObject(config.prior),
            DictToObject(config.policy),
            DictToObject(config.kl),
            config.img_size, config.patch_size, config.seq_len
        ).to(config.device)
        embed_model.load_state_dict(torch.load(os.path.join('models', f'embed_16.pth'), map_location=config.device))
        embed_model.eval()
    else:
        embed_model = None

    dynamic_model = InitDynamicModel(
        video_tokenizer,
        embed_model,
        DictToObject(config.decoder)
    ).to(config.device)
    if config.checkpoint_path is not None:
        state_dict = torch.load(config.checkpoint_path, map_location=config.device)
        if 'module' in state_dict.keys():
            state_dict = state_dict['module']
        elif 'model_state_dict' in state_dict.keys():
            state_dict = state_dict['model_state_dict']
        else:
            state_dict = state_dict
        dynamic_model.load_state_dict(state_dict)

    optim_params = get_optimizer_params(dynamic_model, weight_decay=config.weight_decay)
    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
    use_fused = fused_available and 'cuda' in config.device
    extra_args = dict(fused=True) if use_fused else dict()
    optim = torch.optim.AdamW(optim_params, lr=config.lr, betas=config.betas, **extra_args)

    num_epoch = 1
    scheduler = CosineAnnealingWarmupRestarts(
        optim,
        first_cycle_steps=int(len(train_loader)*num_epoch // 10),
        cycle_mult=2,
        min_lr=config.lr/10,
        max_lr=config.lr,
        warmup_steps=config.warmup_steps,
        gamma=0.9
    )

    trainer = Trainer(dynamic_model, optim, scheduler, train_loader, logger, config.device)
    trainer.train(train_steps=160000)


if __name__ == "__main__":
    main()