"""SDVT
Based on https://github.com/suyoung-lee/SDVT (which in turn was based on https://github.com/lmzintgraf/varibad)
Main scripts to start experiments.
Takes a flag --env-type (see below for choices) and loads the parameters from the respective config file.

REMINDER - to add new configs do the following:
1. Add it to the imports
2. Add it to the first IF statement (approx. line 150+)
3. Add it to the elif chunk below 
4. Add it to the training loop (approx. line 450+)

yes it is very labour intensive but dont fix what aint broke
"""
import argparse
import warnings

import numpy as np
import torch
import json

# get configs

# NEW CONFIGS - UPDATE THESE
from config_dme.ml1 import \
    args_ml1_reach_DME, args_ml1_push_DME, args_ml1_pickplace_DME, args_ml1_reachwall_DME, \
    args_ml1_reach_DME_test

from config_dme.ml10 import \
    args_ml10_DME, \
    args_ml10_DME_test

from config_dme.ml45 import \
    args_ml45_DME, \
    args_ml45_DME_test

# DEFAULT IMPORTS
from config.ml1 import \
    args_ml1_reach_SDVT, args_ml1_reach_SDVT_LW, args_ml1_reach_SD, args_ml1_reach_SD_LW, \
    args_ml1_push_SDVT, args_ml1_push_SDVT_LW, args_ml1_push_SD, args_ml1_push_SD_LW, \
    args_ml1_pickplace_SDVT, args_ml1_pickplace_SDVT_LW, args_ml1_pickplace_SD, args_ml1_pickplace_SD_LW, \
    args_ml1_reachwall_SDVT, args_ml1_reachwall_SDVT_LW, args_ml1_reachwall_SD, args_ml1_reachwall_SD_LW

from config.ml10 import \
    args_ml10_SDVT, args_ml10_SDVT_LW, args_ml10_SD

from config.ml45 import \
    args_ml45_SDVT, args_ml45_SDVT_LW, args_ml45_SD, args_ml45_SD_LW, args_ml45_VariBAD

from config.mujoco import \
    args_cheetah_vel_VariBAD

from environments.parallel_envs import make_vec_envs
from learner import Learner
from learner_ppo import LearnerPPO
from metalearner import MetaLearner
from metalearner_VariBAD import MetaLearnerVariBAD
from metalearner_SDVT import MetaLearnerSDVT
from metalearner_ml10_SDVT import MetaLearnerML10SDVT
from metalearner_ml10_DME import MetaLearnerML10DME
from metalearner_ml10_VariBAD import MetaLearnerML10VariBAD
from metalearner_ml45_SDVT import MetaLearnerML45SDVT
from metalearner_ml45_DME import MetaLearnerML45DME
from metalearner_ml45_VariBAD import MetaLearnerML45VariBAD
from metalearner_DME import MetaLearnerDME
from metaeval_ml10 import MetaEvalML10
from metaeval_ml45 import MetaEvalML45
from utils.helpers import set_device, get_device


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default='ml10-SDVT')
    parser.add_argument('--gpu', type=int, default=0, help='GPU device ID to use. -1 for cpu')
    parser.add_argument('--load-dir', default=None)
    parser.add_argument('--load-iter', default=None)
    parser.add_argument('--render', default=False)
    args, rest_args = parser.parse_known_args()
    env = args.env_type
    
    if torch.cuda.is_available() and args.gpu >= 0:
        set_device(args.gpu)
    else:
        set_device(-1)  # This will set the device to CPU

    OLD_ENVS = ['ml10-VariBAD', 'ml45-VariBAD', 'cheetah-vel-VariBAD', 
               'ml1-reach-VariBAD', 'ml1-reachwall-VariBAD', 'ml1-push-VariBAD', 'ml1-pickplace-VariBAD', 'ml1-sweepinto-VariBAD',
               'ml1-buttonpress-VariBAD', 'ml1-plateslide-VariBAD',
               'ml1-reach-SD_LW', 'ml1-reachwall-SD_LW', 'ml1-push-SD_LW', 'ml1-pickplace-SD_LW']
    
    # ml10
    if env in [ # ML1 BASELINE
                'ml1-reach-SDVT', 'ml1-reach-SDVT_LW', 'ml1-reach-SD', 'ml1-reach-SD_LW',
                'ml1-push-SDVT', 'ml1-push-SDVT_LW', 'ml1-push-SD', 'ml1-push-SD_LW',
                'ml1-pickplace-SDVT', 'ml1-pickplace-SDVT_LW', 'ml1-pickplace-SD', 'ml1-pickplace-SD_LW',
                'ml1-reachwall-SDVT', 'ml1-reachwall-SDVT_LW', 'ml1-reachwall-SD', 'ml1-reachwall-SD_LW',

               # ML10/45 BASELINE
               'ml10-SDVT', 'ml10-SDVT_LW', 'ml10-SD', 'ml10-SD_LW', 
               'ml45-SDVT', 'ml45-SDVT_LW', 'ml45-SD', 'ml45-SD_LW',

               # ML1 DME variants
               'ml1-reach-DME', 'ml1-push-DME', 'ml1-pickplace-DME', 'ml1-reachwall-DME',
               'ml1-reach-DME-test',

               # ML10 DME variants
                'ml10-DME', 'ml10-DME-test', 

               # ML45 DME variants
                'ml45-DME', 'ml45-DME-test',

               # misc test stuff

                ] + OLD_ENVS : # streamlining code
        if args.load_dir is None:
            # ML1 Baselines
            if env=='ml1-reach-SDVT':
                args = args_ml1_reach_SDVT.get_args(rest_args)
            elif env=='ml1-reach-SDVT_LW':
                args = args_ml1_reach_SDVT_LW.get_args(rest_args)
            elif env=='ml1-reach-SD':
                args = args_ml1_reach_SD.get_args(rest_args)
            elif env=='ml1-reach-SD_LW':
                args = args_ml1_reach_SD_LW.get_args(rest_args)
            # ML1 Push variants  
            elif env=='ml1-push-SDVT':
                args = args_ml1_push_SDVT.get_args(rest_args)
            elif env=='ml1-push-SDVT_LW':
                args = args_ml1_push_SDVT_LW.get_args(rest_args)
            elif env=='ml1-push-SD':
                args = args_ml1_push_SD.get_args(rest_args)
            elif env=='ml1-push-SD_LW':
                args = args_ml1_push_SD_LW.get_args(rest_args)
            # ML1 Pickplace variants
            elif env=='ml1-pickplace-SDVT':
                args = args_ml1_pickplace_SDVT.get_args(rest_args)
            elif env=='ml1-pickplace-SDVT_LW':
                args = args_ml1_pickplace_SDVT_LW.get_args(rest_args)
            elif env=='ml1-pickplace-SD':
                args = args_ml1_pickplace_SD.get_args(rest_args)
            elif env=='ml1-pickplace-SD_LW':
                args = args_ml1_pickplace_SD_LW.get_args(rest_args)
            # ML1 Reachwall variants
            elif env=='ml1-reachwall-SDVT':
                args = args_ml1_reachwall_SDVT.get_args(rest_args)
            elif env=='ml1-reachwall-SDVT_LW':
                args = args_ml1_reachwall_SDVT_LW.get_args(rest_args)
            elif env=='ml1-reachwall-SD':
                args = args_ml1_reachwall_SD.get_args(rest_args)
            elif env=='ml1-reachwall-SD_LW':
                args = args_ml1_reachwall_SD_LW.get_args(rest_args)
            # ML10 
            elif env == 'ml10-SDVT':
                args = args_ml10_SDVT.get_args(rest_args)
            elif env == 'ml10-SDVT_LW':
                args = args_ml10_SDVT_LW.get_args(rest_args)
            elif env == 'ml10-SD':
                args = args_ml10_SD.get_args(rest_args)
            # ML45
            elif env == 'ml45-SDVT':
                args = args_ml45_SDVT.get_args(rest_args)
            elif env == 'ml45-SDVT_LW':
                args = args_ml45_SDVT_LW.get_args(rest_args)
            elif env == 'ml45-SD':
                args = args_ml45_SD.get_args(rest_args)
            elif env == 'ml45-SD_LW':
                args = args_ml45_SD_LW.get_args(rest_args)
            
            # DME variants (use dedicated configs)
            elif env == 'ml1-reach-DME':
                args = args_ml1_reach_DME.get_args(rest_args)
            elif env == 'ml1-push-DME':
                args = args_ml1_push_DME.get_args(rest_args)  # Note: using reach config for now since push/pickplace configs don't exist
            elif env == 'ml1-pickplace-DME':
                args = args_ml1_pickplace_DME.get_args(rest_args)  # Note: using reach config for now since push/pickplace configs don't exist
            elif env == 'ml1-reachwall-DME':
                args = args_ml1_reachwall_DME.get_args(rest_args)  # Note: using reach config for now since push/pickplace configs don't exist

            elif env == 'ml10-DME':
                args = args_ml10_DME.get_args(rest_args)
            elif env == 'ml45-DME':
                args = args_ml45_DME.get_args(rest_args)
            elif env == 'ml1-reach-DME-test':
                args = args_ml1_reach_DME_test.get_args(rest_args)
            elif env == 'ml10-DME-test':
                args = args_ml10_DME_test.get_args(rest_args)
            elif env == 'ml45-DME-test':
                args = args_ml45_DME_test.get_args(rest_args)
            
            # OLD
            elif env == 'ml45-VariBAD':
                args = args_ml45_VariBAD.get_args(rest_args)
            elif env == 'cheetah-vel-VariBAD':
                args = args_cheetah_vel_VariBAD.get_args(rest_args)
            args.load_dir = None
            args.load_iter = None
            args.env_type = env
        else:
            load_dir = args.load_dir
            load_iter = args.load_iter
            with open(load_dir + '/config.json', 'r') as f:
                args.__dict__ = json.load(f)
            args.load_dir = load_dir
            args.load_iter = load_iter
            args.env_type = env
        print(args)
    elif env in ['ml10-eval', 'ml45-eval']:
        load_dir = args.load_dir
        load_iter = args.load_iter
        render = args.render
        with open(load_dir + '/config.json', 'r') as f:
            args.__dict__ = json.load(f)
        args.load_dir = load_dir
        args.load_iter = load_iter
        args.render = render
        args.env_type = env
        print(args)
    else:
        raise Exception("Invalid Environment")

    # Update: log what types of (non-ml10/ml45) envs have distinct train/test sets
    if env in ['cheetah-vel-VariBAD']:
        args.has_test_tasks = False
    else:
        args.has_test_tasks = True
    
    
    # warning for deterministic execution
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError('If you want fully deterministic code, run it with num_processes=1.'
                               'Warning: This will slow things down and might break A2C if '
                               'policy_num_steps < env._max_episode_steps.')

    # if we're normalising the actions, we have to make sure that the env expects actions within [-1, 1]
    if args.norm_actions_pre_sampling or args.norm_actions_post_sampling:
        envs = make_vec_envs(env_name=args.env_name, seed=0, num_processes=args.num_processes,
                             gamma=args.policy_gamma, device='cpu',
                             episodes_per_task=args.max_rollouts_per_task,
                             normalise_rew=args.norm_rew_for_policy, ret_rms=None,
                             tasks=None,
                             )
        assert np.unique(envs.action_space.low) == [-1]
        assert np.unique(envs.action_space.high) == [1]

    # clean up arguments
    if args.disable_metalearner or args.disable_decoder:
        args.decode_reward = False
        args.decode_state = False
        args.decode_task = False

    if hasattr(args, 'decode_only_past') and args.decode_only_past:
        args.split_batches_by_elbo = True
    # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes:
    #     args.split_batches_by_elbo = True

    # begin training (loop through all passed seeds)
    seed_list = [args.seed] if isinstance(args.seed, int) else args.seed
    for seed in seed_list:
        print('training', seed)
        args.seed = seed
        args.action_space = None
        if env in ['ml10-SDVT', 'ml10-SDVT_LW', 'ml10-SD', 'ml10-SD_LW']:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerML10SDVT(args)
        elif env in ['ml45-SDVT', 'ml45-SDVT_LW', 'ml45-SD', 'ml45-SD_LW']:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerML45SDVT(args)
        elif env in ['ml1-reach-SDVT', 'ml1-reach-SDVT_LW', 'ml1-reach-SD', 'ml1-reach-SD_LW', 
                     'ml1-push-SDVT', 'ml1-push-SDVT_LW', 'ml1-push-SD', 'ml1-push-SD_LW', 
                     'ml1-pickplace-SDVT', 'ml1-pickplace-SDVT_LW', 'ml1-pickplace-SD', 'ml1-pickplace-SD_LW', 
                     'ml1-reachwall-SDVT', 'ml1-reachwall-SDVT_LW', 'ml1-reachwall-SD', 'ml1-reachwall-SD_LW', 
                     ]:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerSDVT(args)
        
        # DME
        elif env in ['ml10-DME',
                     'ml10-DME-test']:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerML10DME(args)
        elif env in ['ml45-DME',
                     'ml45-DME-test']:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerML45DME(args)
        elif env in ['ml1-reach-DME', 'ml1-push-DME', 
                     'ml1-pickplace-DME', 'ml1-reachwall-DME', 'ml1-reach-DME-test']:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerDME(args)
        
        
        # OLD
        elif env in ['ml10-VariBAD']:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerML10VariBAD(args)
        elif env == 'ml45-VariBAD':
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerML45VariBAD(args)
        elif env == 'ml10-eval':
            args.results_log_dir = args.results_log_dir + '_eval'
            learner = MetaEvalML10(args)
        elif env == 'ml45-eval':
            args.results_log_dir = args.results_log_dir + '_eval'
            learner = MetaEvalML45(args)
        elif env in ['cheetah-vel-VariBAD']:
            args.results_log_dir = args.results_log_dir
            learner = MetaLearnerVariBAD(args)

        elif args.disable_metalearner:
            # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`.
            # This is a stripped down version without encoder, decoder, stochastic latent variables, etc.
            learner = Learner(args)
        else:
            raise Exception("Invalid Environment")
            #learner = MetaLearner(args)
        learner.train()


if __name__ == '__main__':
    main()
