import numpy as np
import time
from mpi4py import MPI
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import distance
from core.novelty.novelty import Novelty
from core.utils.miscellaneous import euclidian_distances, filter_close_points


comm = MPI.COMM_WORLD
num_workers = comm.Get_size()
rank = comm.Get_rank()


class EuclidianNovelty(Novelty):
    """Store all behaviors (filtered to eliminate too close ones) and compute novelty score as the Euclidian distance"""
    def __init__(self, qd_strategy, config):
        super(EuclidianNovelty, self).__init__(qd_strategy, config)
        self.container = {"behaviors": []}
        self.k = config.archive_params.knn
        self.neigh = NearestNeighbors(n_neighbors=self.k)
        self.threshold = config.archive_params.threshold
        self.init = 0
        self.gen = 1
        self.novelty_scale_ratio = config.novelty_method_params.novelty_scale_ratio
        self.percentage_added_behaviors = 0

    def add(self, behavior):
        """Add behavior to the storage. Currently unused to the benefit of self.add_vectorized"""
        if not self.init:
            self.container["behaviors"].append(behavior)
            return
        neigh_dist, neigh_ind = self.neigh.kneighbors(X=behavior.reshape(1, -1), n_neighbors=1, return_distance=True)
        neigh_dist, neigh_ind = float(neigh_dist), int(neigh_ind)
        # Far enough from its closest neighbor
        if neigh_dist > self.threshold:
            self.container["behaviors"].append(behavior)

    def add_vectorized(self, behavior):
        """Add behavior to the storage in an efficient, vectorized manner"""
        if not self.init:
            self.container["behaviors"] += [*behavior]
            return

        if behavior.ndim == 1:
            behavior = behavior.reshape(-1, 1)
        neigh_dist, neigh_ind = self.neigh.kneighbors(X=behavior, n_neighbors=1, return_distance=True)
        # Far enough from closest neighbor
        kept_behaviors = behavior[neigh_dist.flatten() > self.threshold]
        self.percentage_added_behaviors = kept_behaviors.shape[0]/behavior.shape[0]
        self.container["behaviors"] += [*kept_behaviors]

    def update_nearest_neighbors(self):
        """Update the knn model used to compute distances"""
        self.init = 1
        behaviors = np.array(self.container["behaviors"])
        if behaviors.ndim == 1:
            behaviors = behaviors.reshape(-1, 1)
        self.neigh.fit(behaviors)
        self.gen += 1

    def compute_novelty_score(self, behavior, *args):
        """Compute the novelty score of a behavior as its avg Euclidian distance to its k-nearest neighbors"""
        if not self.init:
            return np.zeros(behavior.shape[0])
        if behavior.ndim == 1:
            behavior = behavior.reshape(-1, 1)
        # knn + euclidean distance in the behavior space
        neigh_dist, neigh_ind = self.neigh.kneighbors(X=behavior,
                                                      n_neighbors=min(self.neigh.n_samples_fit_, self.k),
                                                      return_distance=True)
        # Square distance used in QD-ES and Lehman/Stanley
        novelty = np.mean(neigh_dist**2, axis=1)
        return novelty*self.novelty_scale_ratio

    def update(self):
        """Filter and add trajectories of behaviors to the storage"""
        worker_behaviors = self.qd_strategy.state_BDs
        # First filter by thread
        worker_behaviors = filter_close_points(np.float32(worker_behaviors), self.threshold)
        all_behaviors = comm.allgather(worker_behaviors)
        all_behaviors = np.concatenate(all_behaviors, axis=0)
        # Then filter the aggregated result to reduce computation time
        if rank == 0:
            filtered_behaviors = filter_close_points(np.float32(all_behaviors), self.threshold)
        else:
            filtered_behaviors = None
        filtered_behaviors = comm.bcast(filtered_behaviors, root=0)
        # Add filtered behaviors
        self.add_vectorized(filtered_behaviors)
        self.update_nearest_neighbors()
        # Update nearest neighbors of the solution archive too
        if rank == 0:
            self.qd_strategy.archive.update_nearest_neighbors()
