import argparse
import os
import numpy as np
import torch
import inspect

import sys
sys.path.append('.')

from src.configs import reward_model_config
from src.models import RewardModel
from src.utils.data import TrajDataset
from src.utils.logger import Logger, make_log_dirs
from src.utils.utils import set_seed, DictToObject, get_optimizer_params
from src.utils.trainer import Trainer


def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, default='RewardModel')
    parser.add_argument('--log', type=str, default='MetaWorld')
    parser.add_argument('--task', type=str, default='button-press-wall-v2')
    parser.add_argument('--train-data-path', type=str, default='/infinite/common/metaworld-64/train')
    parser.add_argument('--ckpt-path', type=str, default=None)

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--info', type=str, default=None)

    known_args, _ = parser.parse_known_args()
    for arg_key, default_value in reward_model_config.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))

    args = parser.parse_args()

    return args


def main(config=get_config()):
    # setup
    set_seed(config.seed)

    # logger
    output_config = {
        "consoleout_backup": "stdout",
        "training_progress": "csv",
        "tb": "tensorboard"
    }
    log_path = make_log_dirs(config.log, f"{config.model_name}/{config.task}", config.seed, vars(config))
    logger = Logger(log_path, output_config)
    logger.log_hyperparameters(vars(config))

    # dataset
    traindataset = TrajDataset(config.train_data_path, 8, None, mode=config.task, device=config.device)
    trainloader = torch.utils.data.DataLoader(traindataset, batch_size=config.batch_size, shuffle=True, num_workers=20)

    reward_model = RewardModel(
        DictToObject(config.reward),
        config.img_size, config.patch_size, config.seq_len
    ).to(config.device)

    if config.ckpt_path is not None:
        state_dict = torch.load(config.ckpt_path, map_location=config.device)
        if 'module' in state_dict.keys():
            state_dict = state_dict['module']
        elif 'model_state_dict' in state_dict.keys():
            state_dict = state_dict['model_state_dict']
        else:
            state_dict = state_dict
        reward_model.load_state_dict(state_dict)
    
    optim_params = get_optimizer_params(reward_model, weight_decay=config.weight_decay)
    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
    use_fused = fused_available and 'cuda' in config.device
    extra_args = dict(fused=True) if use_fused else dict()
    optim = torch.optim.AdamW(optim_params, lr=config.lr, betas=config.betas, **extra_args)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optim, lr_lambda=lambda step: min(step / config.warmup_steps, 1.0)
    )

    trainer = Trainer(reward_model, optim, scheduler, trainloader, logger, config.device)
    trainer.train(train_steps=10000)


if __name__ == "__main__":
    main()