import gym
import torch
import time
import numpy as np
from abc import ABC, abstractmethod
from mpi4py import MPI

from core.archives.archive import Archive
from core.archives.map_elites_archive import MapElitesArchive
from core.observation_wrapper import create_wrapper, behavior_func_point_maze, behavior_func_point_maze_inertia, behavior_func_ant, behavior_func_humanoid
from core.novelty.euclidian_novelty import EuclidianNovelty
from core.utils.replay_buffer import ReplayBuffer
from core.rl_agents.td3 import TD3
from core.selection.map_elites_selection import MapElitesSelection

comm = MPI.COMM_WORLD
num_workers = comm.Get_size()
rank = comm.Get_rank()


class QDStrategy(ABC):
    def __init__(self, config):
        self.config = config
        self.device = config.xp_params.device

        # How to characterize a solution in the behavior space
        self.behavior_extraction_method = config.behavior_extraction_method

        # Initialize the replay buffer
        self.buffer = ReplayBuffer(config.buffer_size)
        self.buffer_per_worker = ReplayBuffer(config.buffer_size)

        self.batch_size = config.batch_size
        self.pop_size = config.xp_params.num_workers
        self.novelty_size = config.xp_params.novelty_size

        # Initialize fitnesses
        self.fitness_pop_novelty = None
        self.fitness_pop_quality = None
        self.fitness_pop = None
        self.novelty_score = None

        # Create env
        self.env = self.create_env(config.env_name)
        # set seeds
        self.env.seed(config.xp_params.seed)
        torch.manual_seed(config.xp_params.seed)
        np.random.seed(config.xp_params.seed)

        self._init_archive(config)
        self._init_agent(config)
        self._init_selection_method(config)
        self._init_novelty_method(config)

        # Initialization of the population
        if rank == 0:
            self.pop_actors_params = [self.create_agent().actor.get_params() for _ in range(self.pop_size)]
        else:
            self.pop_actors_params = None
            self.critics_pop_params = None

        self.gradient_steps = 0
        self.total_steps = 0
        self.log_steps = 0
        self.save_model_steps = 0
        self.gen = 0

    def log_fitness(self):
        assert rank == 0, "has to be used with the master"
        return self.fitness_pop_novelty, self.fitness_pop_quality, self.fitness_pop

    @abstractmethod
    def train_one_generation(self):
        """Train the selected population, usually called at each generation"""
        pass

    def evaluate_one_actor(self, fill_memory=True, n_episodes=1, deterministic=False, render=False):
        """Play episode(s), fill replay buffer if needed"""
        scores = []
        steps = 0
        device = next(self.agent.actor.parameters()).device
        behaviors = []

        for _ in range(n_episodes):
            score = 0
            obs_raw = self.env.reset()
            obs = obs_raw["obs"] if type(obs_raw) == dict else obs_raw
            done = False
            while not done:
                # get next action and act
                observation = torch.tensor(obs.reshape(1, -1), dtype=torch.float32, device=device)
                action = self.agent.actor.select_action(observation, deterministic=deterministic)
                n_obs_raw, reward, done, info = self.env.step(action)
                behavior = n_obs_raw["behavior"]
                behaviors.append(behavior)
                n_obs = n_obs_raw["obs"] if type(n_obs_raw) == dict else n_obs_raw
                done_bool = 0 if steps + 1 == self.env._max_episode_steps else float(done)
                score += reward
                steps += 1
                # adding in memory
                if fill_memory:
                    self.buffer_per_worker.add((obs, n_obs, action, reward, done_bool, behavior))
                obs = n_obs

                # render if needed
                if render and rank == 0:
                    self.env.render()

            scores.append(score)
        return np.mean(scores), steps, np.array(behaviors)

    def extract_behavior(self, behaviors):
        """Takes a trajectory of state behaviors and return the solution behavior according to the chosen method"""
        if self.behavior_extraction_method == "last_behavior":
            return self.extract_last_behavior(behaviors)
        elif self.behavior_extraction_method == "best_behavior":
            return self.extract_best_behavior(behaviors)

    def extract_best_behavior(self, behaviors):
        """Extract solution behavior as the most novel from a trajectory"""
        novelty_scores = self.novelty_method.compute_novelty_score(behaviors)
        argmax = np.argmax(novelty_scores)
        best_behavior = behaviors[argmax]
        return best_behavior

    def extract_last_behavior(self, behaviors):
        """Extract solution behavior as the last from a trajectory"""
        return behaviors[-1]

    def create_env(self, name):
        """Create environment"""
        if "PointMaze-v0" in name:
            WrappedEnvCls = create_wrapper(behavior_func_point_maze)
            env = WrappedEnvCls("PointMaze-v0")
        elif "PointMazeMujoco-v0" in name:
            WrappedEnvCls = create_wrapper(behavior_func_point_maze)
            env = WrappedEnvCls("PointMazeMujoco-v0")
        elif "PointMazeInertiaMujoco-v0" in name:
            WrappedEnvCls = create_wrapper(behavior_func_point_maze_inertia)
            env = WrappedEnvCls("PointMazeInertiaMujoco-v0")
        elif name == "AntMaze-v0":
            WrappedEnvCls = create_wrapper(behavior_func_ant)
            env = WrappedEnvCls("AntMaze-v0")
        elif name == "AntTrap-v0":
            WrappedEnvCls = create_wrapper(behavior_func_ant)
            env = WrappedEnvCls("AntTrap-v0")
        elif name == "HopperQD-v0":
            WrappedEnvCls = create_wrapper(behavior_func_ant)
            env = WrappedEnvCls("HopperQD-v0")
        elif name == "HumanoidTrapMAES-v1":
            WrappedEnvCls = create_wrapper(behavior_func_humanoid)
            env = WrappedEnvCls("HumanoidTrapMAES-v1")
        else:
            env = gym.make(name)
        return env

    def _init_archive(self, config):
        """Create the archive containing solutions"""
        archive_params = config.archive_params
        if archive_params.archive == "map_elites":
            self.archive = MapElitesArchive(config.map_elites_params, capacity=archive_params.capacity, k=archive_params.knn, threshold=archive_params.threshold)
        else:
            self.archive = Archive(capacity=archive_params.capacity, k=archive_params.knn)

    def _init_selection_method(self, config):
        """Initialize the selection mechanism"""
        if rank == 0:
            if "MapElites" == config.evolution_strategy_name:
                self.evolution_strategy = MapElitesSelection(self.archive, config.map_elites_params)
            else:
                raise NotImplementedError

    def _init_agent(self, config):
        """Initialize the RL agent"""
        assert config.agent_name in ["TD3", "SAC", "DiscreteSAC"], "{} agent not implemented yet."
        if config.agent_name == "TD3":
            agent = TD3
        elif config.agent_name == "SAC":
            agent = SAC
        elif config.agent_name == "DiscreteSAC":
            agent = DiscreteSAC
        config.rl_agent_params.device = self.device
        self.create_agent = lambda: agent(self.env.observation_space, self.env.action_space, config.rl_agent_params)
        self.agent = self.create_agent()

    def _init_novelty_method(self, config):
        """Initialize the method to compute novelty scores"""
        self.novelty_method_name = config.novelty_method_params.name
        assert self.novelty_method_name in ["euclidian_distance", "RND"], "novelty method not implemented yet."
        if "RND" in self.novelty_method_name:
            config.obs_space = self.env.observation_space
            self.novelty_method = RndNovelty(self, config)
        elif "euclidian_distance" in self.novelty_method_name:
            self.novelty_method = EuclidianNovelty(self, config)
