import gym
import datetime
import numpy as np
import ENV.env
SEED = None
  
VARIANT = {
    # 'env_name': 'FetchReach-v1',
    # 'env_name': 'Antcost-v0',
    # 'env_name': 'oscillator',
    # 'env_name': 'MJS1',
    # 'env_name': 'pendulum',
    # 'env_name': 'pendulum_discrete',
    'env_name': 'cartpole_cost',
    # 'env_name': 'cartpole_discrete',
    # 'env_name': 'linear_sys',
    # 'env_name': 'oscillator_complicated',
    # 'env_name': 'HalfCheetahcost-v0',
    # 'env_name': 'cartpole_cost',
    #training prams
    # 'algorithm_name': 'LAC',
    # 'algorithm_name': 'SPPO',
    # 'algorithm_name': 'L_REINFORCE_discrete',
    # 'algorithm_name': 'L_REINFORCE',
    # 'algorithm_name': 'CPO',
    'algorithm_name': 'SAC_cost',
    # 'additional_description': '-1-1-1-1-1-abs-horizon-250-alpha3=-real-1-lya-in-loss-adjust-alpha',
    # 'additional_description': '-64-64',
    'additional_description': '',

    # 'additional_description': '-horizon=5-alpha3=.1',
    # 'additional_description': '-short-T',
    # 'additional_description': '-cost_as_lyapunov',
    # 'additional_description': '-1-10-10-reinfoce-with-baseline-pure-reward-alpha3=1',

    # 'additional_description': '-1-10-10-reinforce-100-step',
    # 'additional_description': '-1-10-10-discount_lyapunov-100-step',
    # 'evaluate': False,
    'train': True,
    # 'train': False,

    'num_of_trials': 10,   # number of random seeds
    'store_last_n_paths': 10,  # number of trajectories for evaluation during training
    'start_of_trial': 0,

    #evaluation params
    # 'evaluation_form': 'constant_impulse',
    'evaluation_form': 'dynamic',
    # 'evaluation_form': 'impulse',
    # 'evaluation_form': 'various_disturbance',
    # 'evaluation_form': 'param_variation',
    # 'evaluation_form': 'trained_disturber',
    'eval_list': [
        # 'L_REINFORCE_discrete-reinforce-baseline',
        # 'L_REINFORCE_discrete-lyapunov-baseline',
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-10',
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-10-horizon-100-alpha3=.5', #1, 6, 7, 8
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-10-horizon-100-alpha3=.1', # 2, 5
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-10-horizon-100-alpha3=.05', # 1, 3, 5
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-10-horizon-100-retry', #1
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-2-10-horizon-100-alpha3=.05', #3
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-2-1-1-10-horizon-100-alpha3=1', #7, 3, 4, 1
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-2-10-horizon-100-alpha3=1', #9, 8, 7,  4, 2,
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-2-10-horizon-100-alpha3=.05', #1, 2, 3, 7
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-10-init--1-1-horizon-100-alpha3=.05', #9, 7*, 6, 5 1
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-10-10-init-[-1-1]-horizon-100-alpha3=.05', # 4, 7
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-10-10-large-init-horizon-100-alpha3=.05', # 7, 8
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-10-large-init-horizon-100-alpha3=.05-horizon=20', #4 2 6

        # 'L_REINFORCE_discrete-lyapunov-baseline-1-20-1-horizon-250-alpha3=1-die-minus-cost', # 0* 1 2 3 4 5
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-20-1-horizon-250-alpha3=1-die-minus-cost-10', #0 1 2 3 4
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-20-1-horizon-250-alpha3=.05-die-minus-cost', # 1

        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-1-horizon-250-alpha3=1-die-minus-cost-5', #9*
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-1-1-1-horizon-250-alpha3=1-die-minus-cost-10', # 1*
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-1-1-horizon-250-alpha3=1-die-minus-cost-10-adjust-alpha',# 0 3 4 5* 6 7*-500 ***
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-2-1-horizon-250-alpha3=1-die-minus-cost-10-adjust-alpha', #8, 7*,4,3,2
        # 'L_REINFORCE_discrete-lyapunov-baseline-1-5-1-horizon-250-alpha3=1-die-minus-cost-10-adjust-alpha', #5*-250
        # 'SAC_cost',
        'SAC',
    ],
    'trials_for_eval': [str(i) for i in range(0, 1)],

    'evaluation_frequency': 2048,
}
if VARIANT['algorithm_name'] == 'RARL':
    ITA = 0
VARIANT['log_path']='/'.join(['./log', VARIANT['env_name'], VARIANT['algorithm_name'] + VARIANT['additional_description']])

ENV_PARAMS = {
    'cartpole_cost': {
        'max_ep_steps': 500,
        'max_global_steps': int(1e6),
        'max_episodes': int(1e6),
        'disturbance dim': 1,
        'eval_render': False,},
    'cartpole_discrete': {
        'max_ep_steps': 500,
        'max_global_steps': int(1e6),
        'max_episodes': int(1e6),
        'eval_render': True,},
    'pendulum': {
        'max_ep_steps': 100,
        'max_global_steps': int(1e6),
        'max_episodes': int(1e6),
        'eval_render': True,},
    'pendulum_discrete': {
        'max_ep_steps': 1000,
        'max_global_steps': int(1e6),
        'max_episodes': int(1e6),
        'eval_render': True,},

    'linear_sys': {
        'max_ep_steps': 100,
        'max_global_steps': int(5e5),
        'max_episodes': int(5e5),
        'disturbance dim': 1,
        'eval_render': False,},

}
ALG_PARAMS = {
    'MPC':{
        'horizon': 5,
    },

    'LQR':{
        'use_Kalman': False,
    },

    'L_REINFORCE': {
        'N_path_num': 50,
        'c_bar': 1000,
        'batch_size': 5000,
        'use_soft_clip': False,
        'weight_of_s_norm': 0.01,
        'train_per_cycle':1,
        'labda': 1.,
        'alpha3': 1,
        'gamma': 0.995,
        'EPSILON':0.2,
        'tau': 5e-3,
        'lr_a': 1e-2,
        'lr_l': 3e-2,
        'epsilon': 1e-4,
        # 'gamma': 0.75,
        # 'use_lyapunov': True,
        # 'finite_horizon': False,
        'target_horizon': 10,
        'discounted_value_as_Lyapunov':True,
        'cost_as_Lyapunov':False,
        'history_horizon': 0,  # 0 is using current state only,
    },

    'L_REINFORCE_discrete': {
        'N_path_num': 50,
        'evaluation_N':50,
        'batch_size': 5000,
        'c_bar': 1000,
        'use_soft_clip': False,
        'weight_of_s_norm': 0.01,
        'train_per_cycle':1,
        'gamma': 0.995,
        'labda': 1.,
        'alpha3': 1,
        'tau': 5e-2,
        'lr_a': 1e-2,
        'lr_l': 1e-2,
        'epsilon':1e-4,
        # 'gamma': 0.75,
        # 'use_lyapunov': True,
        # 'finite_horizon': False,
        'target_horizon': 20,
        'constant_baseline': 10,
        'cost_as_Lyapunov':False,
        'discounted_value_as_Lyapunov':True,
        'use_simple_l':False,
        'update_l_to_minimize_delta': False,
        'history_horizon': 0,  # 0 is using current state only,
    },

    'DDPG': {
        'memory_capacity': int(1e6),
        'cons_memory_capacity': int(1e6),
        'min_memory_size': 1000,
        'batch_size': 256,
        'labda': 1.,
        'alpha3': 0.001,
        'tau': 5e-3,
        'noise': 1.,
        'lr_a': 3e-4,
        'lr_c': 3e-4,
        'gamma': 0.99,
        'steps_per_cycle': 100,
        'train_per_cycle': 80,
        'history_horizon': 0,  # 0 is using current state only
        },
    'SAC_cost': {
        'iter_of_actor_train_per_epoch': 50,
        'iter_of_disturber_train_per_epoch': 50,
        'memory_capacity': int(1e6),
        'cons_memory_capacity': int(1e6),
        'min_memory_size': 1000,
        'batch_size': 256,
        'labda': 1.,
        'alpha': 1.,
        'alpha3': 0.5,
        'tau': 5e-3,
        'lr_a': 1e-4,
        'lr_c': 3e-4,
        'lr_l': 3e-4,
        'gamma': 0.995,
        # 'gamma': 0.75,
        'steps_per_cycle': 100,
        'train_per_cycle': 50,
        'use_lyapunov': False,
        'adaptive_alpha': True,
        'target_entropy': None,

    },
    # 'SPPO': {
    #     'batch_size':10000,
    #     'output_format':['csv'],
    #     'gae_lamda':0.95,
    #     'safety_gae_lamda':0.5,
    #     'labda': 1.,
    #     'number_of_trajectory':10,
    #     'alpha3': 0.1,
    #     'lr_c': 1e-3,
    #     'lr_a': 1e-4,
    #     'gamma': 0.995,
    #     'cliprange':0.2,
    #     'delta':0.01,
    #     'd_0': 1,
    #     'form_of_lyapunov': 'l_reward',
    #     'safety_threshold': 0.,
    #     'use_lyapunov': False,
    #     'use_adaptive_alpha3': False,
    #     'use_baseline':False,
    #     },
    'SPPO': {
        'batch_size':2000,
        'output_format':['csv'],
        'gae_lamda':0.95,
        'safety_gae_lamda':0.95,
        'labda': 1.,
        'number_of_trajectory':50,
        'alpha3': 0.1,
        'lr_c': 3e-4,
        'lr_a': 1e-4,
        'lr_l': 1e-4,
        'gamma': 0.995,
        'cliprange':0.2,
        'delta':0.01,
        # 'd_0': 1,
        'finite_horizon':False,
        'horizon': 5,
        'form_of_lyapunov': 'l_reward',
        'safety_threshold': 10.,
        'use_lyapunov': False,
        'use_adaptive_alpha3': False,
        'use_baseline':False,
        },
}


EVAL_PARAMS = {
    'param_variation': {
        'param_variables': {
            'mass_of_pole': np.arange(0.05, 0.55, 0.05),  # 0.1
            'length_of_pole': np.arange(0.1, 2.1, 0.1),  # 0.5
            'mass_of_cart': np.arange(0.1, 2.1, 0.1),    # 1.0
            # 'gravity': np.arange(9, 10.1, 0.1),  # 0.1

        },
        'grid_eval': True,
        # 'grid_eval': False,
        'grid_eval_param': ['length_of_pole', 'mass_of_cart'],
        'num_of_paths': 100,   # number of path for evaluation
    },
    'impulse': {
        # 'magnitude_range': np.arange(150, 160, 5),
        'magnitude_range': np.arange(80, 155, 5),
        # 'magnitude_range': np.arange(80, 155, 10),
        # 'magnitude_range': np.arange(0.1, 1.1, .1),
        'num_of_paths': 100,   # number of path for evaluation
        'impulse_instant': 200,
    },
    'constant_impulse': {
        # 'magnitude_range': np.arange(120, 125, 5),
        # 'magnitude_range': np.arange(80, 155, 5),
        # 'magnitude_range': np.arange(80, 155, 5),
        # 'magnitude_range': np.arange(80, 155, 5),
        # 'magnitude_range': np.arange(0.2, 2.2, .2),
        # 'magnitude_range': np.arange(0.1, 1.0, .1),
        'num_of_paths': 20,   # number of path for evaluation
        'impulse_instant': 20,
    },
    'various_disturbance': {
        'form': ['sin', 'tri_wave'][0],
        'period_list': np.arange(2, 11, 1),
        # 'magnitude': np.array([1, 1, 1, 1, 1, 1]),
        'magnitude': np.array([80]),
        # 'grid_eval': False,
        'num_of_paths': 100,   # number of path for evaluation
    },
    'trained_disturber': {
        # 'magnitude_range': np.arange(80, 125, 5),
        # 'path': './log/cartpole_cost/RLAC-full-noise-v2/0/',
        'path': './log/HalfCheetahcost-v0/RLAC-horizon=inf-dis=.1/0/',
        'num_of_paths': 100,   # number of path for evaluation
    },
    'dynamic': {
        'additional_description': 'original',
        'num_of_paths': 100,   # number of path for evaluation
        'init_x': np.linspace(-4, 4, 10),
        'init_theta': np.linspace(-.1, .1, 10),
        # 'plot_average': True,
        'plot_average': True,
        'directly_show': True,
    },
}
VARIANT['env_params']=ENV_PARAMS[VARIANT['env_name']]
VARIANT['eval_params']=EVAL_PARAMS[VARIANT['evaluation_form']]
VARIANT['alg_params']=ALG_PARAMS[VARIANT['algorithm_name']]

RENDER = True
def get_env_from_name(name):
    if name == 'cartpole_cost':
        from envs.ENV_V1 import CartPoleEnv_adv as dreamer
        env = dreamer()
        env = env.unwrapped
    elif name =='cartpole_discrete':
        from envs.cartpole_discrete import CartPoleEnv as dreamer
        env = dreamer()
        env = env.unwrapped
    elif name == 'pendulum':
        from envs.pendulum import PendulumEnv as dreamer
        env = dreamer()
        env = env.unwrapped
    elif name == 'pendulum_discrete':
        from envs.pendulum_discrete import PendulumEnv as dreamer
        env = dreamer()
        env = env.unwrapped
    elif name == 'linear_sys':
        from envs.linear_sys import linear_sys as env
        env = env()
        env = env.unwrapped

    else:
        env = gym.make(name)
        env = env.unwrapped
        if name == 'Quadrotorcons-v0':
            if 'CPO' not in VARIANT['algorithm_name']:
                env.modify_action_scale = False
        if 'Fetch' in name or 'Hand' in name:
            env.unwrapped.reward_type = 'dense'
    env.seed(SEED)
    return env

def get_train(name):
    if 'RARL' in name:
        from algorithm.RARL import train as train
    elif 'LAC' in name:
        from algorithm.LAC_V1 import train
    elif 'SPPO' in name:
        from CPO.CPO2 import train
    elif 'DDPG' in name:
        from algorithm.SDDPG_V8 import train
    elif 'REINFORCE' in name:
        if 'discrete' in name:
            from algorithm.L_REINFORCE_discrete import train
        else:
            from algorithm.L_REINFORCE import train
    # elif 'CPO' in name:
    #     from CPO.CPO2 import train
    else:
        from algorithm.SAC_cost import train

    return train

def get_policy(name):
    if 'REINFORCE' in name:
        if 'discrete' in name:
            from algorithm.L_REINFORCE_discrete import L_REINFORCE as build_func
        else:
            from algorithm.L_REINFORCE import L_REINFORCE as build_func
    elif 'LAC' in name :
        from algorithm.LAC_V1 import LAC as build_func
    elif 'LQR' in name:
        from algorithm.lqr import LQR as build_func
    elif 'MPC' in name:
        from algorithm.MPC import MPC as build_func
    elif 'SPPO' in name:
        from CPO.CPO2 import CPO as build_func
    # elif 'CPO' in name:
    #     from CPO.CPO2 import CPO as build_func
    elif 'DDPG' in name:
        from algorithm.SDDPG_V8 import SDDPG as build_func
    else:
        from algorithm.SAC_cost import SAC_cost as build_func
    return build_func

def get_eval(name):
    if 'LAC' in name or 'SAC_cost' in name:
        from algorithm.LAC_V1 import eval

    return eval


