import numpy as np
import torch
import time
from mpi4py import MPI

from core.qd_strategies.qd_strategy import QDStrategy
from core.utils.mpi_utils import sync_buffers, sync_networks, sync_archive, gather_global_data

comm = MPI.COMM_WORLD
num_workers = comm.Get_size()
rank = comm.Get_rank()


class MapElitesRL(QDStrategy):
    def __init__(self, config):
        super(MapElitesRL, self).__init__(config)

        if rank > self.novelty_size - 1:
            color = 1
        else:
            color = 0

        self.splitted_comm = comm.Split(color, rank)

        self.gradient_steps_ratio = config.qd_strategy_params.gradient_steps_ratio
        self.policy_delay = config.qd_strategy_params.policy_delay

        # Buffer 32/64 bits
        self.buffer_32 = config.qd_strategy_params.buffer_32

        # Sync the critic between all threads
        sync_networks(self.agent.critic, self.splitted_comm)
        sync_networks(self.agent.critic_target, self.splitted_comm)

        # LR
        self.novelty_lr = config.qd_strategy_params.novelty_lr
        self.quality_lr = config.qd_strategy_params.quality_lr

        # Grid evaluation frequency
        self.grid_evaluation_frequency = config.xp_params.eval_freq
        self.map_elites_eval_fitness = 0

        self.global_behavior = None
        self.global_behaviors = None
        self.global_from_novelty = None
        self.state_BDs = None

        self.global_q_quality = 0
        self.global_q_novelty = 0

        self.best_solutions = []
        self.current_best_fitness = 0
        self.current_best_behavior = 0

    def train_one_generation(self):
        # Get one set of actor params from pop and set this worker actor params
        params_per_worker = comm.scatter(self.pop_actors_params, root=0)

        self.agent.actor.set_params(params_per_worker)
        if hasattr(self.agent, 'actor_target'):
            self.agent.actor_target.set_params(params_per_worker)

        # Train neural nets
        self._train_agent()

        # Evaluate actor and fill replay buffer
        fitness, steps, behaviors = self.evaluate_one_actor()
        self.state_BDs = behaviors  # Keep last trajectory of state BDs
        behavior = self.extract_behavior(behaviors)

        # Gather data
        reward_novelty = self.novelty_method.compute_novelty_score(behavior=behavior.reshape(1, -1))
        from_novelty = rank < self.novelty_size
        # Share data among workers
        global_fitness, global_steps, global_novelty_score = gather_global_data(comm, fitness, steps, reward_novelty)
        self.global_behaviors = comm.gather(behaviors, root=0)
        self.global_behavior = comm.gather(behavior, root=0)
        self.global_from_novelty = comm.gather(from_novelty, root=0)
        # Update counters
        self.gradient_steps = int(np.max(global_steps) * self.gradient_steps_ratio)
        global_steps = int(np.sum(global_steps))
        self.total_steps += global_steps
        self.log_steps += global_steps
        self.save_model_steps += global_steps
        # Update es and rl fitnesses
        global_fitness = global_fitness.squeeze()
        self.fitness_pop_novelty = global_fitness[:self.novelty_size] if self.novelty_size > 1 else [0]
        self.fitness_pop_quality = global_fitness[self.novelty_size:] if self.novelty_size < num_workers else [0]
        self.fitness_pop = global_fitness
        self.novelty_score = global_novelty_score.squeeze()

        # Archive treatment
        all_params = comm.allgather(self.agent.actor.get_params())
        all_behaviors = comm.allgather(behavior)
        all_fitnesses = comm.allgather(fitness)
        all_from_novelties = comm.allgather(from_novelty)
        if rank == 0:
            for ind in range(len(all_params)):
                self.archive.add(all_params[ind], all_behaviors[ind], all_fitnesses[ind], all_from_novelties[ind])

        # Novelty method update
        self.novelty_method.update()

        # Synchronize replay buffers between workers
        sync_buffers(self.buffer_per_worker, self.buffer, self.buffer_32)

        self.gen += 1

        # keep track of the best and latest solutions from the grid and evaluate them
        if rank == 0:
            self.best_solutions.insert(0, self.archive.get_best())
            if len(self.best_solutions) > num_workers: self.best_solutions = self.best_solutions[:num_workers]
        if self.grid_evaluation_frequency % self.gen == 0:
            self._best_solutions_evaluation()

        if rank == 0:
            print('-----------------------------')
            print("Steps {} Max Fitness {}".format(self.total_steps, np.max(self.fitness_pop)), flush=True)
            print('-----------------------------')

    def _best_solutions_evaluation(self):
        best_solutions = comm.bcast(self.best_solutions, root=0)
        _params = self.agent.actor.get_params()

        if len(best_solutions) > rank:
            self.agent.actor.set_params(best_solutions[rank])
        else:
            self.agent.actor.set_params(best_solutions[0])
        fitness, steps, behavior = self.evaluate_one_actor(fill_memory=False, deterministic=True)
        behavior = self.extract_behavior(behavior)
        behaviors_eval = comm.gather(behavior, root=0)
        fitnesses_eval = comm.gather(fitness, root=0)
        if rank == 0:
            best_ind = np.argmax(fitnesses_eval)
            self.current_best_fitness = fitnesses_eval[best_ind]
            self.current_best_behavior = behaviors_eval[best_ind]
        self.agent.actor.set_params(_params)  # restore initial params

    def _train_agent(self):
        # Reset optimizer
        lr = self.quality_lr if rank > self.novelty_size - 1 else self.novelty_lr
        self.agent.reset_optimizer(lr=lr)

        # Quality
        if rank > self.novelty_size - 1:
            sync_grads_bool = True
            # Update the critic and the actor
            for i in range(self.gradient_steps):
                batch = self.buffer.sample(batch_size=self.batch_size)
                self.agent.update_critic(batch, sync_grads_bool=sync_grads_bool, comm=self.splitted_comm)
                if i % self.policy_delay == 0:
                    # Update the actor
                    self.agent.update_actor(batch)
        # Novelty
        else:
            sync_grads_bool = True
            # Update the critic and the actor
            for i in range(self.gradient_steps):
                batch = self.buffer.sample(batch_size=self.batch_size)
                # Replace the environment reward with novelty rewards
                if batch:
                    novelties = self.novelty_method.compute_novelty_score(batch["behaviors"])
                    batch["rewards"] = novelties
                self.agent.update_critic(batch, sync_grads_bool=sync_grads_bool, comm=self.splitted_comm)
                if i % self.policy_delay == 0:
                    # Update the actor
                    self.agent.update_actor(batch)

    def draw_generation(self):
        self.pop_actors_params = self.evolution_strategy.select_random_uniform(self.pop_size)

    def update(self):
        pass

    def log_behavior(self):
        assert rank == 0, "has to be used with the master"

        best_fitness = self.current_best_fitness
        best_behaviors = self.current_best_behavior
        self.global_behavior.append(np.array(best_behaviors))
        self.fitness_pop = np.append(self.fitness_pop, best_fitness)
        self.global_from_novelty.append(False)

        length = len(self.global_behavior)

        percentage_filled_grid = sum(
            x is not None for x in self.archive.cells) / len(
            self.archive.cells)
        grid_coverage = np.vstack([percentage_filled_grid] * length)

        gen_repeat = np.vstack([self.gen] * length)
        total_steps_repeat = np.vstack([self.total_steps] * length)
        global_behavior_stacked = np.vstack(self.global_behavior)
        global_fitness_stacked = np.vstack(self.fitness_pop)
        global_from_novelty_stacked = np.vstack(self.global_from_novelty)


        log = np.hstack((gen_repeat, total_steps_repeat, global_behavior_stacked,
                         global_fitness_stacked, global_from_novelty_stacked, grid_coverage))
        return log
