import os
import imageio
import copy
import gym
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common import logger
import numpy as np
from stable_baselines3.common import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, SubprocVecEnv
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.callbacks import BaseCallback
import optimal_agents
from optimal_agents.envs import GraphDummyVecEnv, GraphSubprocVecEnv
from optimal_agents.utils.loader import get_paths, get_env, get_alg, get_policy, get_morphology
from optimal_agents.utils.loader import load_from_name, ModelParams, FINETUNES
from optimal_agents.utils.tester import eval_policy

class TrainCallback(BaseCallback):
    """
    Callback for saving a model (the check is done every ``check_freq`` steps)
    based on the training reward (in practice, we recommend using ``EvalCallback``).

    :param check_freq: (int)
    :param log_dir: (str) Path to the folder where the model will be saved.
      It must contains the file created by the ``Monitor`` wrapper.
    :param verbose: (int)
    """
    def __init__(self, params, data_dir, tb_dir=None, verbose=1, eval_env=False):
        super(TrainCallback, self).__init__(verbose)
        self.params = params
        self.checkpoint_freq = params['checkpoint_freq']
        self.eval_env = eval_env
        self.eval_freq = params['eval_freq']
        self.data_dir = data_dir
        self.tb_dir = tb_dir
        self.best_mean_reward = -np.inf
        self.save_path = os.path.join(data_dir, 'best_model')
    
    def _on_step(self) -> bool:
        # NOTE: can add custom tensorboard callbacks here if wanted to log extra materials

        if self.n_calls % self.eval_freq == 0:
            if self.eval_env:
                # procedure for eval when we load the env.
                env = get_env(self.params)
                mean_reward, _ = eval_policy(self.model, env, num_ep=100, deterministic=True, verbose=0, gif=False)
                env.close()
                del env # make sure we free the memory
            else:
                # Retrieve training reward
                x, y = ts2xy(load_results(self.data_dir), 'timesteps')
                if len(x) > 0:
                    # Mean training reward over the last 100 episodes
                    mean_reward = np.mean(y[-100:])
                else:
                    mean_reward = -np.inf
            if self.verbose > 0:
                print("Num timesteps: {}".format(self.num_timesteps))
                print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(self.best_mean_reward, mean_reward))
            
            # New best model, you could save the agent here
            if mean_reward > self.best_mean_reward:
                self.best_mean_reward = mean_reward
                # Example for saving best model
                if self.verbose > 0:
                    print("Saving new best model.")
                self.model.save(self.data_dir + '/best_model')

        if self.checkpoint_freq and self.n_calls % self.checkpoint_freq == 0:
            if self.verbose > 0:
                print("Saving Checkpoint for timestep", self.num_timesteps)
            self.model.save(self.data_dir + 'checkpoint_' + str(self.num_timesteps))

        return True

def run_train(params, model=None, env=None, morphology=None, path=None, verbose=1):
    if verbose > 0:
        print("Training Parameters: ", params)

    data_dir, tb_path = get_paths(params, path=path)
    os.makedirs(data_dir, exist_ok=True)
    
    # Currently saving params immediatly
    # TODO: Figure out where to save params later for the purpose of
    params.save(data_dir)

    if morphology is None:
        morphology = get_morphology(params)
        morphology = [morphology for _ in range(params['num_proc'])]
    if not isinstance(morphology, list):
        morphology = [morphology for _ in range(params['num_proc'])]
    
    # Create the environment if not given
    if env is None:  
        def make_env(i):
            env = get_env(params, morphology=morphology[i]) # Might be issues with same morphology in vec env but not sure.
            info_keywords = tuple() # tuple(['success',])
            env = Monitor(env, data_dir + '/' + str(i), allow_early_resets=params['early_reset'], info_keywords=info_keywords)
            return env

        if params['alg'] in ("GPPO", "ModelBased"):
            if params['num_proc'] > 4:
                env = GraphSubprocVecEnv([(lambda n: lambda: make_env(n))(i) for i in range(params['num_proc'])])
            else:
                env = GraphDummyVecEnv([(lambda n: lambda: make_env(n))(i) for i in range(params['num_proc'])])
        else:
            if params['num_proc'] > 4:
                env = SubprocVecEnv([(lambda n: lambda: make_env(n))(i) for i in range(params['num_proc'])])
            else:
                env = DummyVecEnv([(lambda n: lambda: make_env(n))(i) for i in range(params['num_proc'])])

        if params['normalize']:
            env = VecNormalize(env)

    for i, m in enumerate(morphology):
        if not m is None:
            m.save(data_dir + '/' + str(i) + '.morphology.pkl')
    
    # Set the seeds
    if params['seed']:
        seed = params['seed']
        set_random_seed(seed)
        params['alg_args']['seed'] = seed
    
    if 'noise' in params and params['noise']:
        raise NotImplementedError("Noise not yet implemented.")

    if model is None:
        alg = get_alg(params)
        policy = get_policy(params)
        model = alg(policy,  env, verbose=verbose, tensorboard_log=tb_path, policy_kwargs=params['policy_args'], **params['alg_args'])
    else:
        model.set_env(env)

    if verbose > 0:
        print("\n===============================\n")
        print("TENSORBOARD PATH:", tb_path)
        print("\n===============================\n")

    callback = TrainCallback(params, data_dir, tb_path, verbose=verbose)

    model.learn(total_timesteps=params['timesteps'], log_interval=params['log_interval'], 
                callback=callback)
    
    model.save(data_dir +'/final_model')
    
    if params['normalize']:
        env.save(data_dir + '/environment.pkl')        
    env.close()
    del env
    
    if params['save_gif']:
        import imageio
        test_env = get_env(params, morphology=morphology)
        _, frames = eval_policy(model, test_env, num_ep=1, deterministic=True, verbose=1, gif=True, render=True)
        render_path = os.path.join(data_dir, 'final_render.gif')
        print("Saving gif to", render_path)
        imageio.mimsave(render_path, frames[::3], subrectangles=True, duration=0.05)
        test_env.close()

    # Return the model
    return model, data_dir

def train(params, path=None):
    # To be used for advanced routing later if many variants of training arise.
    if not params['evo_alg'] is None:
        evo_alg = vars(optimal_agents.algs)[params['evo_alg']]
        evo_model = evo_alg(params, **params['evo_alg_args'])
        evo_model.learn(path, params['population_size'], params['num_generations'])
    else:
        if 'finetune' in params and not params['finetune'] is None:
            if not params['finetune'].startswith('/'):
                finetune_path = os.path.join(FINETUNES, params['finetune'])
            else:
                finetune_path = params['finetune']
            print("Loading Model from", finetune_path)
            model, _ = load_from_name(finetune_path, best=False, load_env=False, alg_args=params["alg_args"])
        else:
            model = None
        run_train(params=params, path=path, model=model)
    