import matplotlib.pyplot as plt
from math import floor
from scipy.integrate import quad
from functools import partial
from scipy.spatial import ConvexHull
from utilities import cap_up_and_down

precision = 1000


def normalize(d):
    assert min(d) > 0
    ratio = sum(d) / len(d)
    return [x / ratio for x in d]


def f(d, x):
    size = len(d)
    grid_x = cap_up_and_down(floor(x * size), upBound=size - 1)
    return d[grid_x]


def F(d, x):
    size = len(d)
    grid_x = cap_up_and_down(floor(x * size), upBound=size - 1)
    res = 0
    for idx in range(grid_x):
        res += d[idx] / len(d)
    res += (x - grid_x / size) * d[grid_x]
    return cap_up_and_down(res, upBound=1)


def Finverse(d, x):
    size = len(d)
    for idx in range(size):
        a = F(d, idx / size)
        b = F(d, (idx + 1) / size)
        if a <= x <= b:
            return idx / size + (x - a) / (b - a) / size
    assert abs(x - 1) < 0.000001, "numerical issue at {x} in Finverse"
    return 1


def virtual_valuation(d, x):
    return x - (1 - F(d, x)) / f(d, x)


def h(d, q):
    return virtual_valuation(d, Finverse(d, q))


def H(d, q):
    # not supposed to be actually used -- numerical precision issue
    # refer to H_values below
    return quad(partial(h, d), 0, q, limit=1000)[0]


def plot(func, name="func"):
    plt.rcParams.update({"font.size": 14})
    x = [i / precision for i in range(precision + 1)]
    y = [func(i / precision) for i in range(precision + 1)]
    zero = [0 for _ in range(precision + 1)]
    fig, ax = plt.subplots()
    # ax.set_aspect(1 / 6)
    ax.set_xlim(0, 1)
    ax.set_ylim(-5, 1)
    if name == "mv":
        plt.vlines(0.235, -5, 1, colors="black", linestyle="dashed")
        plt.vlines(0.364, -5, 1, colors="black", linestyle="dashed")
        plt.vlines(0.8, -5, 1, colors="black", linestyle="dashed")
        plt.xticks([0, 0.235, 0.364, 0.8, 1])
        y[364] = None
        y[799] = None
        y[1000] = None
    if name == "vv":
        plt.vlines(0.235, -5, 1, colors="black", linestyle="dashed")
        plt.vlines(0.4, -5, 1, colors="black", linestyle="dashed")
        plt.vlines(0.8, -5, 1, colors="black", linestyle="dashed")
        plt.xticks([0, 0.235, 0.4, 0.8, 1])
        y[199] = None
        y[399] = None
        y[599] = None
        y[799] = None
        y[1000] = None
    if name == "vvi":
        for idx in range(precision):
            print(idx, y[idx])
        plt.vlines(0.235, -5, 1, colors="black", linestyle="dashed")
        plt.vlines(0.8, -5, 1, colors="black", linestyle="dashed")
        plt.xticks([0, 0.235, 0.8, 1])
        y[200] = None
        y[800] = None
    plt.plot(x, y, linewidth=3)
    plt.plot(x, zero, "k--")
    # ticks = [0, 1]
    # for idx in range(precision):
    #     if idx == precision - 1:
    #         continue
    #     if (
    #         (y[idx] <= 0 and y[idx + 1] > 0)
    #         or (y[idx] >= 0 and y[idx + 1] < 0)
    #         or (y[idx] < 0 and y[idx + 1] >= 0)
    #         or (y[idx] > 0 and y[idx + 1] <= 0)
    #     ):
    #         print(x[idx + 1])
    #         ticks.append(x[idx + 1])
    assert False
    plt.savefig(f"{name}.pdf")
    plt.close()


def convex_hull_H(d):
    h_values = [h(d, i / precision) for i in range(precision + 1)]
    H_values = [0] * (precision + 1)
    for i in range(1, precision + 1):
        H_values[i] = H_values[i - 1] + (h_values[i - 1] + h_values[i]) / 2 / precision

    # for i in range(precision + 1):
    #     x = i / precision
    #     print(H(d, x), H_values[i], abs(H(d, x) - H_values[i]))

    points = [[i / precision, H_values[i]] for i in range(precision + 1)]

    hull = ConvexHull(points)
    # hull = ConvexHull(points, qhull_options="QJ")
    hull_points = []
    for v in hull.vertices:
        hull_points.append(points[v])
    return sorted(hull_points)


def ironed_virtual_valuations(hull_points, d, x):
    for idx in range(len(hull_points) - 1):
        if hull_points[idx][0] <= F(d, x) <= hull_points[idx + 1][0]:
            return (hull_points[idx + 1][1] - hull_points[idx][1]) / (
                hull_points[idx + 1][0] - hull_points[idx][0]
            )
    assert (
        False
    ), f"not supposed to reach, numerical error in ironed_virtual_valuations {x}"
