import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from active_ranking import config
from active_ranking.base.utils import FunctionAsLabeler, GridAsFunction
from active_ranking.scenarios.complexity import problem_complexity

# Defining eta(x) = P( Y = 1 | x) piecewise constant on unit interval
# Jmax**2 sized partition of [0,1]


def levels(d, j_max, l=None, p=None, deterministic=False):
    """
    Parameters
    ----------
    d : int
        dimension
    j_max : int
        max refinement
    l : 
        levels of eta
    deterministic :
        random level choice
    Returns
    -------
    """
    if not deterministic:
        ret = np.array([np.random.choice(l, replace=True, p=p) for _ in
                        range(2 ** (j_max * d))]).reshape(
            (2 ** j_max,) * d
        )
    else:
        if callable(l):
            ret = np.array(
                [l(i / 2 ** (j_max * d)) for i in range(0, 2 ** (j_max * d))]
            ).reshape(
                (2 ** j_max,) * d
            )
        else:
            ret = np.array(l).reshape(
                (2 ** j_max,) * d)
    return ret


def levels2(d, j_max, jumps, jump=0.1):
    jumps = np.sort(jumps)
    ret = np.ones((2 ** j_max,) * d) * jumps[0]
    i = np.array((0,) * d)
    i_level = 0
    while True:
        it = tuple(i)
        if it == (2 ** j_max,) * d:
            break
        else:
            if i_level == len(jumps) - 1:
                ret[it] = jumps[i_level]
            else:
                if np.random.uniform() < jump:
                    ret[it] = np.random.choice(jumps)
                if ret[it] == jumps[i_level + 1]:
                    i_level += 1
            # update i
            select = np.array([i_ < 2 ** j_max - 1 for i_ in i])
            if np.sum(select) == 0:
                break
            direction = np.random.choice(np.array(range(d))[select])
            i[direction] += 1
    for id in range(d):
        ret = np.cumsum(ret, axis=id)
    if np.max(ret) > 0:
        ret = ret / np.max(ret) * 0.99
    return ret


def plot_grid(grid, *args, **kwargs):
    import matplotlib.pyplot as plt
    import seaborn as sns

    sns.heatmap(grid.T[::-1], *args, **kwargs)
    j = np.log2(grid.shape[0])
    a = 1 / 2 ** (j + 1)
    plt.gca().set_xticklabels([a + 2 * a * i for i in range(grid.shape[0])])
    plt.gca().set_yticklabels(
        [a + 2 * a * i for i in range(grid.shape[0])][::-1])


def plot_eta_in_one_d(grid: np.array, color=config.colors[0], order=True):
    levels = np.ravel(grid)

    r = np.array([i for i in range(len(levels)) for _ in range(2)])
    r_plot = np.concatenate(([0], r[:-1] + 1)) - 0.5

    levels__ = np.sort(levels) if order else levels
    levels_sorted = np.array([i for i in levels__ for _ in range(2)])

    plt.plot(r_plot, levels_sorted, color)

    plt.xlabel("Cell index")
    plt.ylabel("$\eta$")


def generating_eta_functions():
    import matplotlib

    import pandas as pd
    matplotlib.use("agg")
    cmap = config.eta_cmap
    np.random.seed(0)

    eta_1_ = np.array(
        [[0.1, 0.43],
         [0.42, 0.9]])
    eta_d_2_grid_2 = levels2(
        d=2, j_max=3, jumps=np.array([0.1, 0.1, 0.2, 0]),
        jump=0.5)
    eta_j_3_grid = levels(d=3, j_max=3, l=np.array([0.1, 0.3, 0.5, 0.99]))
    eta_2_ = levels(
        d=2, j_max=2,
        l=np.array([0, 0.28, 0.3, 0.33, 0.38]),
        p=[0.81, 0.1, 0.05, 0.02, 0.02])
    eta_3_ = levels(
        d=2, j_max=3,
        l=lambda x: ((x ** 4) * 0.8 + 0.1),
        deterministic=True)
    eta_4_ = levels(
        d=2, j_max=3,
        l=np.linspace(0.1, 0.9, 100) ** 4)
    eta_5_ = levels(
        d=2, j_max=2,
        l=np.random.uniform(size=2 ** (2 * 2)),
        deterministic=True)

    eta_6_ = levels(
        d=2, j_max=2,
        l=np.random.normal(0.5, 0.05, size=2 ** (2 * 2)),
        deterministic=True)
    eta_7_ = 1 - eta_3_

    l8 = np.linspace(0.1, 0.9, num=2 ** (2 * 2))
    l8[7] += 0.02
    l8[8] -= 0.02
    eta_8_ = levels(
        d=2, j_max=2,
        l=l8,
        deterministic=True)

    matplotlib.rc('xtick', labelsize=8)
    matplotlib.rc('ytick', labelsize=8)

    for i, eta in enumerate([
        eta_1_, eta_2_, eta_3_,
        eta_4_, eta_5_, eta_6_,
        eta_7_, eta_8_
    ]):
        pd.DataFrame(eta).to_pickle(f"results/eta_{i + 1}")

        plt.figure()
        plot_grid(eta, annot=True, cmap=cmap, fmt='.3f',
                  annot_kws=dict(fontsize=8))
        plt.savefig(f"./results/figures/eta/eta_{i + 1}")

        plt.figure(figsize=(3.5, 3.5), dpi=250)
        plot_eta_in_one_d(eta)
        c = problem_complexity(eta)
        ax2 = plt.gca().twinx()
        plt.sca(ax2)
        plot_eta_in_one_d(c, color=config.colors[2], order=False)
        ax2.grid(False)
        ax2.set_ylabel("Complexity", color=config.colors[2])
        plt.savefig(f"./results/figures/eta/levels_eta_{i + 1}")
    np.random.seed()


generating_eta_functions()
eta_0_grid = np.array([0.1, 0.3])
eta_0 = GridAsFunction(eta_0_grid)
eta_0_labeler = FunctionAsLabeler(eta_0)

eta_d_2_grid_1 = np.array(
    [[0.1, 0.43],
     [0.42, 0.9]])

eta_1 = GridAsFunction(eta_d_2_grid_1)
eta_1_labeler = FunctionAsLabeler(eta_1)

# eta_j_3 = GridAsFunction(pd.read_pickle("results/eta_2").values)
# eta_j_3_labeler = FunctionAsLabeler(eta_j_3)

eta_1 = GridAsFunction(pd.read_pickle("results/eta_1").values)
eta_2 = GridAsFunction(pd.read_pickle("results/eta_2").values)
eta_3 = GridAsFunction(pd.read_pickle("results/eta_3").values)
eta_4 = GridAsFunction(pd.read_pickle("results/eta_4").values)
eta_5 = GridAsFunction(pd.read_pickle("results/eta_5").values)
eta_6 = GridAsFunction(pd.read_pickle("results/eta_6").values)
eta_7 = GridAsFunction(pd.read_pickle("results/eta_7").values)
eta_8 = GridAsFunction(pd.read_pickle("results/eta_8").values)
