from settings import n
from utilities import tuple_update_by_index, cap_up_and_down
import random
from math import floor


def uniform_sampling(count):
    return [tuple(random.random() for _ in range(n)) for _ in range(count)]


class GridDistribution:
    def __init__(self, name=None, values=None, size=None, seed=None, binary=False):
        if values is not None:
            self.values = values
            # save is used only for latex rendering
            # self.save = deepcopy(values)
            assert name is not None
            self.name = name
        else:

            def gen_values(level_left):
                if level_left > 1:
                    return [gen_values(level_left - 1) for _ in range(size)]
                else:
                    if binary:
                        print(
                            "warning: setting a min of 0.01 to achieve a continous distribution"
                        )
                        return [max(random.randint(0, 1), 0.01) for _ in range(size)]
                    else:
                        return [random.random() for _ in range(size)]

            random.seed(seed)
            self.values = gen_values(n)
            self.name = f"random-distribution-size{size}-seed{seed}-binary{binary}"

        self.size = len(self.values)

        def verify_size_for_every_level(x):
            if isinstance(x, list):
                assert len(x) == self.size, f"{x} incorrect size {self.size}"
                for c in x:
                    verify_size_for_every_level(c)
            else:
                assert x >= 0, f"{x} needs to be nonnegative"

        verify_size_for_every_level(self.values)

        def integrate_for_normalization(x):
            if isinstance(x, list):
                return sum(integrate_for_normalization(c) for c in x)
            else:
                return x / (self.size**n)

        self.normalization_ratio = integrate_for_normalization(self.values)

        def normalize(x):
            if isinstance(x, list):
                for i in range(len(x)):
                    x[i] = normalize(x[i])
                return x
            else:
                if self.normalization_ratio == 0:
                    return 1
                else:
                    return x / self.normalization_ratio

        normalize(self.values)
        assert abs(integrate_for_normalization(self.values) - 1) < 0.000001

    def retrieve_by_grid(self, grid, start=None):
        if len(grid) == n:
            start = self.values
        if len(grid) > 0:
            return self.retrieve_by_grid(grid[1:], start[grid[0]])
        else:
            return start

    def point_to_grid(self, point):
        # for x in point:
        #     if not 0 <= x < 1:
        #         print(f"point_to_grid numerical error causing out of range {point}")
        grid = tuple(
            cap_up_and_down(floor(x * self.size), lowBound=0, upBound=self.size - 1)
            for x in point
        )
        return grid

    def retrieve_by_point(self, point):
        return self.retrieve_by_grid(self.point_to_grid(point))

    def rejection_sampling(self, count):
        def find_cap(x):
            if isinstance(x, list):
                return max(find_cap(c) for c in x)
            else:
                return x

        cap = find_cap(self.values)
        res = []
        while len(res) < count:
            point = tuple(random.random() for _ in range(n))
            grid = self.point_to_grid(point)
            x = self.retrieve_by_grid(grid)
            y = random.random() * cap
            if x >= y:
                res.append(point)
        return res

    def conditional_distribution(self, grid, agent):
        res = []
        for x in range(self.size):
            modified_grid = tuple_update_by_index(grid, agent, x)
            res.append(self.retrieve_by_grid(modified_grid))
        return res

    def quadratic_formula_by_point(self, point, agent):
        grid = self.point_to_grid(point)
        return self.quadratic_formula(grid, agent)

    # do the integration within a grid for a given agent's dimension, resulting in a quadratic formula, returns a,b in ax^2+bx
    def quadratic_formula(self, grid, agent):
        cd = self.conditional_distribution(grid, agent)
        # the max expression has form: vprime * (integration_of_higher_grids + (current_grid_upper - vprime) * current_grid_density)
        integration_of_higher_grids = 0
        for x in range(grid[agent] + 1, self.size):
            integration_of_higher_grids += cd[x] / self.size
        current_grid_upper = (grid[agent] + 1) / self.size
        current_grid_density = self.retrieve_by_grid(grid)
        a = -current_grid_density
        b = integration_of_higher_grids + current_grid_upper * current_grid_density
        return a, b

    # the marginal_rev before negative partial differentiation
    def marginal_rev_integration(self, point, agent):
        current_max = 0
        current_max_where = 1
        grid = self.point_to_grid(point)
        for x in range(grid[agent], self.size):
            modified_grid = tuple_update_by_index(grid, agent, x)
            a, b = self.quadratic_formula(modified_grid, agent)
            grid_lower = modified_grid[agent] / self.size
            grid_upper = (modified_grid[agent] + 1) / self.size
            candidates = [grid_lower, grid_upper]
            if a < 0:
                mid = -b / a / 2
                if grid_lower < mid < grid_upper:
                    candidates.append(mid)
            for x in candidates:
                if x > point[agent]:
                    new_val = a * x * x + b * x
                    if new_val > current_max:
                        current_max = new_val
                        current_max_where = x
        a, b = self.quadratic_formula(grid, agent)
        y = a * point[agent] * point[agent] + b * point[agent]
        if y > current_max:
            return y, point[agent], (a, b)
        else:
            return current_max, current_max_where, None

    def marginal_rev(self, point, agent):
        current_max, _, ab = self.marginal_rev_integration(point, agent)
        if ab is None:
            return 0
        a, b = ab
        # a x^2 + bx partial wrt x is 2ax + b, there is also a minus sign in front after partial
        return (-1) * (2 * a * point[agent] + b)

    def marginal_rev_per_density(self, point, agent):
        return self.marginal_rev(point, agent) / self.retrieve_by_point(point)
