from sac import SAC_Agent
from modified_envs import HalfCheetahEnv, Walker2dEnv, HopperEnv
from model import EnsembleGymEnv

import torch
import gym
from gym.wrappers import TimeLimit
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import d4rl
import copy
import sklearn
import random
import sklearn.linear_model as linear_model
import argparse

from utils import check_or_make_folder

dict = {
    'halfcheetah-medium-replay-v0': {
        'env_name': 'halfcheetah-medium-replay-v0',
        'base_mopo_policy': [
            './data/checkpoints/model_saved_weights/Model_halfcheetah-medium-replay-v0_seed0_2020_12_31_21-30-54/torch_policy_weights_250_epochs.pt',
            './data/old/2021_02_01_06-48-56/model_saved_weights_seed10/torch_policy_weights_260_epochs_6259.pt',
            './data/old/2021_02_02_06-48-50/model_saved_weights_seed14/torch_policy_weights_240_epochs_6320.pt',
            './data/old/2021_02_02_06-48-54/model_saved_weights_seed12/torch_policy_weights_240_epochs_6687.pt',
        ],
        'context_mopo_policy': [
            './data/old/2021_01_23_03-12-08/model_saved_weights_seed1/torch_policy_weights_2440_epochs.pt',
            './data/old/2021_02_01_06-48-27/model_saved_weights_seed10/torch_policy_weights_900_epochs_6423.pt',
            './data/old/2021_02_02_06-47-35/model_saved_weights_seed12/torch_policy_weights_900_epochs_6417.pt',
        ],
        'model_dir': './checkpoints/model_saved_weights/Model_halfcheetah-medium-replay-v0_seed5',
    },
    'halfcheetah-medium-v0': {
        'env_name': 'halfcheetah-medium-v0',
        'base_mopo_policy': [
            './data/old/2021_01_28_09-12-10/model_saved_weights_seed0/torch_policy_weights_400_epochs_5537.pt',
            './data/old/2021_02_01_06-49-41/model_saved_weights_seed10/torch_policy_weights_400_epochs_5624.pt',
            './data/old/2021_02_02_06-49-40/model_saved_weights_seed14/torch_policy_weights_400_epochs_5216.pt',
            './data/old/2021_02_02_06-49-59/model_saved_weights_seed12/torch_policy_weights_400_epochs_5493.pt',
        ],
        'context_mopo_policy': [
            './data/old/2021_01_21_05-58-45/model_saved_weights_seed0/torch_policy_weights_900_epochs.pt',
            './data/old/2021_01_30_07-19-37/model_saved_weights_seed5/torch_policy_weights_900_epochs_5156.pt',
            './data/old/2021_02_01_06-49-17/model_saved_weights_seed10/torch_policy_weights_900_epochs_4864.pt',
            './data/old/2021_02_02_06-48-11/model_saved_weights_seed12/torch_policy_weights_900_epochs_5001.pt',
        ],
        'model_dir': './data/old/2021_01_21_05-58-45/checkpoints/model_saved_weights/Model_halfcheetah-medium-v0_seed0_2021_01_21_05-58-50',
    },
    'halfcheetah-random-v0': {
        'env_name': 'halfcheetah-random-v0',
        'base_mopo_policy': [
            './data/old/2021_01_28_09-12-20/model_saved_weights_seed0/torch_policy_weights_400_epochs_3594.pt',
            './data/old/2021_02_01_06-50-07/model_saved_weights_seed10/torch_policy_weights_400_epochs_3509.pt',
            './data/old/2021_02_02_06-49-18/model_saved_weights_seed12/torch_policy_weights_400_epochs_3327.pt',
            './data/old/2021_02_02_06-49-54/model_saved_weights_seed14/torch_policy_weights_400_epochs_3748.pt',
        ],
        'context_mopo_policy': [
            './data/old/2021_01_21_05-59-09/model_saved_weights_seed0/torch_policy_weights_900_epochs.pt',
            './data/old/2021_01_30_07-20-11/model_saved_weights_seed5/torch_policy_weights_900_epochs_4360.pt',
            './data/old/2021_02_01_06-49-54/model_saved_weights_seed10/torch_policy_weights_900_epochs_4408.pt',
            './data/old/2021_02_02_06-49-16/model_saved_weights_seed12/torch_policy_weights_900_epochs_3765.pt',
        ],
        'model_dir': './data/old/2021_01_21_05-59-09/checkpoints/model_saved_weights/Model_halfcheetah-random-v0_seed0_2021_01_21_05-59-15',
    },
    'halfcheetah-medium-expert-v0': {
        'env_name': 'halfcheetah-medium-expert-v0',
        'base_mopo_policy': [
            # './data/old/2021_01_28_09-12-56/model_saved_weights_seed0/torch_policy_weights_340_epochs_3161.pt',
            # './data/old/2021_02_01_06-50-56/model_saved_weights_seed10/torch_policy_weights_320_epochs_2338.pt',
            # './data/old/2021_02_02_06-50-20/model_saved_weights_seed14/torch_policy_weights_320_epochs_5464.pt',
            # './data/old/2021_02_02_06-50-56/model_saved_weights_seed12/torch_policy_weights_320_epochs_5807.pt',
            #
            './data/hc_medexp_newmodels_modelseed100_mopo/model_saved_weights_seed75/torch_policy_weights_700_epochs_11306.pt',
            './data/hc_medexp_newmodels_modelseed100_mopo/model_saved_weights_seed76/torch_policy_weights_700_epochs_11435.pt',
            './data/hc_medexp_newmodels_modelseed100_mopo/model_saved_weights_seed77/torch_policy_weights_700_epochs_12982.pt',
            './data/hc_medexp_newmodels_modelseed100_mopo/model_saved_weights_seed78/torch_policy_weights_700_epochs_1320.pt',
        ],
        'context_mopo_policy': [
            # './data/old/2021_01_21_05-59-27/model_saved_weights_seed0/torch_policy_weights_900_epochs.pt',
            # './data/old/2021_01_30_07-20-16/model_saved_weights_seed5/torch_policy_weights_160_epochs_5261.pt',
            # './data/old/2021_02_02_06-49-08/model_saved_weights_seed12/torch_policy_weights_320_epochs_6207.pt',
            #
            './data/hc_medexp_newmodels_modelseed100_awm25/model_saved_weights_seed75/torch_policy_weights_700_epochs_8401.pt',
            './data/hc_medexp_newmodels_modelseed100_awm25/model_saved_weights_seed76/torch_policy_weights_700_epochs_9993.pt',
            './data/hc_medexp_newmodels_modelseed100_awm25/model_saved_weights_seed77/torch_policy_weights_700_epochs_9963.pt',
            './data/hc_medexp_newmodels_modelseed100_awm25/model_saved_weights_seed78/torch_policy_weights_700_epochs_8621.pt',
        ],
        # 'model_dir': './Model_halfcheetah-medium-expert-v0_seed100_2021_03_18_20-56-11',

        'model_dir': './data/old/2021_01_21_05-59-27/checkpoints/model_saved_weights/Model_halfcheetah-medium-expert-v0_seed0_2021_01_21_05-59-34',
    },
    ## WALKER FROM HERE ON
    'walker2d-medium-replay-v0': {
        'env_name': 'walker2d-medium-replay-v0',
        'base_mopo_policy': [
            './data/old/2021_01_28_09-12-17/model_saved_weights_seed0/torch_policy_weights_260_epochs_1982.pt',
            './data/old/2021_02_01_06-49-05/model_saved_weights_seed10/torch_policy_weights_120_epochs_2884.pt',
            './data/old/2021_02_02_06-48-51/model_saved_weights_seed12/torch_policy_weights_260_epochs_1834.pt',
            './data/old/2021_02_02_06-48-56/model_saved_weights_seed14/torch_policy_weights_260_epochs_2477.pt',
        ],
        'context_mopo_policy': [
            './data/old/2021_01_21_05-58-30/model_saved_weights_seed0/torch_policy_weights_900_epochs.pt',
            './data/old/2021_01_30_07-19-06/model_saved_weights_seed5/torch_policy_weights_900_epochs_2602.pt',
            './data/old/2021_02_01_06-48-44/model_saved_weights_seed10/torch_policy_weights_1800_epochs_3563.pt',
            './data/old/2021_02_02_06-47-38/model_saved_weights_seed12/torch_policy_weights_1740_epochs_3261.pt',
        ],
        'model_dir': './data/old/2021_01_21_05-58-30/checkpoints/model_saved_weights/Model_walker2d-medium-replay-v0_seed0_2021_01_21_05-58-34',
    },
    'walker2d-medium-v0': {
        'env_name': 'walker2d-medium-v0',
        'base_mopo_policy': [
            './data/old/2021_01_30_07-39-20/model_saved_weights_seed0/torch_policy_weights_460_epochs_3571.pt',
            './data/old/2021_02_01_06-49-55/model_saved_weights_seed10/torch_policy_weights_460_epochs_2698.pt',
            './data/old/2021_02_02_06-49-34/model_saved_weights_seed14/torch_policy_weights_460_epochs_3571.pt',
            './data/old/2021_02_02_06-49-54/model_saved_weights_seed12/torch_policy_weights_460_epochs_3830.pt',
        ],
        'context_mopo_policy': [
            './data/old/2021_02_01_06-49-38/model_saved_weights_seed10/torch_policy_weights_1200_epochs_3418.pt',
            './data/old/2021_02_02_06-48-03/model_saved_weights_seed12/torch_policy_weights_900_epochs_3552.pt',
        ],
        'model_dir': './data/old/2021_01_28_09-07-02/checkpoints/model_saved_weights/Model_walker2d-medium-v0_seed0_2021_01_28_09-07-09',
    },
    'walker2d-random-v0': {
        'env_name': 'walker2d-random-v0',
        'base_mopo_policy': [
            './data/old/2021_01_28_09-13-04/model_saved_weights_seed0/torch_policy_weights_360_epochs_1007.pt',
            './data/old/2021_02_01_06-49-52/model_saved_weights_seed10/torch_policy_weights_360_epochs_1013.pt',
            './data/old/2021_02_02_06-49-18/model_saved_weights_seed12/torch_policy_weights_380_epochs_996.pt',
            './data/old/2021_02_02_06-49-31/model_saved_weights_seed14/torch_policy_weights_360_epochs_551.pt',
        ],
        'context_mopo_policy': [
            './data/old/2021_01_28_09-07-09/model_saved_weights_seed0/torch_policy_weights_440_epochs_925.pt',
            './data/old/2021_01_30_07-20-03/model_saved_weights_seed5/torch_policy_weights_880_epochs_1021.pt',
            './data/old/2021_02_01_06-49-45/model_saved_weights_seed10/torch_policy_weights_960_epochs_1028.pt',
            './data/old/2021_02_02_06-48-45/model_saved_weights_seed12/torch_policy_weights_980_epochs_1010.pt',
        ],
        'model_dir': './data/old/2021_01_28_09-07-09/checkpoints/model_saved_weights/Model_walker2d-random-v0_seed0_2021_01_28_09-07-14',
    },
    'walker2d-medium-expert-v0': {
        'env_name': 'walker2d-medium-expert-v0',
        'base_mopo_policy': [
            './data/old/2021_01_28_09-13-48/model_saved_weights_seed0/torch_policy_weights_80_epochs_378.pt',
            './data/old/2021_02_02_06-50-18/model_saved_weights_seed14/torch_policy_weights_360_epochs_2611.pt',
        ],
        'context_mopo_policy': [
            './data/old/2021_01_21_05-59-29/model_saved_weights_seed0/torch_policy_weights_900_epochs.pt',
            './data/old/2021_01_30_07-20-31/model_saved_weights_seed5/torch_policy_weights_900_epochs_4651.pt',
            './data/old/2021_02_02_06-48-49/model_saved_weights_seed12/torch_policy_weights_1000_epochs_4987.pt',
            './data/old/2021_02_01_06-50-40/model_saved_weights_seed10/torch_policy_weights_920_epochs_4721.pt',
        ],
        'model_dir': './data/old/2021_01_21_05-59-29/checkpoints/model_saved_weights/Model_walker2d-medium-expert-v0_seed0_2021_01_21_05-59-35',
    },

    # HOPPER NOW
    'hopper-medium-replay-v0': {
        'env_name': 'hopper-medium-replay-v0',
        'base_mopo_policy': [
            './data/hopper_mixed_mopo_save/model_saved_weights_seed73/torch_policy_weights_960_epochs_3185.pt',
            './data/hopper_mixed_mopo_save/model_saved_weights_seed0/torch_policy_weights_960_epochs_2837.pt',
        ],
        'context_mopo_policy': [
            './data/hopper_mixed_mopo_csac_save/model_saved_weights_seed73/torch_policy_weights_960_epochs_3165.pt',
            './data/hopper_mixed_mopo_csac_save/model_saved_weights_seed0/torch_policy_weights_960_epochs_3134.pt',
        ],
        'model_dir': './data/hopper_mixed/checkpoints/model_saved_weights/Model_hopper-medium-replay-v0_seed0_2021_03_18_02-05-16',
    },
    'hopper-medium-v0': {
        'env_name': 'hopper-medium-v0',
        'base_mopo_policy': [
            './data/hopper_med_mopo_save/model_saved_weights_seed0/torch_policy_weights_900_epochs_269.pt',
            './data/hopper_med_mopo_save/model_saved_weights_seed73/torch_policy_weights_900_epochs_718.pt',
        ],
        'context_mopo_policy': [
            './data/hopper_med_mopo_csac_save/model_saved_weights_seed0/torch_policy_weights_900_epochs_683.pt',
            './data/hopper_med_mopo_csac_save/model_saved_weights_seed73/torch_policy_weights_900_epochs_899.pt',
        ],
        'model_dir': './data/hopper_med/checkpoints/model_saved_weights/Model_hopper-medium-v0_seed0_2021_03_18_02-05-08',
    },
    'hopper-random-v0': {
        'env_name': 'hopper-random-v0',
        'base_mopo_policy': [
            './data/hopper_rand_mopo_save/model_saved_weights_seed0/torch_policy_weights_980_epochs_181.pt',
            './data/hopper_rand_mopo_save/model_saved_weights_seed73/torch_policy_weights_980_epochs_347.pt',
        ],
        'context_mopo_policy': [
            './data/hopper_rand_mopo_csac_save/model_saved_weights_seed0/torch_policy_weights_980_epochs_328.pt',
            './data/hopper_rand_mopo_csac_save/model_saved_weights_seed73/torch_policy_weights_980_epochs_295.pt',
        ],
        'model_dir': './data/hopper_rand/checkpoints/model_saved_weights/Model_hopper-random-v0_seed0_2021_03_18_02-05-14',
    },
    'hopper-medium-expert-v0': {
        'env_name': 'hopper-medium-expert-v0',
        'base_mopo_policy': [
            './data/hopper_medexp_mopo_save/model_saved_weights_seed0/torch_policy_weights_920_epochs_853.pt',
            './data/hopper_medexp_mopo_save/model_saved_weights_seed73/torch_policy_weights_920_epochs_923.pt',
        ],
        'context_mopo_policy': [
            './data/hopper_medexp_mopo_csac_save/model_saved_weights_seed0/torch_policy_weights_920_epochs_1659.pt',
            './data/hopper_medexp_mopo_csac_save/model_saved_weights_seed73/torch_policy_weights_920_epochs_342.pt',
        ],
        'model_dir': './data/hopper_medexp/checkpoints/model_saved_weights/Model_hopper-medium-expert-v0_seed0_2021_03_18_02-05-09',
    },
}


# 'walker2d-expert-v0': {
#     'env_name': 'walker2d-expert-v0',
#     'base_mopo_policy': [
#         './data/old/2021_01_28_09-13-07/model_saved_weights_seed0/torch_policy_weights_480_epochs_3103.pt',
#     ],
#     'context_mopo_policy': [
#         './data/old/2021_01_28_09-07-12/model_saved_weights_seed0/torch_policy_weights_300_epochs_3265.pt',
#     ],
#     'model_dir': './data/old/2021_01_28_09-07-12/checkpoints/model_saved_weights/Model_walker2d-expert-v0_seed0_2021_01_28_09-07-18',
# },

# 'halfcheetah-expert-v0': {
#     'env_name': 'halfcheetah-expert-v0',
#     'base_mopo_policy': [
#         './data/old/2021_01_28_09-12-11/model_saved_weights_seed0/torch_policy_weights_500_epochs_14757.pt',
#     ],
#     'context_mopo_policy': [
#         './data/old/2021_01_21_05-58-55/model_saved_weights_seed0/torch_policy_weights_900_epochs.pt',
#     ],
#     'model_dir': './data/old/2021_01_21_05-58-55/checkpoints/model_saved_weights/Model_halfcheetah-expert-v0_seed0_2021_01_21_05-59-00',
# },

# MED-EXP is too low
# MED needs to be 1,1

# 01-28 has some base mopo
# 01-30 is context
# 02-01 is all
# 02-02 is all


class MujocoModelEnv(EnsembleGymEnv):
    def __init__(self, env_dict):
        params = {
            'seed': 0,
            'env_name': env_dict['env_name'],
            'num_models': 7,
            'num_elites': 5,
            'reward_head': True,
            'logvar_head': True,
            'is_done_func': None,
            'train_memory': 8e5,
            'val_memory': 2e5,
            'train_val_ratio': 0.2,
            'tune_mopo_lam': False,
            'mopo': True,
            'mopo_lam': 1,
            'mopo_penalty_type': 'mopo_default',
        }

        if not 'medium-replay' in env_dict['env_name']:
            params['train_memory'] = 2000000
            params['val_memory'] = 500000

        env = gym.make(params['env_name'])
        eval_env = gym.make(params['env_name'])

        params['ob_dim'] = env.observation_space.shape[0]
        params['ac_dim'] = env.action_space.shape[0]

        self.ob_dim = params['ob_dim']
        self.ac_dim = params['ac_dim']

        super(MujocoModelEnv, self).__init__(params, env, eval_env)

        if env_dict['model_dir'] != '':
            self.model.load_model(
                env_dict['model_dir']
            )


def evaluate_agent(env, agent, n_starts=1, pad_state=False, clip_lb=0.95, clip_ub=1.05):
    agents = agent
    if not isinstance(agents, list):
        agents = [agents]
    reward_sum = 0
    agent_indivs = []
    for agent in agents:
        this_agent_rew = reward_sum
        for _ in range(n_starts):
            done = False
            state = env.reset()
            while not done:
                if pad_state:
                    state = np.concatenate((state, np.ones_like(state)))
                action = agent.get_action(state, deterministic=True)
                nextstate, reward, done, _ = env.step(action)
                reward_sum += reward
                state = nextstate
        this_agent_rew = reward_sum - this_agent_rew
        agent_indivs.append(this_agent_rew / n_starts)

    return reward_sum / (n_starts * len(agents)), agent_indivs


class LinearNextStateDeltaPredictor:
    def __init__(self, k, test_X=None, test_y=None):
        self.update_interval = 10
        self.k = k
        self.X = []
        self.y = []
        self.num_samples = 0
        self.reg = None

        self.test_X = test_X
        self.test_y = test_y

    def record(self, state, delta):
        self.X.append(state)
        self.y.append(delta)
        self.num_samples += 1

        if self.num_samples % self.update_interval == 0:
            X = np.array(self.X)
            y = np.array(self.y)
            self.reg = linear_model.LinearRegression().fit(X, y)
            print(f'Test R2 @ {self.num_samples} samples: {self.reg.score(self.test_X, self.test_y)}')

    def predict(self, state):
        if self.num_samples < self.k:
            raise Exception
        else:
            return self.reg.predict(state.reshape(1, -1))

    def test_r2(self, X, y):
        return self.reg.score(X, y)


def evaluate_agent_linear(env, agent, k, n_starts=1, pad_state=False, clip_lb=0.95, clip_ub=1.05, model=None,
                          linear_model_type='state-action', clip_type='hard_clip',
                          soft_clip_scale=0.5):
    agents = agent
    if not isinstance(agents, list):
        agents = [agents]
    reward_sum = 0
    agent_indivs = []
    for agent in agents:
        this_agent_rew = reward_sum

        # Test set
        Xt = []
        yt = []
        for _ in range(2):
            done = False
            state = env.reset()
            while not done:
                action = agent.get_action(np.concatenate((state, np.ones_like(state))), deterministic=False)
                nextstate, reward, done, _ = env.step(action)
                delta = nextstate - state
                if len(state) == 1:
                    state = state[0]
                if len(delta) == 1:
                    delta = delta[0]
                if linear_model_type == 'state-action':
                    Xt.append(np.concatenate((state, action)))
                else:
                    Xt.append(state)
                yt.append(delta)
                state = nextstate

        # Now the evaluation starts.
        for _ in range(n_starts):
            pred = LinearNextStateDeltaPredictor(test_X=Xt, test_y=yt, k=k)
            done = False
            state = env.reset()
            t = 0
            while not done:
                if pad_state:
                    if linear_model_type == 'state-action':
                        raise Exception
                    else:
                        if t > k:
                            delta = pred.predict(state)
                            action = agent.get_action(np.concatenate((state, np.ones_like(state))), deterministic=True)

                            nsm, _ = model.model.predict_state(state.reshape(1, -1), action.reshape(1, -1))
                            model_delta = nsm - state
                            division_factor = model_delta / delta

                            if len(division_factor) == 1:
                                division_factor = division_factor[0]

                            if clip_type == 'hard_clip':
                                division_factor = np.clip(division_factor, clip_lb, clip_ub)
                            else:
                                # soft clip
                                division_factor = np.clip(division_factor, 1 - soft_clip_scale, 1 + soft_clip_scale)
                                division_factor = (clip_ub - clip_lb) / (
                                        2 * soft_clip_scale) * division_factor + clip_lb - 0.5 * (clip_ub - clip_lb) * (
                                                          1 - soft_clip_scale) / soft_clip_scale
                            context = division_factor
                        else:
                            context = np.ones_like(state)
                    if len(context) == 1:
                        context = context[0]

                t += 1

                action = agent.get_action(np.concatenate((state, context)), deterministic=True)
                nextstate, reward, done, _ = env.step(action)
                reward_sum += reward

                pred.record(state, nextstate - state)

                state = nextstate

        this_agent_rew = reward_sum - this_agent_rew
        agent_indivs.append(this_agent_rew / n_starts)

    return reward_sum / (n_starts * len(agents)), agent_indivs


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', type=str, default='hopper-medium-replay-v0')
    parser.add_argument('--seed', type=int, default=15)
    args = parser.parse_args()

    print(f'Starting on env={args.env_name}.')

    # Torch RNG
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # Python RNG
    np.random.seed(args.seed)
    random.seed(args.seed)

    save_path = 'image_icml_rebuttal_hopper_500/'
    check_or_make_folder(save_path)
    save_path = save_path + args.env_name + '/'
    check_or_make_folder(save_path)

    e_dict = dict[args.env_name]

    model = MujocoModelEnv(e_dict)
    model.convert_filter_to_torch()

    do_base = False
    do_context1s = False
    do_linear_adapt = True
    
    k_val = 500

    # BASE
    base_policies = []
    for path in e_dict['base_mopo_policy']:
        policy_dict = torch.load(path)
        agent_HC_MOPO = SAC_Agent(args.seed, model.real_env.observation_space.shape[0],
                                  model.real_env.action_space.shape[0])
        agent_HC_MOPO.log_alpha = policy_dict['log_alpha_state_dict']
        agent_HC_MOPO.policy.load_state_dict(policy_dict['policy_state_dict'])
        agent_HC_MOPO.target_q_funcs.load_state_dict(policy_dict['target_double_q_state_dict'])
        agent_HC_MOPO.q_funcs.load_state_dict(policy_dict['double_q_state_dict'])
        base_policies.append(agent_HC_MOPO)

    # CONTEXT
    context_policies = []
    for path in e_dict['context_mopo_policy']:
        policy_dict = torch.load(path)
        agent_HC_MOPO = SAC_Agent(args.seed, model.real_env.observation_space.shape[0] * 2,
                                  model.real_env.action_space.shape[0])
        agent_HC_MOPO.log_alpha = policy_dict['log_alpha_state_dict']
        agent_HC_MOPO.policy.load_state_dict(policy_dict['policy_state_dict'])
        agent_HC_MOPO.target_q_funcs.load_state_dict(policy_dict['target_double_q_state_dict'])
        agent_HC_MOPO.q_funcs.load_state_dict(policy_dict['double_q_state_dict'])
        context_policies.append(agent_HC_MOPO)

    if 'halfcheetah' in args.env_name:
        settings = [0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75]
    elif 'walker' in args.env_name:
        settings = [0.5, 0.75, 1.0, 1.25, 1.50]
    elif 'hopper' in args.env_name:
        settings = [0.5, 0.75, 1.0, 1.25, 1.50]
    else:
        raise Exception

    n_starts = 1

    if do_base:
        # BASE RUNS
        for idx, bp in enumerate(base_policies):
            results = np.zeros((len(settings), len(settings)))
            for i, s1 in enumerate(settings):
                for j, s2 in enumerate(settings):
                    if 'halfcheetah' in args.env_name:
                        env = TimeLimit(HalfCheetahEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'walker' in args.env_name:
                        env = TimeLimit(Walker2dEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'hopper' in args.env_name:
                        env = TimeLimit(HopperEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    else:
                        raise Exception
                    p = evaluate_agent(env, [bp], n_starts=n_starts)[0]
                    results[i, j] = p

            results = pd.DataFrame(results, columns=settings, index=settings)
            fig = plt.figure(figsize=(11, 10))
            ax = plt.axes()
            sns.heatmap(results, annot=True, ax=ax)
            ax.set_title('Performance')
            ax.set_ylabel('Mass Scale')
            ax.set_xlabel('Damping Scale')
            plt.savefig(save_path + f'base_mopo_{int(results.mean().mean())}_{idx}.png')
            results.to_csv(save_path + f'base_mopo_{int(results.mean().mean())}_{idx}.csv')

    if do_context1s:
        # CONTEXT OF 1s
        for idx, cp in enumerate(context_policies):
            results_RAD_delta_base = np.zeros((len(settings), len(settings)))
            for i, s1 in enumerate(settings):
                for j, s2 in enumerate(settings):
                    if 'halfcheetah' in args.env_name:
                        env = TimeLimit(HalfCheetahEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'walker' in args.env_name:
                        env = TimeLimit(Walker2dEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'hopper' in args.env_name:
                        env = TimeLimit(HopperEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    else:
                        raise Exception
                    p = evaluate_agent(env, [cp], n_starts=n_starts, pad_state=True)[0]
                    results_RAD_delta_base[i, j] = p

            results_RAD_delta_base = pd.DataFrame(results_RAD_delta_base, columns=settings, index=settings)
            fig = plt.figure(figsize=(11, 10))
            ax = plt.axes()
            sns.heatmap(results_RAD_delta_base, annot=True, ax=ax)
            ax.set_title('Performance')
            ax.set_ylabel('Mass Scale')
            ax.set_xlabel('Damping Scale')
            plt.savefig(save_path + f'context_1s_{int(results_RAD_delta_base.mean().mean())}_{idx}.png')
            results_RAD_delta_base.to_csv(
                save_path + f'context_1s_{int(results_RAD_delta_base.mean().mean())}_{idx}.csv')

    if do_linear_adapt:
        # LINEAR FUN (1: HARD 07)
        for idx, cp in enumerate(context_policies):
            results_RAD_delta_linear07_hard = np.zeros((len(settings), len(settings)))
            for i, s1 in enumerate(settings):
                for j, s2 in enumerate(settings):
                    print(s1, s2)
                    if 'halfcheetah' in args.env_name:
                        env = TimeLimit(HalfCheetahEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'walker' in args.env_name:
                        env = TimeLimit(Walker2dEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'hopper' in args.env_name:
                        env = TimeLimit(HopperEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    else:
                        raise Exception
                    p = evaluate_agent_linear(env,
                                              [cp],
                                              n_starts=n_starts,
                                              pad_state=True,
                                              clip_lb=0.93,
                                              clip_ub=1.07,
                                              model=model,
                                              linear_model_type='state',
                                              clip_type='hard_clip',
                                              k=k_val,
                                              )[0]
                    results_RAD_delta_linear07_hard[i, j] = p

            results_RAD_delta_linear07_hard = pd.DataFrame(results_RAD_delta_linear07_hard, columns=settings,
                                                           index=settings)
            fig = plt.figure(figsize=(11, 10))
            ax = plt.axes()
            sns.heatmap(results_RAD_delta_linear07_hard, annot=True, ax=ax)
            ax.set_title('Performance')
            ax.set_ylabel('Mass Scale')
            ax.set_xlabel('Damping Scale')
            plt.savefig(save_path + f'linear_hard_07_{int(results_RAD_delta_linear07_hard.mean().mean())}_{idx}.png')
            results_RAD_delta_linear07_hard.to_csv(
                save_path + f'linear_hard_07_{int(results_RAD_delta_linear07_hard.mean().mean())}_{idx}.csv')
        1/0
        # LINEAR FUN (2: SOFT 07)
        for idx, cp in enumerate(context_policies):
            results_RAD_delta_linear07_soft = np.zeros((len(settings), len(settings)))
            for i, s1 in enumerate(settings):
                for j, s2 in enumerate(settings):
                    if 'halfcheetah' in args.env_name:
                        env = TimeLimit(HalfCheetahEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'walker' in args.env_name:
                        env = TimeLimit(Walker2dEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'hopper' in args.env_name:
                        env = TimeLimit(HopperEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    else:
                        raise Exception
                    p = evaluate_agent_linear(env,
                                              [cp],
                                              n_starts=n_starts,
                                              pad_state=True,
                                              clip_lb=0.93,
                                              clip_ub=1.07,
                                              model=model,
                                              linear_model_type='state',
                                              clip_type='soft_clip',
                                              k=k_val,
                                              )[0]
                    results_RAD_delta_linear07_soft[i, j] = p

            results_RAD_delta_linear07_soft = pd.DataFrame(results_RAD_delta_linear07_soft, columns=settings,
                                                           index=settings)
            fig = plt.figure(figsize=(11, 10))
            ax = plt.axes()
            sns.heatmap(results_RAD_delta_linear07_soft, annot=True, ax=ax)
            ax.set_title('Performance')
            ax.set_ylabel('Mass Scale')
            ax.set_xlabel('Damping Scale')
            plt.savefig(save_path + f'linear_soft_07_{int(results_RAD_delta_linear07_soft.mean().mean())}_{idx}.png')
            results_RAD_delta_linear07_soft.to_csv(
                save_path + f'linear_soft_07_{int(results_RAD_delta_linear07_soft.mean().mean())}_{idx}.csv')

        ###

        # LINEAR FUN (3: HARD 05)
        for idx, cp in enumerate(context_policies):
            results_RAD_delta_linear05_hard = np.zeros((len(settings), len(settings)))
            for i, s1 in enumerate(settings):
                for j, s2 in enumerate(settings):
                    if 'halfcheetah' in args.env_name:
                        env = TimeLimit(HalfCheetahEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'walker' in args.env_name:
                        env = TimeLimit(Walker2dEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'hopper' in args.env_name:
                        env = TimeLimit(HopperEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    else:
                        raise Exception
                    p = evaluate_agent_linear(env,
                                              [cp],
                                              n_starts=n_starts,
                                              pad_state=True,
                                              clip_lb=0.95,
                                              clip_ub=1.05,
                                              model=model,
                                              linear_model_type='state',
                                              clip_type='hard_clip',
                                              k=k_val,
                                              )[0]
                    results_RAD_delta_linear05_hard[i, j] = p

            results_RAD_delta_linear05_hard = pd.DataFrame(results_RAD_delta_linear05_hard, columns=settings,
                                                           index=settings)
            fig = plt.figure(figsize=(11, 10))
            ax = plt.axes()
            sns.heatmap(results_RAD_delta_linear05_hard, annot=True, ax=ax)
            ax.set_title('Performance')
            ax.set_ylabel('Mass Scale')
            ax.set_xlabel('Damping Scale')
            plt.savefig(save_path + f'linear_hard_05_{int(results_RAD_delta_linear05_hard.mean().mean())}_{idx}.png')
            results_RAD_delta_linear05_hard.to_csv(
                save_path + f'linear_hard_05_{int(results_RAD_delta_linear05_hard.mean().mean())}_{idx}.csv')

        # LINEAR FUN (4: SOFT 05)
        for idx, cp in enumerate(context_policies):
            results_RAD_delta_linear05_soft = np.zeros((len(settings), len(settings)))
            for i, s1 in enumerate(settings):
                for j, s2 in enumerate(settings):
                    if 'halfcheetah' in args.env_name:
                        env = TimeLimit(HalfCheetahEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'walker' in args.env_name:
                        env = TimeLimit(Walker2dEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    elif 'hopper' in args.env_name:
                        env = TimeLimit(HopperEnv(mass_scale_set=[s1], damping_scale_set=[s2]),
                                        max_episode_steps=1000)
                    else:
                        raise Exception
                    p = evaluate_agent_linear(env,
                                              [cp],
                                              n_starts=n_starts,
                                              pad_state=True,
                                              clip_lb=0.95,
                                              clip_ub=1.05,
                                              model=model,
                                              linear_model_type='state',
                                              clip_type='soft_clip',
                                              k=k_val,
                                              )[0]
                    results_RAD_delta_linear05_soft[i, j] = p

            results_RAD_delta_linear05_soft = pd.DataFrame(results_RAD_delta_linear05_soft, columns=settings,
                                                           index=settings)
            fig = plt.figure(figsize=(11, 10))
            ax = plt.axes()
            sns.heatmap(results_RAD_delta_linear05_soft, annot=True, ax=ax)
            ax.set_title('Performance')
            ax.set_ylabel('Mass Scale')
            ax.set_xlabel('Damping Scale')
            plt.savefig(save_path + f'linear_soft_05_{int(results_RAD_delta_linear05_soft.mean().mean())}_{idx}.png')
            results_RAD_delta_linear05_soft.to_csv(
                save_path + f'linear_soft_05_{int(results_RAD_delta_linear05_soft.mean().mean())}_{idx}.csv')
