from settings import n
from utilities import tuple_update_by_index, cap_up_and_down
import random
from math import floor

# from copy import deepcopy


def uniform_sampling(count):
    return [tuple(random.random() for _ in range(n)) for _ in range(count)]


class GridDistribution:
    def __init__(self, name=None, values=None, size=None, seed=None, binary=False):
        if values is not None:
            self.values = values
            # save is used only for latex rendering
            # self.save = deepcopy(values)
            assert name is not None
            self.name = name
        else:

            def gen_values(level_left):
                if level_left > 1:
                    return [gen_values(level_left - 1) for _ in range(size)]
                else:
                    if binary:
                        print(
                            "warning: setting a min of 0.01 to achieve a continous distribution"
                        )
                        return [max(random.randint(0, 1), 0.01) for _ in range(size)]
                    else:
                        return [random.random() for _ in range(size)]

            random.seed(seed)
            self.values = gen_values(n)
            self.name = f"random-distribution-size{size}-seed{seed}-binary{binary}"

        self.size = len(self.values)

        def verify_size_for_every_level(x):
            if isinstance(x, list):
                assert len(x) == self.size, f"{x} incorrect size {self.size}"
                for c in x:
                    verify_size_for_every_level(c)
            else:
                assert x >= 0, f"{x} needs to be nonnegative"

        verify_size_for_every_level(self.values)

        def integrate_for_normalization(x):
            if isinstance(x, list):
                return sum(integrate_for_normalization(c) for c in x)
            else:
                return x / (self.size**n)

        self.normalization_ratio = integrate_for_normalization(self.values)

        def normalize(x):
            if isinstance(x, list):
                for i in range(len(x)):
                    x[i] = normalize(x[i])
                return x
            else:
                if self.normalization_ratio == 0:
                    return 1
                else:
                    return x / self.normalization_ratio

        normalize(self.values)
        assert abs(integrate_for_normalization(self.values) - 1) < 0.000001

    def retrieve_by_grid(self, grid, start=None):
        if len(grid) == n:
            start = self.values
        if len(grid) > 0:
            return self.retrieve_by_grid(grid[1:], start[grid[0]])
        else:
            return start

    def AMD_grid_integration(self, AMD_grid, H):
        # # estimation only -- which is unnecessary if H % self.size == 0
        # integration_steps = 100
        # res = 0
        # for i in range(integration_steps):
        #     estimated_point = [
        #         x / H + (i + 1) / (integration_steps + 1) / H for x in AMD_grid
        #     ]
        #     res += self.retrieve_by_point(estimated_point) / (H**n)
        # return res / integration_steps
        estimated_point = [x / H for x in AMD_grid]
        return self.retrieve_by_point(estimated_point) / (H**n)

    def point_to_grid(self, point):
        # for x in point:
        #     if not 0 <= x < 1:
        #         print(f"point_to_grid numerical error causing out of range {point}")
        grid = tuple(
            cap_up_and_down(floor(x * self.size), lowBound=0, upBound=self.size - 1)
            for x in point
        )
        return grid

    def retrieve_by_point(self, point):
        return self.retrieve_by_grid(self.point_to_grid(point))

    def rejection_sampling(self, count):
        def find_cap(x):
            if isinstance(x, list):
                return max(find_cap(c) for c in x)
            else:
                return x

        cap = find_cap(self.values)
        res = []
        while len(res) < count:
            point = tuple(random.random() for _ in range(n))
            grid = self.point_to_grid(point)
            x = self.retrieve_by_grid(grid)
            y = random.random() * cap
            if x >= y:
                res.append(point)
        return res

    def conditional_distribution(self, grid, agent):
        res = []
        for x in range(self.size):
            modified_grid = tuple_update_by_index(grid, agent, x)
            res.append(self.retrieve_by_grid(modified_grid))
        return res

    def quadratic_formula_by_point(self, point, agent):
        grid = self.point_to_grid(point)
        return self.quadratic_formula(grid, agent)

    # do the integration within a grid for a given agent's dimension, resulting in a quadratic formula, returns a,b in ax^2+bx
    def quadratic_formula(self, grid, agent):
        cd = self.conditional_distribution(grid, agent)
        # the max expression has form: vprime * (integration_of_higher_grids + (current_grid_upper - vprime) * current_grid_density)
        integration_of_higher_grids = 0
        for x in range(grid[agent] + 1, self.size):
            integration_of_higher_grids += cd[x] / self.size
        current_grid_upper = (grid[agent] + 1) / self.size
        current_grid_density = self.retrieve_by_grid(grid)
        a = -current_grid_density
        b = integration_of_higher_grids + current_grid_upper * current_grid_density
        return a, b

    # the marginal_rev before negative partial differentiation
    def marginal_rev_integration(self, point, agent):
        current_max = 0
        current_max_where = 1
        grid = self.point_to_grid(point)
        for x in range(grid[agent], self.size):
            modified_grid = tuple_update_by_index(grid, agent, x)
            a, b = self.quadratic_formula(modified_grid, agent)
            grid_lower = modified_grid[agent] / self.size
            grid_upper = (modified_grid[agent] + 1) / self.size
            candidates = [grid_lower, grid_upper]
            if a < 0:
                mid = -b / a / 2
                if grid_lower < mid < grid_upper:
                    candidates.append(mid)
            for x in candidates:
                if x > point[agent]:
                    new_val = a * x * x + b * x
                    if new_val > current_max:
                        current_max = new_val
                        current_max_where = x
        a, b = self.quadratic_formula(grid, agent)
        y = a * point[agent] * point[agent] + b * point[agent]
        if y > current_max:
            return y, point[agent], (a, b)
        else:
            return current_max, current_max_where, None

    def marginal_rev(self, point, agent):
        current_max, _, ab = self.marginal_rev_integration(point, agent)
        if ab is None:
            return 0
        a, b = ab
        # a x^2 + bx partial wrt x is 2ax + b, there is also a minus sign in front after partial
        return (-1) * (2 * a * point[agent] + b)

    def marginal_rev_per_density(self, point, agent):
        return self.marginal_rev(point, agent) / self.retrieve_by_point(point)

    def verify_greedy(self):
        def still_win(profile, old_winner):
            mvs = [self.marginal_rev(profile, agent) for agent in range(n)]
            return mvs[old_winner] == max(mvs)

        def get_unique_winner(profile):
            mvs = [self.marginal_rev(profile, agent) for agent in range(n)]
            max_mv = max(mvs)
            count = sum(1 if t == max_mv else 0 for t in mvs)
            if count == 1:
                return mvs.index(max_mv)
            else:
                return None

        for profile in uniform_sampling(10000):
            winner = get_unique_winner(profile)
            if winner is None:
                continue
            for r in range(100):
                new_bid = profile[winner] + (1 - profile[winner]) * r / 100
                modified_profile = tuple_update_by_index(profile, winner, new_bid)
                if not still_win(modified_profile, winner):
                    print(winner)
                    print(profile)
                    print([self.marginal_rev(profile, agent) for agent in range(n)])
                    print(modified_profile)
                    print(
                        [
                            self.marginal_rev(modified_profile, agent)
                            for agent in range(n)
                        ]
                    )
                    return False
        return True

    def get_marginal_value_winners(self, point):
        mvs = [self.marginal_rev(point, agent) for agent in range(n)]
        max_mv = max(mvs)
        winners = []
        for agent in range(n):
            if mvs[agent] == max_mv:
                winners.append(agent)
        return winners, max_mv


saturn = GridDistribution("saturn", [[4, 9, 2], [3, 5, 7], [8, 1, 6]])

jupiter = GridDistribution(
    "jupiter", [[4, 14, 15, 1], [9, 7, 6, 12], [5, 11, 10, 8], [16, 2, 3, 13]]
)

mars = GridDistribution(
    "mars",
    [
        [11, 24, 7, 20, 3],
        [4, 12, 25, 8, 16],
        [17, 5, 13, 21, 9],
        [10, 18, 1, 14, 22],
        [23, 6, 19, 2, 15],
    ],
)

sol = GridDistribution(
    "sol",
    [
        [6, 32, 3, 34, 35, 1],
        [7, 11, 27, 28, 8, 30],
        [19, 14, 16, 15, 23, 24],
        [18, 20, 22, 21, 17, 13],
        [25, 29, 10, 9, 26, 12],
        [36, 5, 33, 4, 2, 31],
    ],
)

venus = GridDistribution(
    "venus",
    [
        [22, 47, 16, 41, 10, 35, 4],
        [5, 23, 48, 17, 42, 11, 29],
        [30, 6, 24, 49, 18, 36, 12],
        [13, 31, 7, 25, 43, 19, 37],
        [38, 14, 32, 1, 26, 44, 20],
        [21, 39, 8, 33, 2, 27, 45],
        [46, 15, 40, 9, 34, 3, 28],
    ],
)

mercury = GridDistribution(
    "mercury",
    [
        [8, 58, 59, 5, 4, 62, 63, 1],
        [49, 15, 14, 52, 53, 11, 10, 56],
        [41, 23, 22, 44, 45, 19, 18, 48],
        [32, 34, 35, 29, 28, 38, 39, 25],
        [40, 26, 27, 37, 36, 30, 31, 33],
        [17, 47, 46, 20, 21, 43, 42, 24],
        [9, 55, 54, 12, 13, 51, 50, 16],
        [64, 2, 3, 61, 60, 6, 7, 57],
    ],
)

luna = GridDistribution(
    "luna",
    [
        [37, 78, 29, 70, 21, 62, 13, 54, 5],
        [6, 38, 79, 30, 71, 22, 63, 14, 46],
        [47, 7, 39, 80, 31, 72, 23, 55, 15],
        [16, 48, 8, 40, 81, 32, 64, 24, 56],
        [57, 17, 49, 9, 41, 73, 33, 65, 25],
        [26, 58, 18, 50, 1, 42, 74, 34, 66],
        [67, 27, 59, 10, 51, 2, 43, 75, 35],
        [36, 68, 19, 60, 11, 52, 3, 44, 76],
        [77, 28, 69, 20, 61, 12, 53, 4, 45],
    ],
)

worst100 = GridDistribution(
    "worst100",
    [
        [1, 1.0, 1.0, 1.0, 1.0],
        [1.0, 100.0, 1.0, 1.0, 1.0],
        [1.0, 1.0, 48.5, 24.25, 72.75],
        [1.0, 14.25, 1.0, 1.0, 3.0],
        [1.0, 42.75, 1.0, 1.0, 1.0],
    ],
)

worst10 = GridDistribution(
    "worst10",
    [
        [1, 1.0, 1.0, 1.0, 1.0],
        [1.0, 10.0, 9.0, 1.0, 10.0],
        [1.0, 1.0, 1.0, 1.0, 1.0],
        [1.0, 1.0, 1.0, 2.5892857142845793, 1.0],
        [1.0, 1.0, 1.0, 7.7678571428537335, 1.0],
    ],
)

# adversarial distributions obtained via evolutionary computation, saved for evaluation
EC = [None for _ in range(10)]

EC[0] = GridDistribution(
    "ec0",
    [
        [
            1.3967981564269358,
            1.1678827660825315,
            1.4276175512588902,
            1.0432138210131767,
            0.7508255291175423,
        ],
        [
            0.6944327431173817,
            1.0851228252508465,
            0.9693516771745261,
            1.0828926912877443,
            1.0341501890154445,
        ],
        [
            1.4026549909776473,
            0.6408647063463043,
            1.4241759971543715,
            1.2246214882841997,
            1.3275380876686826,
        ],
        [
            1.5822600137818754,
            0.7439098784010663,
            1.2975245321541804,
            0.15771882104783036,
            0.141572081262576,
        ],
        [
            1.2718215591076385,
            1.117268934262278,
            0.5568791249086729,
            0.5125865751293259,
            0.9463152597683312,
        ],
    ],
)

EC[1] = GridDistribution(
    "ec1",
    [
        [
            0.6840165976423755,
            1.1412860996246894,
            1.5095536639919533,
            0.7334173998361309,
            0.7397150781300039,
        ],
        [
            0.9673608012537767,
            1.3579639583494147,
            1.0810250272240387,
            0.9176110271952809,
            1.0446857664734228,
        ],
        [
            1.6034687993727603,
            0.4526829676783068,
            1.629655170621167,
            0.19162916936246088,
            1.2178374851419604,
        ],
        [
            0.9226718121117587,
            1.1333928422383366,
            0.8024513731888758,
            0.4768550787420708,
            0.9754740258078353,
        ],
        [
            1.3190308323772961,
            1.301252714030714,
            1.121335804836391,
            0.5883320807573571,
            1.0872944240116238,
        ],
    ],
)

EC[2] = GridDistribution(
    "ec2",
    [
        [
            0.35405521163734355,
            0.9512946659032573,
            1.4571531617485634,
            0.9488305377264566,
            0.5975634894134875,
        ],
        [
            1.3303470861291329,
            1.0023197652182305,
            1.376658647592819,
            1.0836460542783355,
            1.1835609382321859,
        ],
        [
            0.34520742686227734,
            1.1748700596945854,
            1.3614256475017708,
            0.9454104677515205,
            1.4077000277596667,
        ],
        [
            1.4456485222556528,
            0.7637456610649407,
            0.6158987334908685,
            0.23778966545002797,
            0.18296896800845447,
        ],
        [
            1.2629972730898222,
            1.067868157585239,
            1.6325134434025705,
            0.9427615682720647,
            1.3277648199307253,
        ],
    ],
)

EC[3] = GridDistribution(
    "ec3",
    [
        [
            0.79591048577358,
            0.8710840836079892,
            0.5248231659516458,
            1.3004986747663723,
            1.1940294572061942,
        ],
        [
            0.8990550213884985,
            1.6086251329321444,
            1.478420731879929,
            0.6204452501840629,
            0.569261251528724,
        ],
        [
            1.658411663193613,
            1.0579894835790158,
            1.3755721285567113,
            0.2887559324714196,
            1.496331599027354,
        ],
        [
            1.0534063995363576,
            1.4370403057359993,
            0.6015604895424471,
            0.42032445352123965,
            0.27987975340869253,
        ],
        [
            0.5632579975835212,
            1.2822776543916596,
            1.469913384812181,
            0.9756952509445659,
            1.1774302484760846,
        ],
    ],
)

EC[4] = GridDistribution(
    "ec4",
    [
        [
            0.4354050498183514,
            0.5983810792388418,
            1.2883489012203402,
            0.7941971259558351,
            1.2678196317143868,
        ],
        [
            1.3387198940658407,
            1.3899790338776772,
            0.29429848232134576,
            0.853148862767086,
            1.2922124682780933,
        ],
        [
            0.9732233180681471,
            1.1529484607470635,
            1.4176078702324042,
            0.6488030538862724,
            1.9395879197713517,
        ],
        [
            0.9735401059811337,
            1.4807803534070987,
            1.1573825579897625,
            0.39167229505528334,
            0.6303225580902491,
        ],
        [
            0.4774572221788472,
            1.1643701892339817,
            1.0685904906867596,
            0.8619329318177501,
            1.1092701435960963,
        ],
    ],
)

EC[5] = GridDistribution(
    "ec5",
    [
        [
            0.34021241031352045,
            1.4426356936424545,
            0.9596710599186852,
            0.8921261771054732,
            1.2369285873648146,
        ],
        [
            0.613698915265927,
            0.583232827058628,
            1.3962114826616534,
            0.5248622248153582,
            0.5051098533058331,
        ],
        [
            1.1748792408657631,
            0.8259797823083384,
            1.8246600353055547,
            0.7079864056441763,
            1.6121582047538259,
        ],
        [
            1.3192165904684907,
            1.4286390055917422,
            1.4521535924404918,
            0.34976866073222374,
            0.7638347841924702,
        ],
        [
            0.7669441396550428,
            0.8312997720406222,
            1.6979946957654906,
            0.456385142381482,
            1.2934107164019413,
        ],
    ],
)

EC[6] = GridDistribution(
    "ec6",
    [
        [
            0.93380511838244,
            0.764098787539554,
            1.480373072406561,
            0.8456377091687615,
            0.875934405930365,
        ],
        [
            0.7911622000708828,
            0.9854524053912372,
            1.0504804786001847,
            1.3302927841986827,
            0.9615520272943961,
        ],
        [
            0.5458570355980524,
            0.8158363326186051,
            1.1222972081678306,
            1.1732365607555642,
            1.4686102324954107,
        ],
        [
            1.2273429478306692,
            1.0149715554603302,
            0.41446243755652357,
            0.4808013573967449,
            0.25538075677247396,
        ],
        [
            0.98064110641479,
            1.6116416260111222,
            1.3797802211464125,
            0.9334452407660029,
            1.5569063920264035,
        ],
    ],
)

EC[7] = GridDistribution(
    "ec7",
    [
        [
            1.565050970151152,
            1.1827556789048863,
            1.319428312747822,
            0.7691909909752974,
            0.6930963681536647,
        ],
        [
            0.39233891628805356,
            0.9192135595798248,
            1.1509367998446052,
            0.7149315800252726,
            1.350402595958516,
        ],
        [
            1.0081872095450028,
            0.8476217713411734,
            1.1941005937837752,
            0.6786001725181682,
            1.4876314231214731,
        ],
        [
            0.9697023392829813,
            1.461974597525024,
            1.5289137499649068,
            0.6900127747814983,
            0.47897769909257204,
        ],
        [
            1.2222703230523402,
            1.7164684626443427,
            0.7985273839052088,
            0.28376133732565473,
            0.5759043894867818,
        ],
    ],
)

EC[8] = GridDistribution(
    "ec8",
    [
        [
            0.7097708389033067,
            1.01841884991948,
            0.8177649916339821,
            0.8801056817206476,
            0.7910111801480305,
        ],
        [
            1.020715310704338,
            0.7403891280723569,
            1.4363043037470409,
            1.3142417486280158,
            1.5640532071802324,
        ],
        [
            1.3537147004925791,
            0.35930093739422,
            1.1591796145559274,
            0.48321001312037515,
            1.1971354485455574,
        ],
        [
            0.8497289599874942,
            1.5672572859088807,
            0.22328067451959904,
            1.0918521206814287,
            1.3591408499177742,
        ],
        [
            1.0455111680214235,
            1.2653030909882974,
            1.1101997327442368,
            1.251612663384629,
            0.39079749908014383,
        ],
    ],
)

EC[9] = GridDistribution(
    "ec9",
    [
        [
            0.9324910124012167,
            0.8519891327693419,
            0.7384059469284192,
            1.0969815848900437,
            0.8188817019433081,
        ],
        [
            1.1854953658314558,
            0.9971621577635038,
            1.6649973687642987,
            1.508431339437923,
            1.4128546322278397,
        ],
        [
            0.8755737595657502,
            1.214897256330518,
            1.2127137295627137,
            0.1202593621158089,
            1.080953073912272,
        ],
        [
            0.723550264490681,
            1.188657225539934,
            0.30472853049521514,
            1.1659412034032681,
            0.5898765657035459,
        ],
        [
            0.9780523093532204,
            1.0921353503137046,
            1.4401880328083225,
            0.8236699864141144,
            0.9811131070335806,
        ],
    ],
)


def render(d):
    print(
        r"""\[
\begin{bmatrix}"""
    )
    lines = []
    for i in range(d.size):
        lines.append(" & ".join([str(d.save[i][j]) for j in range(d.size)]))
    print(r" \\ ".join(lines))
    print(
        r"""\[
\end{bmatrix}
\]"""
    )
