import os
import optimal_agents
import json
from datetime import date
import stable_baselines3

BASE = os.path.dirname(os.path.dirname(optimal_agents.__file__)) + '/data'
LOGS = os.path.dirname(os.path.dirname(optimal_agents.__file__)) + '/tb_logs'
FINETUNES = os.path.dirname(os.path.dirname(optimal_agents.__file__)) + '/finetunes'

class ModelParams(dict):

    def __init__(self, env : str, alg : str, wrapper=None):
        super(ModelParams, self).__init__()
        # Construction Specification
        self['alg'] = alg
        self['env'] = env
        self['env_wrapper'] = wrapper
        self['morphology'] = None
        self['arena'] = None
        self['policy'] = None
        self['use_her'] = False
        # Arg Dicts
        self['env_args'] = dict()
        self['env_wrapper_args'] = dict()
        self['alg_args'] = dict()
        self['policy_args'] = dict()
        self['morphology_args'] = dict()
        self['node_args'] = dict()
        self['mutation_args'] = dict()
        self['arena_args'] = dict()
        # Env Wrapper Arguments
        self['early_reset'] = True
        self['normalize'] = False
        # Training Args
        self['seed'] = None
        self['timesteps'] = 250000
        # Logistical Args
        self['log_interval'] = 20
        self['name'] = None
        self['tensorboard'] = None
        self['num_proc'] = 1 # Default to single process
        self['eval_freq'] = 100000
        self['checkpoint_freq'] = None
        self['save_gif'] = False
        # Evolution Args
        self['evo_alg'] = None
        self['evo_alg_args'] = dict()

    def get_save_name(self) -> str:
        if self['name']:
            name =  self['name']
        else:
            name = self['env'] + ('_' + self['env_wrapper'] if self['env_wrapper'] else "") + '_' + self['alg']
        if not self['seed'] is None:
            name += '_s' + str(self['seed'])
        return name

    def save(self, path : str):
        if not path.endswith(".json"):
            path = os.path.join(path, 'params.json')
        with open(path, 'w') as fp:
                json.dump(self, fp, indent=4)

    @classmethod
    def load(cls, path):
        if not path.startswith('/'):
            path = os.path.join(BASE, path)
        if os.path.isdir(path) and 'params.json' in os.listdir(path):
            path = os.path.join(path, 'params.json')
        elif os.path.exists(path):
            pass
        else:
            raise ValueError("Params file not found in specified save directory:" + str(path))
        with open(path, 'r') as fp:
            data = json.load(fp)
        params = cls(data['env'], data['alg'])
        params.update(data)
        return params

def get_alg(params: ModelParams):
    alg_name = params['alg']
    try:
        alg = vars(optimal_agents.algs)[alg_name]
    except:
        alg = vars(stable_baselines3)[alg_name]
    return alg

def get_env(params: ModelParams, morphology=None, apply_wrapper=True):
    env_name = params['env']
    try:
        env_cls = vars(optimal_agents.envs)[params['env']]
        arena = get_arena(params)
        env = env_cls(morphology, arena=arena, **params['env_args'])
        if apply_wrapper and params['env_wrapper']:
            env = vars(optimal_agents.envs)[params['env_wrapper']](env, **params['env_wrapper_args'])
    except:
        # If we don't get the env, then we assume its a gym environment
        import gym
        env = gym.make(params['env'])
        if apply_wrapper and params['env_wrapper']:
            env = vars(gym.wrappers)[params['env_wrapper']](env, **params['env_wrapper_args'])
    return env    

def get_morphology(params: ModelParams):
    morphology_name = params['morphology']
    if morphology_name is None:
        return None
    try:
        # Try to load the morphology
        morphology = optimal_agents.morphology.Morphology.load(os.path.join(FINETUNES, morphology_name))
    except:
        morphology = vars(optimal_agents.morphology)[morphology_name](**params['morphology_args'], mutation_kwargs=params['mutation_args'], node_kwargs=params['node_args'])
    return morphology

def get_arena(params: ModelParams):
    arena_name = params['arena']
    if arena_name is None:
        return None
    arena = vars(optimal_agents.morphology.arenas)[arena_name](**params['arena_args'])
    return arena

def get_policy(params: ModelParams):
    policy_name = params['policy']
    if policy_name is None:
        policy_name = 'MlpPolicy' 
    try:
        policy = vars(optimal_agents.policies)[policy_name]
        return policy
    except:
        alg_name = params['alg']
        if 'SAC' in alg_name:
            search_location = stable_baselines3.sac.policies
        elif 'DDPG' in alg_name:
            search_location = stable_baselines3.ddpg.policies
        elif'DQN' in alg_name:
            search_location = stable_baselines3.deepq.policies
        elif 'TD3' in alg_name:
            search_location = stable_baselines3.td3.policies
        elif 'PPO' in alg_name:
            search_location = stable_baselines3.ppo.policies
        else:
            search_location = stable_baselines3.common.policies
        policy = vars(search_location)[policy_name]
        return policy
    
def get_paths(params: ModelParams, path=None):
    if path is None:
        date_dir = BASE
    else:
        date_dir = path

    save_name = params.get_save_name()
    if os.path.isdir(date_dir):
        candidates = [f_name for f_name in os.listdir(date_dir) if '_'.join(f_name.split('_')[:-1]) == save_name]
        if len(candidates) == 0:
            save_name += '_0'
        else:
            num = max([int(dirname[-1]) for dirname in candidates]) + 1
            save_name += '_' + str(num)
    else:
        save_name += '_0'
    
    save_path = os.path.join(date_dir, save_name)
    tb_path = os.path.join(LOGS, save_name) if params['tensorboard'] else None
    return save_path, tb_path

def load_from_name(path, best=False, load_env=True, ret_params=False, alg_args=None, morphology_index=0):
    if not path.startswith('/'):
        path = os.path.join(BASE, path)
    params = ModelParams.load(path)
    if ret_params:
        return load(path, params, best=best, load_env=load_env, alg_args=alg_args, morphology_index=morphology_index) + (params,)
    return load(path, params, best=best, load_env=load_env, alg_args=alg_args, morphology_index=morphology_index)

def load(path: str, params : ModelParams, best=False, load_env=True, alg_args=None, morphology_index=0):
    if not path.startswith('/'):
        path = os.path.join(BASE, path)
    files = os.listdir(path)
    if not 'final_model.zip' in files and 'best_model.zip' in files:
        model_path = path + '/best_model.zip'
    elif 'best_model.zip' in files and best:
        model_path = path + '/best_model.zip'
    elif 'final_model.zip' in files:
        model_path = path + '/final_model.zip'
    else:
        raise ValueError("Cannot find a model for name: " + path)
    # get model
    alg = get_alg(params) if not params['use_her'] else optimal_agents.algs.HER
    if alg_args is None:
        alg_args = params['alg_args']
    model = alg.load(model_path, **alg_args)
    if load_env:
        morphology_file_name = str(morphology_index) + ".morphology.pkl"
        if morphology_file_name in files:
            from optimal_agents.morphology import Morphology
            morphology = Morphology.load(os.path.join(path, morphology_file_name))
            print("Loaded", morphology_file_name)
        else:
            print("WARNING: Could not find morphology.")
            morphology = None
        env = get_env(params, morphology=morphology)
    else:
        env = None
    return model, env
