import numpy as np
from scipy.stats import sem
from settings import n, sample_size
from utilities import tuple_update_by_index
from itertools import product
import distribution
import random
from pulp import (
    LpVariable,
    LpMaximize,
    LpProblem,
    lpSum,
    GUROBI_CMD,
    value,
)
import myerson


def AMD(d, H, require_ex_post_sp=True, visualize=False):
    if H % d.size != 0:
        print(
            f"warning: AMD grid size {H} is not a multiple of distribution size {d.size}"
        )
    prob = LpProblem("AMD", LpMaximize)
    variables = {}
    for agent in range(n):
        for profile in product(*[range(H) for _ in range(n)]):
            if require_ex_post_sp:
                variables[f"a{agent}-{profile}"] = LpVariable(
                    f"a{agent}-{profile}", cat="Binary"
                )
            else:
                variables[f"a{agent}-{profile}"] = LpVariable(
                    f"a{agent}-{profile}", lowBound=0, upBound=1
                )
            variables[f"p{agent}-{profile}"] = LpVariable(
                f"p{agent}-{profile}", lowBound=0
            )

    for profile in product(*[range(H) for _ in range(n)]):
        # allocation feasiblity
        prob += lpSum([variables[f"a{agent}-{profile}"] for agent in range(n)]) <= 1

    for agent in range(n):
        for profile in product(*[range(H) for _ in range(n)]):
            if profile[agent] == 0:
                continue
            else:
                # monotonicity
                prob += (
                    variables[f"a{agent}-{profile}"]
                    >= variables[
                        f"a{agent}-{tuple_update_by_index(profile, agent, profile[agent]-1)}"
                    ]
                )

    for agent in range(n):
        for profile in product(*[range(H) for _ in range(n)]):
            if profile[agent] == 0:
                # bid 0 then pays 0
                prob += variables[f"p{agent}-{profile}"] == 0
            else:
                # bid other values then pays the marginal gain
                prob += variables[f"p{agent}-{profile}"] == lpSum(
                    [
                        t
                        / H
                        * (
                            variables[
                                f"a{agent}-{tuple_update_by_index(profile, agent, t)}"
                            ]
                            - variables[
                                f"a{agent}-{tuple_update_by_index(profile, agent, t-1)}"
                            ]
                        )
                        for t in range(1, profile[agent] + 1)
                    ]
                )

    prob += lpSum(
        [
            d.AMD_grid_integration(profile, H) * variables[f"p{agent}-{profile}"]
            for agent in range(n)
            for profile in product(*[range(H) for _ in range(n)])
        ]
    )

    prob.solve(
        GUROBI_CMD(
            msg=False,
        )
    )

    total = 0
    wrong = 0
    for profile in product(*[range(H) for _ in range(n)]):
        total += 1
        x, y = profile
        x /= H
        y /= H
        if value(variables[f"a0-{profile}"]) == 1:
            if d.marginal_rev_per_density((x, y), 0) < d.marginal_rev_per_density(
                (x, y), 1
            ):
                wrong += 1
        elif value(variables[f"a1-{profile}"]) == 1:
            if d.marginal_rev_per_density((x, y), 0) > d.marginal_rev_per_density(
                (x, y), 1
            ):
                wrong += 1
    print(f"wrong {wrong} over {total}")

    return value(prob.objective), wrong

    # visualize 2d
    # if visualize:
    #     assert require_ex_post_sp
    #     assert n == 2
    #     for profile in product(*[range(H) for _ in range(n)]):
    #         x, y = profile
    #         x /= H
    #         y /= H
    #         if value(variables[f"a0-{profile}"]) == 1:
    #             plt.scatter(x, y, c="blue")
    #         elif value(variables[f"a1-{profile}"]) == 1:
    #             plt.scatter(x, y, c="red")
    #         else:
    #             plt.scatter(x, y, c="black")
    #     plt.savefig(f"amd-{d.name}.pdf")


def second_price_revenue(d, reserve=0, sample_size=sample_size):
    profiles = d.rejection_sampling(sample_size)
    res = []
    for profile in profiles:
        profile = sorted(profile)
        if profile[-1] >= reserve:
            res.append(max(profile[-2], reserve))
        else:
            res.append(0)
    # print(f"reserve={reserve}, last={res[-10:]}, avg={np.average(res)}")
    return np.average(res), sem(res)


def search_second_price_revenue(d):
    max_rev = (0, 0, 0)
    for t in range(1000):
        reserve = t / 1000
        rev_avg, rev_std = second_price_revenue(d, reserve, sample_size // 10)
        if (rev_avg, rev_std, reserve) > max_rev:
            max_rev = (rev_avg, rev_std, reserve)
            # print(max_rev)
    # return max_rev
    # shouldn't directly return max_rev
    return second_price_revenue(d, max_rev[-1], sample_size)


def greedy(d):
    res = []
    for profile in d.rejection_sampling(sample_size):
        res.append(
            max(d.marginal_rev_per_density(profile, agent) for agent in range(n))
        )
    return np.average(res), sem(res)


def myerson_ignore_correlation(d):
    cds = [[0 for _ in range(d.size)]] * n
    for grid in product(*[range(d.size) for _ in range(n)]):
        for agent in range(n):
            cds[agent][grid[agent]] += d.retrieve_by_grid(grid)
    hull_points = []
    for agent in range(n):
        cds[agent] = myerson.normalize(cds[agent])
        # print(cds[agent])
        hull_points.append(myerson.convex_hull_H(myerson.normalize(cds[agent])))

    res = []
    for profile in d.rejection_sampling(sample_size):
        vals = [
            myerson.ironed_virtual_valuations(
                hull_points[agent], cds[agent], profile[agent]
            )
            for agent in range(n)
        ]
        winner = vals.index(max(vals))
        if vals[winner] <= 0:
            res.append(0)
            continue
        tobeat = max(sorted(vals)[-2], 0)
        low = 0
        up = profile[winner]
        while up - low > 0.001:
            mid = (low + up) / 2
            if (
                myerson.ironed_virtual_valuations(hull_points[winner], cds[winner], mid)
                > tobeat
            ):
                up = mid
            else:
                low = mid
        # print(f"winner {winner} bids {profile[winner]} in {profile}, pays {low}")
        res.append(low)
    return np.average(res), sem(res)


def mutate(d, mu=0.1):
    assert n == 2
    size = d.size
    mutate_template = distribution.GridDistribution(
        size=size, seed=random.random(), binary=False
    )
    values = mutate_template.values
    for i in range(size):
        for j in range(size):
            values[i][j] += (
                d.values[i][j] * (1 - mu) + mutate_template.values[i][j] * mu
            )
    return distribution.GridDistribution(name="mutate", values=values)


def find_worst_distribution(d):
    max_wrong = 0
    for evolution in range(100):
        new_d = mutate(d)
        _, wrong = AMD(new_d, 20)
        print(f"evolution {evolution} wrong {wrong} max_wrong {max_wrong}")
        if wrong > max_wrong:
            d = new_d
            max_wrong = wrong
            print(evolution, max_wrong)
            print(d.values)
    return d
