from sac import SAC_Agent
from modified_envs import HalfCheetahEnv, Walker2dEnv
from model import EnsembleGymEnv

import torch
import gym
from gym.wrappers import TimeLimit
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import d4rl
import copy
import sklearn
import random
import sklearn.linear_model as linear_model
import argparse

from utils import check_or_make_folder

dict = {
    'halfcheetah-medium-replay-v0': {
        'env_name': 'halfcheetah-medium-replay-v0',
        'model_dir': './checkpoints/model_saved_weights/Model_halfcheetah-medium-replay-v0_seed5',
    },
    'halfcheetah-medium-v0': {
        'env_name': 'halfcheetah-medium-v0',
        'model_dir': './data/2021_01_21_05-58-45/checkpoints/model_saved_weights/Model_halfcheetah-medium-v0_seed0_2021_01_21_05-58-50',
    },
    'halfcheetah-random-v0': {
        'env_name': 'halfcheetah-random-v0',
        'model_dir': './data/2021_01_21_05-59-09/checkpoints/model_saved_weights/Model_halfcheetah-random-v0_seed0_2021_01_21_05-59-15',
    },
    'halfcheetah-medium-expert-v0': {
        'env_name': 'halfcheetah-medium-expert-v0',
        'model_dir': './data/2021_01_21_05-59-27/checkpoints/model_saved_weights/Model_halfcheetah-medium-expert-v0_seed0_2021_01_21_05-59-34',
    },
    ## WALKER FROM HERE ON
    'walker2d-medium-replay-v0': {
        'env_name': 'walker2d-medium-replay-v0',
        'model_dir': './data/2021_01_21_05-58-30/checkpoints/model_saved_weights/Model_walker2d-medium-replay-v0_seed0_2021_01_21_05-58-34',
    },
    'walker2d-medium-v0': {
        'env_name': 'walker2d-medium-v0',
        'model_dir': './data/2021_01_28_09-07-02/checkpoints/model_saved_weights/Model_walker2d-medium-v0_seed0_2021_01_28_09-07-09',
    },
    'walker2d-random-v0': {
        'env_name': 'walker2d-random-v0',
        'model_dir': './data/2021_01_28_09-07-09/checkpoints/model_saved_weights/Model_walker2d-random-v0_seed0_2021_01_28_09-07-14',
    },
    'walker2d-medium-expert-v0': {
        'env_name': 'walker2d-medium-expert-v0',
        'model_dir': './data/2021_01_21_05-59-29/checkpoints/model_saved_weights/Model_walker2d-medium-expert-v0_seed0_2021_01_21_05-59-35',
    },
}


class MujocoModelEnv(EnsembleGymEnv):
    def __init__(self, env_dict):
        params = {
            'seed': 0,
            'env_name': env_dict['env_name'],
            'num_models': 7,
            'num_elites': 5,
            'reward_head': True,
            'logvar_head': True,
            'is_done_func': None,
            'train_memory': 8e5,
            'val_memory': 2e5,
            'train_val_ratio': 0.2,
        }

        if not 'medium-replay' in env_dict['env_name']:
            params['train_memory'] = 2000000
            params['val_memory'] = 500000

        env = gym.make(params['env_name'])
        eval_env = gym.make(params['env_name'])

        params['ob_dim'] = env.observation_space.shape[0]
        params['ac_dim'] = env.action_space.shape[0]

        self.ob_dim = params['ob_dim']
        self.ac_dim = params['ac_dim']

        super(MujocoModelEnv, self).__init__(params, env, eval_env)

        if env_dict['model_dir'] != '':
            self.model.load_model(
                env_dict['model_dir']
            )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', type=str, default='halfcheetah-medium-replay-v0')
    parser.add_argument('--seed', type=int, default=15)
    args = parser.parse_args()

    print(f'Starting on env={args.env_name}.')

    # Torch RNG
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # Python RNG
    np.random.seed(args.seed)
    random.seed(args.seed)

    e_dict = dict[args.env_name]

    model = MujocoModelEnv(e_dict)

    print(model)
