import numpy as np
from matplotlib.ticker import FixedLocator, FixedFormatter
from scipy.stats import norm

from experiments.grid_runs import GridPoint, Task
from experiments.grid_runs_utils import get_axes_from_grid_config, prepare_grid_pairs_to_iterate

C_JL = 4
# C_JL = 2

eps_JL = lambda m, v: np.sqrt(C_JL * np.log(v) / m)


def p_normal_greater_than_zero(mu, var):
    if var == 0:
        return np.ones_like(var)
    sigma = np.sqrt(var)
    sigma = np.asarray(sigma)
    return norm.cdf(mu / sigma)


# N_f_factor = 1.75, sigma_factor = 0.275
# N_f_factor = 2, sigma_factor = 0.275

def compute_theoretical_accuracy__JL_probabilistic(D, N, V, N_f, L, N_f_factor = 0.4, sigma_factor = 0.3):

    # print(f"{D=}, {N=}, {V=}, {N_f=}, {L=}")

    # assert D > 0
    # assert N > 0
    # assert V > 0
    # assert N_f > 0
    # assert L > 0

    if type(N) is np.array:
        N[N < 2] = np.nan

    V_v = V // 2
    # V_v = V

    N_f_tilde = N_f_factor * (N_f - 1)

    sigma_squared_k = sigma_factor / N
    sigma_squared_v = sigma_factor / D

    sigma_squared_f = sigma_squared_k * sigma_squared_v

    sigma_squared = N_f_tilde * sigma_squared_f

    # sigma_squared_wrong = 2*(sigma_squared_v + 2 * sigma_squared_k) + N_f_tilde * sigma_squared_f
    # sigma_squared_wrong = 2*(sigma_squared_v + 1 * sigma_squared_k) + N_f_tilde * sigma_squared_f
    # sigma_squared_noise = 2*(sigma_squared_v + 1 * sigma_squared_k) + N_f_tilde * sigma_squared_f

    sigma_squared_wrong = 2 * (sigma_squared_v + 2 * sigma_squared_k) + 2 * N_f_tilde * sigma_squared_f
    sigma_squared_noise = 2 * (sigma_squared_v + 1 * sigma_squared_k) + 2 * N_f_tilde * sigma_squared_f

    # # e_v = eps_JL(m=D, v=V)
    # # e_k = eps_JL(m=N, v=V)
    # e_v = sigma_squared_v
    # e_k = sigma_squared_k

    # min_correct = (1 - e_v) * (1 - e_k)
    # max_wrong = (1 + e_v) * e_k
    # max_noise = e_v * e_k

    # mu_dy_wrong = min_correct - max_wrong
    # mu_dy_noise = min_correct - max_noise
    mu_dy_wrong = 1
    mu_dy_noise = 1

    # p_greater_than_wrong = p_normal_greater_than_zero(mu=mu_dy_wrong, sigma=sigma_squared)
    # p_greater_than_noise = p_normal_greater_than_zero(mu=mu_dy_noise, sigma=sigma_squared)

    p_greater_than_wrong = p_normal_greater_than_zero(mu=mu_dy_wrong, var=sigma_squared_wrong)
    p_greater_than_noise = p_normal_greater_than_zero(mu=mu_dy_noise, var=sigma_squared_noise)

    N_wrong_entries = N_f - 1
    N_noise_entries = V_v - N_f

    p_success = (p_greater_than_wrong ** N_wrong_entries) * (p_greater_than_noise ** N_noise_entries)

    # TODO
    # override, simply return this:

    factor = 1
    var = (1.5 * 2 * N_f / (D * N)) # TODO - THIS WAS USED FOR THE PAPER, NEVER DELETE
    # var = factor * 2 * (3./N + 2./D + 2 * N_f / (D * N))
    p_success = p_normal_greater_than_zero(mu=1, var=var) ** (V // 2)

    return p_success


def compute_theoretical_accuracy_grid(run_config):

    # set up grid/lines pairs
    grid_config = run_config["grid"]
    # grid_options = run_config["grid_options"]
    grid_axes, grid_constants = get_axes_from_grid_config(grid_config)
    xy_pairs, dims = prepare_grid_pairs_to_iterate(grid_axes, grid_constants, grid_options={})
    grid_results_dict: dict[str, dict] = {}

    print(f"{grid_constants = }")

    x_axis = grid_axes.x.axis.astype(int)
    y_axis = grid_axes.y.axis.astype(int)
    x_label = grid_axes.x.name
    y_label = grid_axes.y.name

    theoretical_acc_grid = np.full((len(x_axis), len(y_axis)), np.nan)

    for (x, y) in xy_pairs:

        x_pt = GridPoint(name=grid_axes.x.name, value=int(x))
        y_pt = GridPoint(name=grid_axes.y.name, value=int(y))

        task = Task(x=x_pt, y=y_pt, constants=grid_constants, seed=0)

        if task.dims.D < 4:
            continue

        # optionally scale N with D:
        if grid_config.get('scale_N_with_D', False):
            task.dims.N = int(task.dims.D / float(grid_config['D_to_N_ratio']))

        if task.dims.N < 2:
            continue

        dims = task.dims

        if dims.N > dims.D:
            continue

        i = np.where(x_axis == x)[0]
        j = np.where(y_axis == y)[0]

        if dims.N < 2:
            theoretical_acc_grid[i, j] = 0  # temp
            continue

        args = (dims.D, dims.N, dims.V, dims.N_facts, dims.L)
        acc = compute_theoretical_accuracy__JL_probabilistic(*args)
        theoretical_acc_grid[i, j] = acc

    return theoretical_acc_grid


# plot helpers

# Fractions of the axis where we want ticks
FRACTIONS = np.array([0.0, 0.25, 0.5, 0.75, 1.0])

def _closest(values, targets):
    values = np.asarray(values)
    idx = [np.nanargmin(np.abs(values - t)) for t in targets]
    return values[np.array(idx, dtype=int)]

def _fmt_ints(values, thousand_sep=False, signed=False):
    vals_i = np.rint(values).astype(np.int64)  # round then cast to int
    if thousand_sep:
        fmt = "{:+,d}" if signed else "{:,d}"
    else:
        fmt = "{:+d}"  if signed else "{:d}"
    return [fmt.format(v) for v in vals_i]

def set_quarter_ticks_from_arrays(ax, x_vals, y_vals, thousand_sep=False):
    x_vals = np.asarray(x_vals)
    y_vals = np.asarray(y_vals)

    if x_vals.size > 1:
        x_idx = np.rint(FRACTIONS * (x_vals.size - 1)).astype(int)
        x_pos = x_vals[x_idx]
        ax.xaxis.set_major_locator(FixedLocator(x_pos))
        ax.xaxis.set_major_formatter(FixedFormatter(_fmt_ints(x_pos, thousand_sep)))

    y_min, y_max = float(np.nanmin(y_vals)), float(np.nanmax(y_vals))
    if y_max != y_min:
        targets = y_min + FRACTIONS * (y_max - y_min)
        # snap to nearest array values, then format as ints
        idx = [np.nanargmin(np.abs(y_vals - t)) for t in targets]
        y_pos = np.asarray(y_vals)[idx]
        ax.yaxis.set_major_locator(FixedLocator(y_pos))
        ax.yaxis.set_major_formatter(FixedFormatter(_fmt_ints(y_pos, thousand_sep)))