from copy import copy

import matplotlib.pyplot as plt
import numba
import numpy as np
from scipy.spatial import KDTree

from active_ranking import config
from active_ranking.base.ucb_lcb import ucb_index, lcb_index


# Class for node of bipartite graph (1D)
@numba.experimental.jitclass
class Node:
    def __init__(self, index=None):
        self.L = index[0]
        self.M = index[1]
        self.R = index[2]
        self.width = index[2] - index[0]
        self.is_leaf = (self.R == self.L + 1)

    def left_child(self):
        if self.is_leaf:
            return self
        else:
            return Node([self.L, (self.L + self.M) // 2, self.M])

    def right_child(self):
        if self.is_leaf:
            return self
        else:
            return Node([self.M, (self.M + self.R) // 2, self.R])

    def value(self):
        return [self.L, self.M, self.R]


class Cell:
    def __init__(self, id_: str, size, centers, t=1):
        self.id = id_
        self.y = np.array([])
        self.x = np.array([])
        self.size = size
        self.average_size = size
        self.n = 0
        self.centers = centers
        self.t = t
        self.__value = np.nan
        self.__ucb = 1
        self.__lcb = 0

    @property
    def value(self) -> float:
        return self.__value

    @property
    def lcb(self) -> float:
        return self.__lcb

    @property
    def ucb(self) -> float:
        return self.__ucb

    def set_value(self, fun, exploration_parameter):
        self.x, self.y = self.get_labels()
        self.__compute_nb_label()
        self.__value = fun(self.x, self.y)

        self.__ucb = ucb_index(self.__value, exploration_parameter, self.n)
        self.__lcb = lcb_index(self.__value, exploration_parameter, self.n)
        return self.value

    def add_labels(self, x: np.ndarray, y: np.array):
        if len(self.x) == 0:
            self.x = x
            self.y = y
        else:
            self.x = np.concatenate((x, self.x))
            self.y = np.concatenate((y, self.y))
        self.__compute_nb_label()

    def get_labels(self):
        return self.x, self.y

    def __compute_nb_label(self):
        if len(self.y) == self.x.shape[0]:
            self.n = len(self.y)
        else:
            raise AssertionError

    def __add__(self, other):
        cell = Cell(
            f"{self.id}+{other.id}",
            self.size + other.size,
            self.centers + other.centers,
        )

        def get_label__(cls, one):
            def fun():
                x1, y1 = one.get_labels()
                x2, y2 = cls.get_labels()
                x = np.concatenate((x1, x2))
                y = np.concatenate((y1, y2))
                return x, y

            return fun

        cell.get_labels = get_label__(self, other)  # FIXME not very pythonic
        self.__compute_nb_label()
        return cell

    # UTILS AND PLOTS
    def plot_rectangle(self, linewidth=1, edgecolor='r', facecolor='none'):
        import matplotlib.patches as patches
        import matplotlib.pyplot as plt
        d = 2
        j_max = - np.log2(self.size) / d

        a = 1 / 2 ** (j_max + 1)

        for c in self.centers:
            rect = patches.Rectangle(
                (c[0] - a, c[1] - a),
                2 * a, 2 * a, linewidth=linewidth, edgecolor=edgecolor,
                facecolor=facecolor)
            plt.gca().add_patch(rect)

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        return f"<Cell object " \
               f"(id:{self.id})" \
               f"(value:{np.round(self.value, 2)})" \
               f"(lcb:{np.round(self.__lcb, 2)})" \
               f"(ucb:{np.round(self.__ucb, 2)})" \
               f"(n:{self.n})>"


class Partition:
    def __init__(self, j_max, d):
        self.j_max = j_max
        self.d = d
        ln_space = np.linspace(0, 1, 2 ** self.j_max, endpoint=False)
        self.grid = (1 / 2 ** (self.j_max + 1) + ln_space,) * d
        self._centroids = np.meshgrid(*self.grid)
        self.centroids = np.vstack(tuple(map(np.ravel, self._centroids))).T
        self._tree = KDTree(self.centroids)
        self.track_cells = {0: {}}
        self.step = 0
        self.min_partition()
        self.x = np.array([])
        self.y = np.array([])
        self.loc = np.array([])
        self.cells_mapping = {str(i): str(i) for i in
                              range(len(self.current_cells()))}

    def min_partition(self):
        """
        Define first partition as the finest grid given by j_max parameter
        """
        self.p_cells = {}
        for i in range(len(self.centroids)):
            id_cell = str(i)
            id_cell = "0" * (-len(id_cell) + len(
                str(2 ** (self.j_max * self.d)))) + id_cell
            cell = Cell(id_cell, 1 / (2 ** (self.j_max * self.d)),
                        [self.centroids[i]])
            self.p_cells[cell.id] = cell

        self.track_cells[0] = self.p_cells
        self.id_p_cells = list(self.p_cells.keys())
        self.id_c_cells = list(self.track_cells[self.step].keys())

        self.positions = {}
        for c in self.p_cells.values():
            self.positions[c.id] = np.searchsorted(self.grid[0], c.centers[0])

    def check_volume(self, step=0):
        v = 0
        for _, c in self.track_cells[step].items():
            v += c.size
        assert v == 1

    def find_cells(self, X):
        _, indexes = self._tree.query(X)
        return indexes

    def add_labels(self, X, y):
        indexes = self.find_cells(X)
        for c in np.unique(indexes):
            self.p_cells[self.id_p_cells[c]].add_labels(
                X[int(c) == indexes],
                y[int(c) == indexes])

        if len(self.x) == 0:
            self.x = X
            self.y = y
            self.loc = np.array(self.id_p_cells)[indexes]
            self.loc_step = np.array([self.step] * len(self.x))
        else:
            self.x = np.concatenate((X, self.x))
            self.y = np.concatenate((y, self.y))
            self.loc = np.concatenate(
                (self.loc, np.array(self.id_p_cells)[indexes]))
            self.loc_step = np.concatenate((
                self.loc_step,
                np.array([self.step] * len(X))))
        self.p = np.mean([c.value for _, c in self.p_cells.items()])

    def current_cells(self) -> dict:
        return self.track_cells[self.step]

    def __make_mapping(self):
        self.cells_mapping = {}
        for i in self.current_cells().keys():
            items = i.split("+")
            for j in items:
                self.cells_mapping[j] = i

    def split(self, str_id, str_rmv):
        if str_rmv in str_id:
            list_str: list = str_id.split("+")
            list_str.remove(str_rmv)
            save_cells = self.track_cells[self.step]
            new_cells = {}

            for i in save_cells.keys():
                if i not in str_id:
                    new_cells[i] = self.current_cells()[i]

            cell = self.p_cells[list_str[0]]
            for j in list_str[1:]:
                cell += save_cells[j]
            new_cells[cell.id] = cell

            new_cells[str_rmv] = self.p_cells[str_rmv]

            self.track_cells[self.step] = {k: new_cells[k] for k in
                                           np.sort(list(new_cells.keys()))}
            self.__make_mapping()

    def merge(self, list_of_tuples):
        save_cells = self.track_cells[self.step]
        new_cells = {}
        for items in list_of_tuples:
            cell = save_cells[items[0]]
            for j in items[1:]:
                cell += save_cells[j]
            new_cells[cell.id] = cell

        all_tuple_index = [item for sublist in list_of_tuples for item in
                           sublist]

        for i in save_cells.keys():
            if i not in all_tuple_index:
                new_cells[i] = self.current_cells()[i]

        self.track_cells[self.step] = new_cells
        self.__make_mapping()

    def ini_step(self):
        self.step += 1
        self.track_cells[self.step] = copy(self.track_cells[self.step - 1])
        self.id_c_cells = list(self.track_cells[self.step].keys())

    def plot_cell(self, id_cell, eta=None, **args):
        cell: Cell = self.p_cells[id_cell]
        cell.plot_rectangle()
        if eta is not None:
            c = plt.get_cmap(config.eta_cmap)(eta(cell.x))
        else:
            c = "k"
        plt.scatter(cell.x[:, 0], cell.x[:, 1], c=c, **args)


def test_cells():
    c1 = Cell("1", 10, [0.1])

    c2 = Cell("2", 10, [0.2])
    c3 = c1 + c2
    c2.add_labels(np.array([[0, 0.2], [0, 0.4]]),
                  np.array([True, False]))
    c1.add_labels(np.array([[0, 1]]), np.array([True]))
    assert len(c3.get_labels()[1]) == 3
    assert c3.id == "1+2"


def test_partition():
    from active_ranking.base.sampler import sampler_d
    from active_ranking.scenarios.eta import eta_j_3_labeler
    p = Partition(j_max=1, d=3)
    assert p.centroids.shape == (2 ** (p.d * p.j_max), p.d)
    p.check_volume(0)
    X = sampler_d(200, d=p.d)
    y = eta_j_3_labeler(X)
    p.add_labels(X, y)
    for _, c in p.current_cells().items():
        print(c.n)
    p.merge(list_of_tuples=[("1", "2", "3")])
    p.add_labels(X, y)
    p.check_volume(p.step)


def test_plot_partition():
    from active_ranking.base.sampler import sampler_d
    from active_ranking.scenarios.eta import eta_d_2_labeler
    p = Partition(j_max=2, d=2)
    X = sampler_d(200, d=p.d)
    y = eta_d_2_labeler(X)
    plt.figure()
    for c in p.p_cells.__keys():
        p.p_cells[c].plot_rectangle()
        plt.scatter(p.p_cells[c].centers[0][0], p.p_cells[c].centers[0][1])
    plt.xlim((0, 1))
    plt.ylim((0, 1))
