from settings import (
    n,
    m,
    device,
    supervised_steps,
    unsupervised_steps,
    fingerprint,
    target,
    sample_size,
    shape_n,
)
from utilities import get_timestamp, construct_relu_mlp_outputs, tuple_update_by_index
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import softmax
from torch.utils.data import TensorDataset, DataLoader
from pulp import (
    LpVariable,
    LpMaximize,
    LpProblem,
    GUROBI_CMD,
    value,
    LpStatus,
)
import numpy as np
from scipy.stats import sem
from itertools import combinations

all_winner_sets = []
for i in range(m + 1):
    all_winner_sets.extend(combinations(range(n), i))
print("all_winner_sets", all_winner_sets)


ratio = 1000
additional_steps = 0
network_shapes = [shape_n, shape_n]
batch_size = 16

my_training_dataloader = None

training_sample_size = 100000


def init_dataloaders(d):
    def value_func(profile, agent):
        if target == "mv":
            return d.marginal_rev_per_density(profile, agent)

    global my_training_dataloader

    profiles_raw = d.rejection_sampling(training_sample_size)
    profiles = torch.tensor(profiles_raw, device=device)
    value_func_res = [
        [value_func(profile, agent) for agent in range(n)] + [0]
        for profile in profiles_raw
    ]
    target_values = torch.tensor(
        value_func_res,
        device=device,
    )
    my_training_dataset = TensorDataset(profiles, target_values)
    my_training_dataloader = DataLoader(
        my_training_dataset, batch_size=batch_size, shuffle=True
    )


class A4(nn.Module):
    def __init__(self, shapes):
        super().__init__()
        shapes = [n] + shapes + [n + 1]
        stacks = []
        for i in range(len(shapes) - 1):
            stacks.append(nn.Linear(shapes[i], shapes[i + 1]))
            stacks.append(nn.ReLU())
        # remove the last ReLu
        stacks = stacks[:-1]
        self.linear_relu_stack = nn.Sequential(*stacks)
        self.shapes = shapes

    def forward(self, x):
        return self.linear_relu_stack(x)


def train(d):
    epochs = 1000
    a4 = A4(network_shapes)
    a4.to(device)
    # optimizer = optim.SGD(a4.parameters(), lr=0.001)
    optimizer = optim.Adam(a4.parameters(), lr=0.001)
    # fix_profiles = []

    init_dataloaders(d)

    for outer_index in range(supervised_steps + unsupervised_steps + additional_steps):
        a4.train()

        overall_loss = 0
        for idx in range(epochs):
            profiles, marginal_revs = next(iter(my_training_dataloader))
            allocations = softmax(a4(profiles), dim=1)
            # print(allocations)
            allocations = torch.mul(allocations, m)
            # print(allocations)
            allocations = torch.clamp(allocations, max=1)
            # print(allocations)

            loss = -torch.tensordot(allocations, marginal_revs)

            # if len(fix_profiles):
            #     for p1, movedup_bid, old_winner, new_winner in fix_profiles[-16:]:
            #         allocation1 = softmax(a4(torch.tensor(p1, device=device)))
            #         p2 = tuple_update_by_index(p1, old_winner, movedup_bid)
            #         allocation2 = softmax(a4(torch.tensor(p2, device=device)))
            #         loss += ratio * torch.relu(
            #             allocation1[old_winner] - allocation2[old_winner]
            #         )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            overall_loss += loss.data
        loss_res = overall_loss / epochs / batch_size
        print(outer_index, loss_res)

        fn = f"saved_story4/{get_timestamp()}-{fingerprint}-loss={loss_res}.saved"
        print(f"saving to {fn}")
        torch.save(a4, fn)

        a4.eval()
        if (
            (outer_index + 1) % 5 == 0
            or outer_index
            == supervised_steps + unsupervised_steps + additional_steps - 1
        ):
            new_fix_profiles = story4_mip_verifier(a4)
            # assert not new_fix_profiles
            print(new_fix_profiles)
            if not new_fix_profiles:
                story4_evaluate(d, a4)
            # fix_profiles.extend(new_fix_profiles)
            # if new_fix_profiles:
            #     story4_evaluate(d, a4, sp_fix=True)
            # else:
            #     story4_evaluate(d, a4)
        # story4_mip_verifier(a4)


# def get_winner(model, profile):
#     allocations = model(torch.tensor(profile, device=device))
#     allocations = [allocations[i].item() for i in range(n + 1)]
#     return allocations.index(max(allocations))


def get_winner_set(model, profile):
    winner_set = []
    allocations = model(torch.tensor(profile, device=device))
    allocations = [allocations[i].item() for i in range(n + 1)]
    allocations_for_winner = sorted([(allocations[i], i) for i in range(n)])
    for i in range(m):
        candidate = allocations_for_winner[-1 - i]
        if candidate[0] >= allocations[-1]:
            winner_set.append(candidate[1])
    res = tuple(sorted(winner_set))
    # print(res, all_winner_sets)
    assert res in all_winner_sets
    return res


def binary_search_cutoff_price(d, model, profile, winner, sp_fix=False):
    assert not sp_fix
    low = 0
    up = profile[winner]
    while up - low > 0.001:
        mid = (low + up) / 2
        new_profile = tuple_update_by_index(profile, winner, mid)
        new_winner_set = get_winner_set(model, new_profile)
        if winner in new_winner_set:
            up = mid
        else:
            low = mid
    new_profile = tuple_update_by_index(profile, winner, low)
    # revenue fix, pushing up offer when it is beneficial to do so
    _, correct_winning_bid, _ = d.marginal_rev_integration(new_profile, winner)
    # if sp_fix:
    #     sp_offer = story4_sp_offer(profile, winner, model)
    #     if sp_offer > correct_winning_bid:
    #         # print(
    #         #     f"pushing up winning bid from {correct_winning_bid} to {sp_offer} due to sp fix"
    #         # )
    #         correct_winning_bid = sp_offer
    return correct_winning_bid


def story4_evaluate(d, a4, sp_fix=False):
    assert not sp_fix
    res = []
    # if sp_fix:
    #     adjusted_sample_size = sample_size // 10
    # else:
    #     adjusted_sample_size = sample_size
    adjusted_sample_size = sample_size
    for profile in d.rejection_sampling(adjusted_sample_size):
        winner_set = get_winner_set(a4, profile)
        # if winner == n:
        #     res.append(0)
        #     continue
        revenue = 0
        for winner in winner_set:
            winning_bid = binary_search_cutoff_price(d, a4, profile, winner, sp_fix)
            if profile[winner] > winning_bid:
                revenue += winning_bid
        res.append(revenue)
    print(f"RESULT story4 evaluate sp_fix={sp_fix}", np.average(res), sem(res))


def story4_mip_verifier(model):
    max_res = 0
    fix_profiles = []
    for agent1 in range(n):
        for winner_set1 in all_winner_sets:
            if agent1 not in winner_set1:
                continue
            for winner_set2 in all_winner_sets:
                if agent1 in winner_set2:
                    continue
                res, p1, movedup_bid = story4_mip_verifier_helper(
                    agent1, winner_set1, winner_set2, model
                )
                if res is not None:
                    max_res = max(res, max_res)
                if res is not None and res > 0.000001:
                    fix_profiles.append([p1, movedup_bid, agent1, winner_set2])
                    # hack, short cut to save time
                    return fix_profiles
    print("RESULT story4 mip verifier max", max_res)
    return fix_profiles


def story4_mip_verifier_helper(oldwinner, winner_set1, winner_set2, model):
    print(f"verifying {oldwinner} {winner_set1} => {winner_set2}")
    a4 = model

    prob = LpProblem("story4_mip_verifier", LpMaximize)
    variables = {}

    for agent in range(n):
        variables[f"b{agent}"] = LpVariable(f"b{agent}", lowBound=0, upBound=1)
    movedup_bid = LpVariable("movedup_bid", lowBound=0, upBound=1)
    variables["movedup_bid"] = movedup_bid
    prob += movedup_bid >= variables[f"b{oldwinner}"]
    prob += movedup_bid - variables[f"b{oldwinner}"]
    ins = [variables[f"b{agent}"] for agent in range(n)]
    outs = construct_relu_mlp_outputs(
        a4,
        ins,
        prob,
        variables,
        "beforemoveup",
    )

    if len(winner_set1) < m:
        for i in range(n):
            if i not in winner_set1:
                prob += outs[i] <= outs[n]
    for i in winner_set1:
        for j in range(n + 1):
            if j in winner_set1:
                continue
            prob += outs[i] >= outs[j]

    ins[oldwinner] = movedup_bid
    outs_after_moveup = construct_relu_mlp_outputs(
        a4,
        ins,
        prob,
        variables,
        "aftermoveup",
    )

    # for other in range(n + 1):
    #     if other == newwinner:
    #         continue
    #     prob += outs_after_moveup[newwinner] >= outs_after_moveup[other]
    if len(winner_set2) < m:
        for i in range(n):
            if i not in winner_set2:
                prob += outs_after_moveup[i] <= outs_after_moveup[n]
    for i in winner_set2:
        for j in range(n + 1):
            if j in winner_set2:
                continue
            prob += outs_after_moveup[i] >= outs_after_moveup[j]

    prob.solve(
        GUROBI_CMD(
            msg=False,
        )
    )
    print(LpStatus[prob.status])
    print(value(prob.objective))
    if value(prob.objective) is not None and value(prob.objective) > 0.01:
        print([value(variables[f"b{agent}"]) for agent in range(n)])
        print([value(o) for o in outs])
        print(value(movedup_bid), oldwinner, winner_set1, winner_set2)
        print([value(o) for o in outs_after_moveup])
    return (
        value(prob.objective),
        [value(variables[f"b{agent}"]) for agent in range(n)],
        value(movedup_bid),
    )


# def story4_sp_offer(profile, winner, model):
#     max_res = 0
#     for opponent in range(n + 1):
#         if winner == opponent:
#             continue
#         res = story4_sp_offer_helper(profile, winner, opponent, model)
#         if res is not None:
#             max_res = max(res, max_res)
#     return max_res


# def story4_sp_offer_helper(profile, winner, opponent, model):
#     assert winner != opponent
#     assert 0 <= winner <= n - 1
#     assert 0 <= opponent <= n
#     a4 = model

#     prob = LpProblem("story4_sp_offer", LpMaximize)
#     variables = {}

#     for agent in range(n):
#         variables[f"b{agent}"] = LpVariable(f"b{agent}", lowBound=0, upBound=1)
#     for agent in range(n):
#         if agent != winner:
#             prob += variables[f"b{agent}"] == profile[agent]
#     prob += variables[f"b{winner}"]

#     ins = [variables[f"b{agent}"] for agent in range(n)]
#     outs = construct_relu_mlp_outputs(
#         a4,
#         ins,
#         prob,
#         variables,
#         "spoffer",
#     )
#     prob += outs[opponent] >= outs[winner]
#     prob.solve(
#         GUROBI_CMD(
#             msg=False,
#         )
#     )
#     return value(prob.objective)
