import copy

import numpy as np

from active_ranking import config, utils
from active_ranking.base import ucb_lcb
from active_ranking.base.abstract import LearnerAnalyser
from active_ranking.base.sampler import sample_around_center
from active_ranking.base.sampler import sampler_d
from active_ranking.base.utils import _compute_active_set, \
    create_tuple_of_disjoint_cells, get_tuple_ij, UMessyRank


class ActiveLearnerOld(LearnerAnalyser):
    name = "active-rank-old"

    def compute_w(self):
        li = list(self.partition.current_cells().keys())
        self.w = create_tuple_of_disjoint_cells(np.array(li))
        self.w = np.array([(int(e1), int(e2)) for e1, e2 in self.w])

    def compute_u(self, i, j):
        arr1 = np.array(list(self.partition.current_cells().keys()))
        self.u = get_tuple_ij(arr1, i, j)

    def merge_cells(self):
        part_p = self.partition.p_cells
        i, j = self.ij

        self.compute_u(i, j)
        self.diff_ucb_lcb = [
            abs(part_p[i].ucb - part_p[j].lcb) for i, j in self.u
        ]
        self.cells1 = [c for c in self.partition.id_c_cells if i in c][0]
        self.cells2 = [c for c in self.partition.id_c_cells if j in c][0]

        assert self.cells1 != self.cells2

        self.list_cells1 = np.array(self.cells1.split("+"))
        self.list_cells2 = np.array(self.cells2.split("+"))

        union = self.partition.current_cells()[self.cells1] + \
                self.partition.current_cells()[self.cells2]
        kappa = np.sum([part_p[c].value for c in
                        (*self.list_cells1, *self.list_cells2)]) / len(
            (*self.list_cells1, *self.list_cells2))
        right_inequality = self.partition.p * config.c * config.epsilon
        # / union.size  # TODO work on that * (1 - kappa)
        right_inequality = 0.02
        left_inequality = max(self.diff_ucb_lcb)
        cond = left_inequality < right_inequality

        if cond:
            self.partition.merge(list_of_tuples=[(self.cells1, self.cells2)])
        self.prune()

    def prune(self):
        if len(self.list_cells1) == 1 or len(self.list_cells2) == 1:
            return None
        else:
            self.prune_helper(self.list_cells1, self.list_cells2, self.cells1)
            self.prune_helper(self.list_cells2, self.list_cells1, self.cells2)

    def prune_helper(self, lists_cell1, list_cells2, cell1):
        ensemble1 = self.bounds[lists_cell1.astype(int)]
        ensemble2 = self.bounds[list_cells2.astype(int)]

        ucb_max = ensemble2[:, 0].max()
        lcb_min = ensemble2[:, 1].min()

        l1 = ensemble1[:, 1] > ucb_max
        l2 = ensemble1[:, 0] < lcb_min
        cell_to_remove = lists_cell1[l1 | l2]

        for elt in cell_to_remove:
            print(elt, cell1)
            self.partition.split(elt, cell1)

    def compute_active_set(self):
        self.compute_w()
        p_cells = self.partition.p_cells
        i, j = _compute_active_set(
            np.array(list(self.ucb.values())),
            np.array(list(self.lcb.values())), self.w)
        print(i, j)
        cell_i, cell_j = list(self.ucb.keys())[i], list(self.ucb.keys())[j]
        cells1 = [c for c in self.partition.id_c_cells if cell_i in c][0]
        cells2 = [c for c in self.partition.id_c_cells if cell_j in c][0]

        if p_cells[cell_i].n > p_cells[cell_j].n:
            cells1, cells2 = cells2, cells1

        self.ij = cell_i, cell_j
        self.active_cell = (self.partition.current_cells()[cells1],)

    def sample(self, n=1):
        centers = np.array([c.centers for c in self.active_cell])
        centers = centers.reshape(-1, self.partition.d)
        samples = sample_around_center(centers, j_max=self.partition.j_max,
                                       d=self.partition.d, n=n)
        return samples[np.random.choice(
            range(len(samples)), n, replace=False), :]

    def stopping(self):
        if len(self.partition.current_cells()) == 1:
            self.stop = True
            return None
        crit = [self.lcb[self.partition.id_p_cells[i]] > self.ucb[
            self.partition.id_p_cells[j]] for i, j in self.w]
        self.stop = sum(crit) == len(crit)


class ActiveLearner(LearnerAnalyser):
    name = "active-rank-old-2"

    def __init__(self, j_max, d, eta, prioritize="n_sample"):
        super().__init__(j_max, d, eta)

        self.bi = None
        self.active_set = [k for k in self.partition.p_cells.keys()]
        self.epsilon = 1
        self.prioritize = prioritize

    @utils.execution_time
    def compute_active_set(self):
        cells = self.partition.p_cells
        ucb = [cells[c].ucb for c in self.active_set]
        lcb = [cells[c].lcb for c in self.active_set]

        self.bi = ucb_lcb.BoundsIntersection(
            np.array(ucb),
            np.array(lcb)
        )
        p = np.mean([c.value for _, c in self.partition.p_cells.items()])
        dist = np.array(self.bi.dist)
        dist = np.where(dist <= 0, 0.0001, dist)
        right = len(self.bi) / dist * p * self.epsilon
        left = self.bi.card
        cond = left <= right
        rmv_keys = np.array(self.active_set)[cond]
        for k in rmv_keys:
            self.active_set.remove(k)

        if len(self.active_set) == 0:
            # print("new initiation")
            self.active_set = list(self.partition.p_cells.keys())
            self.epsilon /= 2

    @utils.execution_time
    def sample(self, n):
        p_cells = self.partition.p_cells
        centers = np.array([p_cells[c].centers
                            for c in self.active_set])[:, 0, :]
        samples = sample_around_center(centers, j_max=self.partition.j_max,
                                       d=self.partition.d, n=n)
        if self.prioritize == "n_sample":
            order = [-p_cells[c].n for c in self.active_set]
        if self.prioritize == "margin":
            order = [p_cells[c].ucb - p_cells[c].lcb for c in self.active_set]
        if self.prioritize is None:
            return samples[np.random.choice(
                range(len(samples)), n, replace=False), :]

        if self.prioritize is not None:
            rank = np.argsort(order)
            return samples[rank][-n:, :]


class DTrackingLearner(LearnerAnalyser):
    name = "D-tracking"

    def __init__(self, j_max, d, eta):
        super().__init__(j_max, d, eta)
        self.arms = list(self.partition.p_cells.keys())
        self.active_set = list(self.partition.p_cells.keys())
        self.delta = {k: 0 for k in self.arms}
        self.w = {k: 0 for k in self.arms}
        self.step = 1
        self.track = {"sample": copy.copy(self.delta)}

        self.k = 0
        self.count = 0

    def update_deltas(self):
        from active_ranking.base import ucb_lcb
        self.etas = [self.partition.p_cells[k].value for k in
                     self.active_set]
        self.max_mu = np.max(self.etas)
        i = np.argmax(self.etas)

        for j, k in enumerate(self.active_set):
            value = ucb_lcb.kl_bernoulli(
                self.etas[j],
                np.array([max((self.etas[j] + config.epsilon_d_tracking),
                              self.max_mu)]))[0]
            if i == j:
                __etas = copy.copy(self.etas)
                del __etas[i]
                __max = max(__etas)
                value = ucb_lcb.kl_bernoulli(
                    self.etas[j],
                    np.array([min((self.etas[j] - config.epsilon_d_tracking),
                                  __max)]))[0]
            self.delta[k] = value
        return [self.delta[k] for k in self.active_set]

    def get_active_cell(self, delta):

        w = ucb_lcb.compute_w_d_tracking_algorithm(np.array(delta))
        t = np.array([self.partition.p_cells[c].n for c in
                      self.active_set])
        criterion = np.argmax(w * self.step - t)
        a = self.active_set[criterion]
        return a

    def sample(self, n):
        t = np.array([self.partition.p_cells[c].n for c in
                      self.arms])
        criterion = np.sum(t <= np.sqrt(self.step) - len(self.arms) / 2) > 0
        if criterion:
            self.step += 1
            a = self.arms[np.argmin(t)]
        else:
            delta = self.update_deltas()
            a = self.get_active_cell(delta)
        self.track["sample"][a] += 1
        x_new = sample_around_center(
            np.array(self.partition.p_cells[a].centers),
            self.partition.j_max, self.partition.d, n)
        return x_new

    def stopping(self):
        self.stop = False
        s = len(self.active_set)
        beta = 2 * (s - 1) * self.step / config.delta_d_tracking
        mu = np.array(
            [self.partition.p_cells[c].value for c in self.active_set])
        t = np.array([self.partition.p_cells[c].n for c in self.active_set])
        self.t1 = 2 * self.n_max / (self.K + 1) / self.K
        self.ti = {i: self.t1 * (self.K - i) for i in range(self.K)}
        t_max = int(self.ti[self.k])
        self.count += 1
        crit = ucb_lcb.compute_stopping_rule(mu, t, beta)
        if crit or (self.count > t_max):
            self.k += 1
            self.count = 0
            a_rmv = np.argmax(mu)
            del self.active_set[a_rmv]


class PassiveLearner(LearnerAnalyser):
    name = "passive-rank"

    def sample(self, n):
        x_new = sampler_d(n, self.partition.d)
        return x_new


class ActiveNaiveLeaner(LearnerAnalyser):
    name = "naive-active-rank"

    def sample(self, n):
        uncertainty = {i: self.ucb[i] - self.lcb[i] for i in self.ucb.keys()}
        uncertainty = np.array(list(uncertainty.values()))
        a = np.argmax(uncertainty)
        cell = list(self.ucb.keys())[a]
        centers = np.array(self.partition.p_cells[cell].centers)
        samples = sample_around_center(centers, j_max=self.partition.j_max,
                                       d=self.partition.d, n=n)
        return samples[np.random.choice(
            range(len(samples)), n, replace=False), :]


class ActiveClassificationLeaner(LearnerAnalyser):
    name = "active-classification"

    def __init__(self, j_max, d, eta, level=0.5):
        super().__init__(j_max, d, eta)
        self.level = level

    def sample(self, n):
        level = self.level
        upper = (vta(self.lcb) < level) & (vta(self.estimates) > level)
        lower = (vta(self.ucb) > level) & (vta(self.estimates) < level)
        candidates = upper | lower

        if sum(candidates) == 0:
            uncertainty = np.abs(
                np.array(list(self.estimates.values())) - level)
            candidates = uncertainty == np.min(uncertainty)

        cell = np.array(list(self.estimates.keys()))[candidates]
        centers = np.array([self.partition.p_cells[c].centers for c in cell])[:,
                  0, :]
        samples = sample_around_center(centers, j_max=self.partition.j_max,
                                       d=self.partition.d, n=n)
        return samples[np.random.choice(
            range(len(samples)), n, replace=False), :]


class MessyRank(LearnerAnalyser):
    name = 'active-rank'

    def __init__(self, j_max, d, eta, prioritize=None,
                 alternative_p_estimate=False):
        super().__init__(j_max, d, eta)
        self.active_set = [k for k in self.partition.p_cells.keys()]
        self.prioritize = prioritize
        self.memory = 0
        self.epsilon = config.epsilon_messy_rank
        self.alternative_p_estimate = alternative_p_estimate
        self.x_p_estimate = np.array([]).reshape((0, self.partition.d))
        self.y_p_estimate = []
        self.active_set_samples = np.array([]).reshape((0, self.partition.d))

    def compute_active_set(self):

        cells = self.partition.p_cells
        ucb = np.array([cells[c].ucb for c in self.active_set])
        lcb = np.array([cells[c].lcb for c in self.active_set])
        mu = np.array([cells[c].value for c in self.active_set])

        if self.alternative_p_estimate:
            p = np.mean(self.partition.y[self.y_p_estimate])
        else:
            p = np.mean(mu)

        self.umr = UMessyRank(
            empirical_mean=mu,
            parameter=config.parameter_messy_rank,
            ucb=ucb, lcb=lcb, k=self.K)
        if p > self.umr.delta:
            q = np.array(self.active_set)[self.umr.q(
                p=p, epsilon=self.epsilon)]
            q = list(q)
            self.active_set = [s for s in self.active_set if s not in q]

            if len(self.active_set) <= 1:
                # print("="*50)
                print("step:", self.partition.step, "  eps:",
                      np.round(self.epsilon, decimals=2), "t :", self.t,
                      "  n_sample:", len(self.partition.x))
                self.epsilon /= 1.1

                self.active_set = [k for k in self.partition.p_cells.keys()]
            if len(q) > 0:
                pass

    def __sample_from_save(self, n):
        ret = self.active_set_samples[:n]
        self.active_set_samples = self.active_set_samples[n:]
        return ret

    def sample(self, n):
        if self.alternative_p_estimate:
            if len(self.partition.x) % len(self.active_set) == 0:
                x_new = sampler_d(n, self.partition.d)
                self.x_p_estimate = np.concatenate(
                    (self.x_p_estimate, x_new
                     ))
                self.y_p_estimate.append(len(self.partition.x))
                return x_new
        if len(self.active_set_samples) != 0 and n <= len(
                self.active_set_samples):
            return self.__sample_from_save(n)
        p_cells = self.partition.p_cells
        centers = np.array([p_cells[c].centers
                            for c in self.active_set])[:, 0, :]
        self.active_set_samples = sample_around_center(
            centers, j_max=self.partition.j_max,
            d=self.partition.d, n=n)
        self.t += 1

        if self.prioritize is None:
            return self.__sample_from_save(n)
        elif self.prioritize == "n_sample":
            order = [-p_cells[c].n for c in self.active_set]
        elif self.prioritize == "margin":
            order = [p_cells[c].ucb - p_cells[c].lcb for c in self.active_set]
        elif self.prioritize == "random":
            return self.active_set_samples[np.random.choice(
                range(len(self.active_set_samples)), n, replace=False), :]
        else:
            raise ValueError
        if self.prioritize is not None:
            rank = np.argsort(order)
            return self.active_set_samples[rank][-n:, :]

    def update_t(self):
        pass


def vta(d: dict):
    return np.array(list(d.values()))


def kta(d: dict):
    return np.array(list(d.keys()))
