import numpy as np
from scipy.stats import sem
from settings import n, m, sample_size

# from utilities import tuple_update_by_index
# from itertools import product
# import distribution
# import random
# from pulp import (
#     LpVariable,
#     LpMaximize,
#     LpProblem,
#     lpSum,
#     GUROBI_CMD,
#     value,
# )
# import myerson


def mplus1_price_revenue(d, reserve=0, sample_size=sample_size):
    profiles = d.rejection_sampling(sample_size)
    res = []
    for profile in profiles:
        profile = sorted(profile)
        revenue = 0
        for i in range(m):
            if profile[-1 - i] >= reserve:
                revenue += max(profile[-1 - m], reserve)
        res.append(revenue)
    # print(f"reserve={reserve}, last={res[-10:]}, avg={np.average(res)}")
    return np.average(res), sem(res)


def search_mplus1_price_revenue(d):
    max_rev = (0, 0, 0)
    for t in range(1000):
        reserve = t / 1000
        rev_avg, rev_std = mplus1_price_revenue(d, reserve, sample_size // 10)
        if (rev_avg, rev_std, reserve) > max_rev:
            max_rev = (rev_avg, rev_std, reserve)
    # shouldn't directly return max_rev
    return mplus1_price_revenue(d, max_rev[-1], sample_size)


def greedy(d):
    res = []
    for profile in d.rejection_sampling(sample_size):
        tmp = sorted([d.marginal_rev_per_density(profile, agent) for agent in range(n)])
        res.append(sum(tmp[-m:]))
    return np.average(res), sem(res)


# def myerson_ignore_correlation(d):
#     cds = [[0 for _ in range(d.size)]] * n
#     for grid in product(*[range(d.size) for _ in range(n)]):
#         for agent in range(n):
#             cds[agent][grid[agent]] += d.retrieve_by_grid(grid)
#     hull_points = []
#     for agent in range(n):
#         cds[agent] = myerson.normalize(cds[agent])
#         # print(cds[agent])
#         hull_points.append(myerson.convex_hull_H(myerson.normalize(cds[agent])))

#     res = []
#     for profile in d.rejection_sampling(sample_size):
#         vals = [
#             myerson.ironed_virtual_valuations(
#                 hull_points[agent], cds[agent], profile[agent]
#             )
#             for agent in range(n)
#         ]
#         winner = vals.index(max(vals))
#         if vals[winner] <= 0:
#             res.append(0)
#             continue
#         tobeat = max(sorted(vals)[-2], 0)
#         low = 0
#         up = profile[winner]
#         while up - low > 0.001:
#             mid = (low + up) / 2
#             if (
#                 myerson.ironed_virtual_valuations(hull_points[winner], cds[winner], mid)
#                 > tobeat
#             ):
#                 up = mid
#             else:
#                 low = mid
#         # print(f"winner {winner} bids {profile[winner]} in {profile}, pays {low}")
#         res.append(low)
#     return np.average(res), sem(res)
