import argparse
import os
import numpy as np
import torch
import inspect

from src.configs import policy_embedding_config
from src.models import PolicyEmbedModel
from src.utils.data import TrajDataset
from src.utils.logger import Logger, make_log_dirs
from src.utils.utils import set_seed, DictToObject, get_optimizer_params
from src.utils.scheduler import CosineAnnealingWarmupRestarts
from src.utils.trainer import Trainer


def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, default='PolicyEmbedding')
    parser.add_argument('--log', type=str, default='MetaWorld')
    parser.add_argument('--train-data-path', type=str, default='/infinite/common/metaworld-64/train')
    parser.add_argument('--checkpoint-path', type=str, default=None)

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--class-size', type=int, default=16)
    parser.add_argument('--cat-size', type=int, default=16)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--info', type=str, default=None)

    known_args, _ = parser.parse_known_args()
    for arg_key, default_value in policy_embedding_config.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))

    args = parser.parse_args()

    return args


def main(config=get_config()):
    # setup
    set_seed(config.seed)

    # dataset
    train_dataset = TrajDataset(config.train_data_path, config.seq_len, os.path.join(config.train_data_path, "train_indices.json"), "train", config.device)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=20)

    config.policy.update({'action_dim': train_dataset.action_dim})
    config.posterior.update({'cat_size': config.cat_size, 'class_size': config.class_size})

    embed_model = PolicyEmbedModel(
        DictToObject(config.posterior),
        DictToObject(config.prior),
        DictToObject(config.policy),
        DictToObject(config.kl),
        config.img_size, config.patch_size, config.seq_len,
    ).to(config.device)
    if config.checkpoint_path is not None:
        state_dict = torch.load(config.checkpoint_path, map_location=config.device)
        if 'module' in state_dict.keys():
            state_dict = state_dict['module']
        elif 'model_state_dict' in state_dict.keys():
            state_dict = state_dict['model_state_dict']
        else:
            state_dict = state_dict
        embed_model.load_state_dict(state_dict)

    optim_params = get_optimizer_params(embed_model, weight_decay=config.weight_decay)
    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
    use_fused = fused_available and 'cuda' in config.device
    extra_args = dict(fused=True) if use_fused else dict()
    optim = torch.optim.AdamW(optim_params, lr=config.lr, betas=config.betas, **extra_args)

    num_epoch = 1
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optim, lr_lambda=lambda step: min(step / config.warmup_steps, 1.0)
    )

    # logger
    output_config = {
        "consoleout_backup": "stdout",
        "training_progress": "csv",
        "tb": "tensorboard"
    }
    log_path = make_log_dirs(config.log, config.model_name, config.seed, vars(config))
    logger = Logger(log_path, output_config)
    logger.log_hyperparameters(vars(config))
    
    trainer = Trainer(embed_model, optim, scheduler, train_loader, logger, config.device)
    trainer.train(train_steps=config.n_train_steps)


if __name__ == "__main__":
    main()