from settings import n
from utilities import tuple_update_by_index, cap_up_and_down
import random
from math import floor
from scipy import optimize, integrate


def uniform_sampling(count):
    return [tuple(random.random() for _ in range(n)) for _ in range(count)]


class GridDistribution:
    def __init__(self, name=None, values=None, size=None, seed=None, binary=False):
        if values is not None:
            self.values = values
            assert name is not None
            self.name = name
        else:

            def gen_values(level_left):
                if level_left > 1:
                    return [gen_values(level_left - 1) for _ in range(size)]
                else:
                    if binary:
                        return [random.randint(0, 1) for _ in range(size)]
                    else:
                        return [random.random() for _ in range(size)]

            random.seed(seed)
            self.values = gen_values(n)
            self.name = f"random-distribution-size{size}-seed{seed}-binary{binary}"

        self.size = len(self.values)

        def verify_size_for_every_level(x):
            if isinstance(x, list):
                assert len(x) == self.size, f"{x} incorrect size {self.size}"
                for c in x:
                    verify_size_for_every_level(c)
            else:
                assert x >= 0, f"{x} needs to be nonnegative"

        verify_size_for_every_level(self.values)

        def integrate_for_normalization(x):
            if isinstance(x, list):
                return sum(integrate_for_normalization(c) for c in x)
            else:
                return x / (self.size**n)

        self.normalization_ratio = integrate_for_normalization(self.values)

        def normalize(x):
            if isinstance(x, list):
                for i in range(len(x)):
                    x[i] = normalize(x[i])
                return x
            else:
                if self.normalization_ratio == 0:
                    return 1
                else:
                    return x / self.normalization_ratio

        normalize(self.values)
        assert abs(integrate_for_normalization(self.values) - 1) < 0.000001

    def retrieve_by_grid(self, grid, start=None):
        if len(grid) == n:
            start = self.values
        if len(grid) > 0:
            return self.retrieve_by_grid(grid[1:], start[grid[0]])
        else:
            return start

    def AMD_grid_integration(self, AMD_grid, H):
        estimated_point = [x / H for x in AMD_grid]
        return self.retrieve_by_point(estimated_point) / (H**n)

    def point_to_grid(self, point):
        grid = tuple(
            cap_up_and_down(floor(x * self.size), lowBound=0, upBound=self.size - 1)
            for x in point
        )
        return grid

    def retrieve_by_point(self, point):
        return self.retrieve_by_grid(self.point_to_grid(point))

    def rejection_sampling(self, count):
        def find_cap(x):
            if isinstance(x, list):
                return max(find_cap(c) for c in x)
            else:
                return x

        cap = find_cap(self.values)
        res = []
        while len(res) < count:
            point = tuple(random.random() for _ in range(n))
            grid = self.point_to_grid(point)
            x = self.retrieve_by_grid(grid)
            y = random.random() * cap
            if x >= y:
                res.append(point)
        return res

    def conditional_distribution(self, grid, agent):
        res = []
        for x in range(self.size):
            modified_grid = tuple_update_by_index(grid, agent, x)
            res.append(self.retrieve_by_grid(modified_grid))
        return res

    def quadratic_formula_by_point(self, point, agent):
        grid = self.point_to_grid(point)
        return self.quadratic_formula(grid, agent)

    # do the integration within a grid for a given agent's dimension, resulting in a quadratic formula, returns a,b in ax^2+bx
    def quadratic_formula(self, grid, agent):
        cd = self.conditional_distribution(grid, agent)
        # the max expression has form: vprime * (integration_of_higher_grids + (current_grid_upper - vprime) * current_grid_density)
        integration_of_higher_grids = 0
        for x in range(grid[agent] + 1, self.size):
            integration_of_higher_grids += cd[x] / self.size
        current_grid_upper = (grid[agent] + 1) / self.size
        current_grid_density = self.retrieve_by_grid(grid)
        a = -current_grid_density
        b = integration_of_higher_grids + current_grid_upper * current_grid_density
        return a, b

    # the marginal_rev before negative partial differentiation
    def marginal_rev_integration(self, point, agent):
        current_max = 0
        current_max_where = 1
        grid = self.point_to_grid(point)
        for x in range(grid[agent], self.size):
            modified_grid = tuple_update_by_index(grid, agent, x)
            a, b = self.quadratic_formula(modified_grid, agent)
            grid_lower = modified_grid[agent] / self.size
            grid_upper = (modified_grid[agent] + 1) / self.size
            candidates = [grid_lower, grid_upper]
            if a < 0:
                mid = -b / a / 2
                if grid_lower < mid < grid_upper:
                    candidates.append(mid)
            for x in candidates:
                if x > point[agent]:
                    new_val = a * x * x + b * x
                    if new_val > current_max:
                        current_max = new_val
                        current_max_where = x
        a, b = self.quadratic_formula(grid, agent)
        y = a * point[agent] * point[agent] + b * point[agent]
        if y > current_max:
            return y, point[agent], (a, b)
        else:
            return current_max, current_max_where, None

    def marginal_rev(self, point, agent):
        current_max, _, ab = self.marginal_rev_integration(point, agent)
        if ab is None:
            return 0
        a, b = ab
        # a x^2 + bx partial wrt x is 2ax + b, there is also a minus sign in front after partial
        return (-1) * (2 * a * point[agent] + b)

    def marginal_rev_per_density(self, point, agent):
        return self.marginal_rev(point, agent) / self.retrieve_by_point(point)

    def verify_greedy(self):
        def still_win(profile, old_winner):
            mvs = [self.marginal_rev(profile, agent) for agent in range(n)]
            return mvs[old_winner] == max(mvs)

        def get_unique_winner(profile):
            mvs = [self.marginal_rev(profile, agent) for agent in range(n)]
            max_mv = max(mvs)
            count = sum(1 if t == max_mv else 0 for t in mvs)
            if count == 1:
                return mvs.index(max_mv)
            else:
                return None

        for profile in uniform_sampling(10000):
            winner = get_unique_winner(profile)
            if winner is None:
                continue
            for r in range(100):
                new_bid = profile[winner] + (1 - profile[winner]) * r / 100
                modified_profile = tuple_update_by_index(profile, winner, new_bid)
                if not still_win(modified_profile, winner):
                    print(winner)
                    print(profile)
                    print([self.marginal_rev(profile, agent) for agent in range(n)])
                    print(modified_profile)
                    print(
                        [
                            self.marginal_rev(modified_profile, agent)
                            for agent in range(n)
                        ]
                    )
                    return False
        return True

    def get_marginal_value_winners(self, point):
        mvs = [self.marginal_rev(point, agent) for agent in range(n)]
        max_mv = max(mvs)
        winners = []
        for agent in range(n):
            if mvs[agent] == max_mv:
                winners.append(agent)
        return winners, max_mv


saturn = GridDistribution("saturn", [[4, 9, 2], [3, 5, 7], [8, 1, 6]])

jupiter = GridDistribution(
    "jupiter", [[4, 14, 15, 1], [9, 7, 6, 12], [5, 11, 10, 8], [16, 2, 3, 13]]
)

mars = GridDistribution(
    "mars",
    [
        [11, 24, 7, 20, 3],
        [4, 12, 25, 8, 16],
        [17, 5, 13, 21, 9],
        [10, 18, 1, 14, 22],
        [23, 6, 19, 2, 15],
    ],
)

sol = GridDistribution(
    "sol",
    [
        [6, 32, 3, 34, 35, 1],
        [7, 11, 27, 28, 8, 30],
        [19, 14, 16, 15, 23, 24],
        [18, 20, 22, 21, 17, 13],
        [25, 29, 10, 9, 26, 12],
        [36, 5, 33, 4, 2, 31],
    ],
)

venus = GridDistribution(
    "venus",
    [
        [22, 47, 16, 41, 10, 35, 4],
        [5, 23, 48, 17, 42, 11, 29],
        [30, 6, 24, 49, 18, 36, 12],
        [13, 31, 7, 25, 43, 19, 37],
        [38, 14, 32, 1, 26, 44, 20],
        [21, 39, 8, 33, 2, 27, 45],
        [46, 15, 40, 9, 34, 3, 28],
    ],
)

mercury = GridDistribution(
    "mercury",
    [
        [8, 58, 59, 5, 4, 62, 63, 1],
        [49, 15, 14, 52, 53, 11, 10, 56],
        [41, 23, 22, 44, 45, 19, 18, 48],
        [32, 34, 35, 29, 28, 38, 39, 25],
        [40, 26, 27, 37, 36, 30, 31, 33],
        [17, 47, 46, 20, 21, 43, 42, 24],
        [9, 55, 54, 12, 13, 51, 50, 16],
        [64, 2, 3, 61, 60, 6, 7, 57],
    ],
)

luna = GridDistribution(
    "luna",
    [
        [37, 78, 29, 70, 21, 62, 13, 54, 5],
        [6, 38, 79, 30, 71, 22, 63, 14, 46],
        [47, 7, 39, 80, 31, 72, 23, 55, 15],
        [16, 48, 8, 40, 81, 32, 64, 24, 56],
        [57, 17, 49, 9, 41, 73, 33, 65, 25],
        [26, 58, 18, 50, 1, 42, 74, 34, 66],
        [67, 27, 59, 10, 51, 2, 43, 75, 35],
        [36, 68, 19, 60, 11, 52, 3, 44, 76],
        [77, 28, 69, 20, 61, 12, 53, 4, 45],
    ],
)

worst100 = GridDistribution(
    "worst100",
    [
        [1, 1.0, 1.0, 1.0, 1.0],
        [1.0, 100.0, 1.0, 1.0, 1.0],
        [1.0, 1.0, 48.5, 24.25, 72.75],
        [1.0, 14.25, 1.0, 1.0, 3.0],
        [1.0, 42.75, 1.0, 1.0, 1.0],
    ],
)

worst10 = GridDistribution(
    "worst10",
    [
        [1, 1.0, 1.0, 1.0, 1.0],
        [1.0, 10.0, 9.0, 1.0, 10.0],
        [1.0, 1.0, 1.0, 1.0, 1.0],
        [1.0, 1.0, 1.0, 2.5892857142845793, 1.0],
        [1.0, 1.0, 1.0, 7.7678571428537335, 1.0],
    ],
)


if __name__ == "__main__":
    d = GridDistribution([[100, 0], [0, 0]])
    print(d.values)
    d = GridDistribution(size=2, seed=0)
    print(d.values)
    d = GridDistribution(size=2, seed=0, binary=True)
    print(d.values)
    d = GridDistribution([[1, 2], [3, 4]])
    assert d.retrieve_by_grid((0, 0)) == 0.4
    assert d.retrieve_by_grid((0, 1)) == 0.8
    assert d.retrieve_by_grid((1, 0)) == 1.2
    assert d.retrieve_by_grid((1, 1)) == 1.6

    d = GridDistribution([[1, 2], [3, 4]])
    assert d.point_to_grid((0.49, 0)) == (0, 0)
    assert d.point_to_grid((0.49, 0.5)) == (0, 1)
    assert d.point_to_grid((0.99, 0.49)) == (1, 0)
    assert d.point_to_grid((0.87, 0.64)) == (1, 1)

    assert tuple(d.conditional_distribution((0, 0), 0)) == (0.4, 1.2)
    assert tuple(d.conditional_distribution((0, 0), 1)) == (0.4, 0.8)

    d = GridDistribution([[0, 0], [0, 1]])
    for x, y in d.rejection_sampling(100):
        assert x >= 0.5
        assert y >= 0.5

    def marginal_rev_integration_numerical_version(d, point, agent):
        def func(x):
            modified_point = tuple_update_by_index(point, agent, x)
            return d.retrieve_by_point(modified_point)

        def integration_from_vprime(vprime):
            return -vprime * integrate.quad(func, vprime, 1)[0]

        return -optimize.brute(
            integration_from_vprime,
            (slice(point[agent], 1, (1 - point[agent]) / 1000),),
            full_output=True,
            finish=None,
        )[1]

    d = GridDistribution(size=10, seed=1)
    print(d.maximum_marginal_rev(0))
    print(d.maximum_marginal_rev(1))
    for point in d.rejection_sampling(1000):
        try:
            n0 = marginal_rev_integration_numerical_version(d, point, 0)
            a0 = d.marginal_rev_integration(point, 0)
            n1 = marginal_rev_integration_numerical_version(d, point, 1)
            a1 = d.marginal_rev_integration(point, 1)
        except: 
            print(point)
        if abs(n0 - a0[0]) > 0.001 or abs(n1 - a1[0]) > 0.001:
            print(point, n0, a0, n1, a1)