import argparse
import os
import numpy as np
import torch
import inspect

from src.configs import video_tokenizer_config
from src.models import VideoTokenizer
from src.utils.data import TrajDataset
from src.utils.logger import Logger, make_log_dirs
from src.utils.utils import set_seed, DictToObject, get_optimizer_params
from src.utils.trainer import Trainer


def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, default='VideoTokenizer')
    parser.add_argument('--log', type=str, default='MetaWorld')
    parser.add_argument('--train-data-path', type=str, default='/infinite/common/metaworld-64/train')
    parser.add_argument('--ckpt-path', type=str, default=None)

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--info', type=str, default=None)

    for arg_key, default_value in video_tokenizer_config.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))
    
    args = parser.parse_args()

    return args


def main(config=get_config()):
    # setup
    set_seed(config.seed)

    # logger
    output_config = {
        "consoleout_backup": "stdout",
        "training_progress": "csv",
        "tb": "tensorboard"
    }
    log_path = make_log_dirs(config.log, config.model_name, config.seed, vars(config))
    logger = Logger(log_path, output_config)
    logger.log_hyperparameters(vars(config))

    # dataset
    train_dataset = TrajDataset(config.train_data_path, video_tokenizer_config['seq_len'], os.path.join(config.train_data_path, "train_indices.json"), "train", config.device)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=12)
    # init latent action dynamics model
    video_tokenizer = VideoTokenizer(
        DictToObject(video_tokenizer_config['encoder']),
        DictToObject(video_tokenizer_config['decoder']),
        DictToObject(video_tokenizer_config['codebook']),
        video_tokenizer_config['img_size'], video_tokenizer_config['patch_size'], video_tokenizer_config['seq_len']
    ).to(config.device)
    if config.ckpt_path is not None:
        video_tokenizer.load_state_dict(torch.load(config.ckpt_path, map_location=config.device))

    optim_params = get_optimizer_params(video_tokenizer, weight_decay=config.weight_decay)
    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
    use_fused = fused_available and 'cuda' in config.device
    extra_args = dict(fused=True) if use_fused else dict()
    optim = torch.optim.AdamW(optim_params, lr=config.lr, betas=config.betas, **extra_args)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optim, lr_lambda=lambda step: min(step / config.warmup_steps, 1.0)
    )

    trainer = Trainer(video_tokenizer, optim, scheduler, train_loader, logger, config.device)
    trainer.train(train_steps=config.n_train_steps)


if __name__ == "__main__":
    main()