import os
import shutil

from data_modules.utils import load_dataset
from trainers import train

import argparse
import torch
import numpy as np
import random
from modules import *
import yaml

def seed_all(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='antmaze-medium-play-v2')
    parser.add_argument('--algo', type=str, default='dtamp')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--prefix', type=str, default=None)
    parser.add_argument('--resume', action='store_true', dest='resume', default=False)
    args = parser.parse_args()

    exp_name = args.algo + '_' + args.env
    
    domain = args.env.split('-')[0]
    config = yaml.load(open(f'configs/{domain}/{args.algo}.yaml'), Loader=yaml.FullLoader)
    if args.env == 'calvin':
        env = None
    else:
        from envs.d4rl_env import GoalReachingD4rlEnv
        env = GoalReachingD4rlEnv(args.env)
    load_dataset(args.env, config['dataset_cfg'], env=env)

    config['trainer_cfg']['resume'] = args.resume

    if args.algo == 'dtamp':
        model = DTAMP(**config['model_cfg'], dataset_cfg=config['dataset_cfg'])
    elif args.algo == 'play_lmp':
        model = PlayLMP(**config['model_cfg'], dataset_cfg=config['dataset_cfg'])
    else:
        raise NotImplementedError

    if args.seed > -1:
        seed_all(args.seed)
        exp_name += '_s_%d' % args.seed

    model_dir = os.path.join('checkpoints', domain, exp_name)
    log_dir = os.path.join('logs', domain, exp_name)
    os.makedirs(model_dir, exist_ok=True)
    if not args.resume and os.path.exists(log_dir):
        shutil.rmtree(log_dir)
    os.makedirs(log_dir, exist_ok=True)
    
    config['trainer_cfg']['model_dir'] = model_dir
    config['trainer_cfg']['log_dir'] = log_dir

    if args.env == 'calvin':
        train(model, model_cfg=config['model_cfg'], **config['trainer_cfg'])
    else:
        config['trainer_cfg']['env'] = env
        train(model, model_cfg=config['model_cfg'], **config['trainer_cfg'])


if __name__ == '__main__':
    main()
