# Made by: Giuseppe PAOLO
# Date: 4/13/2022

import numpy as np
from scipy.spatial.distance import jensenshannon
from collections import deque


class CvgGrid:
    """
    This class implements the grid used to calculate the CVG and UNIF
    """
    def __init__(self, bins, dimensions):
        """
        This data structure encodes the grid that is used to calculate the coverage
        :param bins: number of bins
        :param dimensions: [[minx, maxx], [miny, maxy], ...]
        """
        self.bins = bins
        self.dimensions = dimensions

        self.grid = self.init_grid()
        self.filled_tracker = []  # Everytime a new cell is filled, the eval step is added to the list
        self.cell_lims = [np.linspace(dim[0], dim[1], num=self.bins+1) for dim in self.dimensions]

    def reset(self):
        """
        This function resets the archive to an empty state
        """
        del self.data
        self.grid = self.init_grid()

    def __len__(self):
        """
        Returns the length of the archive
        """
        return self.size

    @property
    def size(self):
        """
        Size of the archive
        """
        return np.sum(self.grid)
    
    @property
    def shape(self):
        """
        Size of the archive
        """
        return self.grid.shape

    def init_grid(self):
        # TODO look for a way to make it more memory efficient through indexing
        return np.full([self.bins] * len(self.dimensions), fill_value=0, dtype=np.ushort)

    def store(self, point):
        """
        Store data in the archive as a list of: (genome, gt_bd, bd, traj).
        No need to store the ID given that we store the genome.
        Saving as a tuple instead of a dict makes the append operation faster

        It also checks if the grid cell is already occupied. In case it is, saves the one with the highest fitness

        :param agent: agent to store
        :return:
        """
        assert len(point) == len(self.dimensions), \
            print('Point of wrong shape. Given: {} - Expected: {}'.format(len(point), len(self.dimensions)))
        cell = self._find_cell(point)
        # Increase cell count by one
        self.grid[cell] += 1

    def _find_cell(self, bd):
        """
        This function finds in which cell the given BD belongs
        :param bd:
        :return:
        """
        cell_idx = []
        for dim_idx, dim in enumerate(self.dimensions):
            assert dim[1] >= bd[dim_idx] >= dim[0], \
                print("BD outside of grid. BD: {} - Bottom Limits: {} - Upper Limits: {}".format(bd, dim[0], dim[1]))

            # The max() is there so if we are at the bottom border the cell counts as the first
            # Remove 1 for indexing starting at 0
            cell_idx.append(max(np.argmax(self.cell_lims[dim_idx] >= bd[dim_idx]), 1) - 1)
        return tuple(cell_idx)


def calculate_coverage(occupied_grid):
  """
  This function calculated the coverage percentage from the grid
  :param occupied_grid
  :return:
  """
  coverage = np.count_nonzero(occupied_grid)/occupied_grid.size
  return coverage


def calculate_uniformity(grid):
  """
  This function calculates the uniformity of the normed grid, that is the histogram
  :param normed_grid
  :return:
  """
  normed_grid = grid/np.sum(grid)
  uniform_grid = np.ones_like(normed_grid)/normed_grid.size
  return 1-jensenshannon(normed_grid.flatten(), uniform_grid.flatten())


def fast_non_dominated_sort(values1, values2, best=False):
    """
    This function sorts the non dominated elements according to the values of the 2 objectives.
    Taken from https://github.com/haris989/NSGA-II
    :param values1: Values of first obj
    :param values2: Values of second obj
    :param best: if True only returns the best front without calculating the others
    :return: Sorted list of indexes
    """
    S = [[]] * len(values1)
    front = [[]]
    n = [0] * len(values1)
    rank = [0] * len(values1)

    for p in range(len(values1)):
        S[p] = []
        n[p] = 0
        for q in range(len(values1)):
            if (values1[p] > values1[q] and values2[p] > values2[q]) or (
                    values1[p] >= values1[q] and values2[p] > values2[q]) or (
                    values1[p] > values1[q] and values2[p] >= values2[q]):
                if q not in S[p]: S[p].append(q)
            elif (values1[q] > values1[p] and values2[q] > values2[p]) or (
                    values1[q] >= values1[p] and values2[q] > values2[p]) or (
                    values1[q] > values1[p] and values2[q] >= values2[p]):
                n[p] += 1
        if n[p] == 0:
            rank[p] = 0
            if p not in front[0]:
                front[0].append(p)

    if best:
        return front[0]

    i = 0
    while front[i]:
        Q = []
        for p in front[i]:
            for q in S[p]:
                n[q] -= 1
                if n[q] == 0:
                    rank[q] = i + 1
                    if q not in Q:
                        Q.append(q)
        i += 1
        front.append(Q)
    del front[-1]
    return front

if __name__ == "__main__":
    dimensions = [[0, 1], [0, 1], [0,1]]
    grid = CvgGrid(bins=10, dimensions=dimensions)

