from settings import (
    n,
    device,
    supervised_steps,
    unsupervised_steps,
    fingerprint,
    target,
    sample_size,
    shape_n,
)
from utilities import get_timestamp, construct_relu_mlp_outputs, tuple_update_by_index
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import softmax, mse_loss
from torch.utils.data import TensorDataset, DataLoader
from pulp import (
    LpVariable,
    LpMaximize,
    LpProblem,
    GUROBI_CMD,
    value,
    LpStatus,
)
import numpy as np
from scipy.stats import sem
import myerson


ratio = 1000
additional_steps = 10
network_shapes = [shape_n, shape_n]
batch_size = 16

my_training_dataloader = None

training_sample_size = 100000

convex_hull_dict = {}


def init_dataloaders(d):
    def value_func(profile, agent):
        if target == "mv":
            return d.marginal_rev_per_density(profile, agent)
        if target == "vv":
            grid = d.point_to_grid(profile)
            cd = d.conditional_distribution(grid, agent)
            cd = myerson.normalize(cd)
            return myerson.virtual_valuation(cd, profile[agent])
        if target == "ivv":
            grid = d.point_to_grid(profile)
            cd = d.conditional_distribution(grid, agent)
            cd = myerson.normalize(cd)
            if tuple(cd) in convex_hull_dict:
                hull_points = convex_hull_dict[tuple(cd)]
            else:
                hull_points = myerson.convex_hull_H(cd)
                convex_hull_dict[tuple(cd)] = hull_points
            return myerson.ironed_virtual_valuations(hull_points, cd, profile[agent])

    def gen_goal(d, profile):
        mvs = [value_func(profile, agent) for agent in range(n)] + [0]
        max_mv = max(mvs)
        return [1.0 if mv == max_mv else 0 for mv in mvs]

    global my_training_dataloader

    profiles_raw = d.rejection_sampling(training_sample_size)
    profiles = torch.tensor(profiles_raw, device=device)
    value_func_res = [
        [value_func(profile, agent) for agent in range(n)] + [0]
        for profile in profiles_raw
    ]
    target_values = torch.tensor(
        value_func_res,
        device=device,
    )
    greedy_allocations = torch.tensor(
        [[1.0 if v == max(line) else 0 for v in line] for line in value_func_res],
        device=device,
    )
    my_training_dataset = TensorDataset(profiles, target_values, greedy_allocations)
    my_training_dataloader = DataLoader(
        my_training_dataset, batch_size=batch_size, shuffle=True
    )


class A4(nn.Module):
    def __init__(self, shapes):
        super().__init__()
        shapes = [n] + shapes + [n + 1]
        stacks = []
        for i in range(len(shapes) - 1):
            stacks.append(nn.Linear(shapes[i], shapes[i + 1]))
            stacks.append(nn.ReLU())
        # remove the last ReLu
        stacks = stacks[:-1]
        self.linear_relu_stack = nn.Sequential(*stacks)
        self.shapes = shapes

    def forward(self, x):
        return self.linear_relu_stack(x)


def train(d):
    epochs = 1000
    a4 = A4(network_shapes)
    a4.to(device)
    # optimizer = optim.SGD(a4.parameters(), lr=0.001)
    optimizer = optim.Adam(a4.parameters(), lr=0.001)
    fix_profiles = []

    init_dataloaders(d)

    for outer_index in range(supervised_steps + unsupervised_steps + additional_steps):
        a4.train()

        overall_loss = 0
        for idx in range(epochs):
            profiles, marginal_revs, greedy_allocations = next(
                iter(my_training_dataloader)
            )
            allocations = softmax(a4(profiles), dim=1)

            if outer_index < supervised_steps:
                loss = mse_loss(allocations, greedy_allocations)
            else:
                loss = -torch.tensordot(allocations, marginal_revs)

            if len(fix_profiles):
                for p1, movedup_bid, old_winner, new_winner in fix_profiles[-16:]:
                    allocation1 = softmax(a4(torch.tensor(p1, device=device)))
                    p2 = tuple_update_by_index(p1, old_winner, movedup_bid)
                    allocation2 = softmax(a4(torch.tensor(p2, device=device)))
                    loss += ratio * torch.relu(
                        allocation1[old_winner] - allocation2[old_winner]
                    )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            overall_loss += loss.data
        loss_res = overall_loss / epochs / batch_size
        print(outer_index, loss_res)

        fn = f"saved_story4/{get_timestamp()}-{fingerprint}-loss={loss_res}.saved"
        print(f"saving to {fn}")
        torch.save(a4, fn)

        a4.eval()
        if (
            (outer_index + 1) % 20 == 0
            or outer_index
            == supervised_steps + unsupervised_steps + additional_steps - 1
        ):
            new_fix_profiles = story4_mip_verifier(a4)
            fix_profiles.extend(new_fix_profiles)
            if new_fix_profiles:
                story4_evaluate(d, a4, sp_fix=True)
            else:
                story4_evaluate(d, a4)


def get_winner(model, profile):
    allocations = model(torch.tensor(profile, device=device))
    allocations = [allocations[i].item() for i in range(n + 1)]
    return allocations.index(max(allocations))


def binary_search_cutoff_price(d, model, profile, winner, sp_fix=False):
    low = 0
    up = profile[winner]
    while up - low > 0.001:
        mid = (low + up) / 2
        new_profile = tuple_update_by_index(profile, winner, mid)
        new_winner = get_winner(model, new_profile)
        if new_winner == winner:
            up = mid
        else:
            low = mid
    new_profile = tuple_update_by_index(profile, winner, low)
    # revenue fix, pushing up offer when it is beneficial to do so
    _, correct_winning_bid, _ = d.marginal_rev_integration(new_profile, winner)
    if sp_fix:
        sp_offer = story4_sp_offer(profile, winner, model)
        if sp_offer > correct_winning_bid:
            # print(
            #     f"pushing up winning bid from {correct_winning_bid} to {sp_offer} due to sp fix"
            # )
            correct_winning_bid = sp_offer
    return correct_winning_bid


def story4_evaluate(d, a4, sp_fix=False):
    res = []
    if sp_fix:
        adjusted_sample_size = sample_size // 10
    else:
        adjusted_sample_size = sample_size
    for profile in d.rejection_sampling(adjusted_sample_size):
        winner = get_winner(a4, profile)
        if winner == n:
            res.append(0)
            continue
        winning_bid = binary_search_cutoff_price(d, a4, profile, winner, sp_fix)
        if profile[winner] > winning_bid:
            res.append(winning_bid)
        else:
            res.append(0)
    print(f"RESULT story4 evaluate sp_fix={sp_fix}", np.average(res), sem(res))


def story4_mip_verifier(model):
    max_res = 0
    fix_profiles = []
    for agent1 in range(n):
        # n means no allocation
        for agent2 in range(n + 1):
            if agent1 == agent2:
                continue
            res, p1, movedup_bid = story4_mip_verifier_helper(agent1, agent2, model)
            if res is not None:
                max_res = max(res, max_res)
            if res is not None and res > 0.000001:
                fix_profiles.append([p1, movedup_bid, agent1, agent2])
    print("RESULT story4 mip verifier max", max_res)
    return fix_profiles


def story4_mip_verifier_helper(oldwinner, newwinner, model):
    print(f"verifying {oldwinner} => {newwinner}")
    assert oldwinner != newwinner
    assert 0 <= oldwinner <= n - 1
    a4 = model

    prob = LpProblem("story4_mip_verifier", LpMaximize)
    variables = {}

    for agent in range(n):
        variables[f"b{agent}"] = LpVariable(f"b{agent}", lowBound=0, upBound=1)
    movedup_bid = LpVariable("movedup_bid", lowBound=0, upBound=1)
    variables["movedup_bid"] = movedup_bid
    prob += movedup_bid >= variables[f"b{oldwinner}"]
    prob += movedup_bid - variables[f"b{oldwinner}"]
    ins = [variables[f"b{agent}"] for agent in range(n)]
    outs = construct_relu_mlp_outputs(
        a4,
        ins,
        prob,
        variables,
        "beforemoveup",
    )
    for other in range(n + 1):
        if other == oldwinner:
            continue
        prob += outs[oldwinner] >= outs[other]
    ins[oldwinner] = movedup_bid
    outs_after_moveup = construct_relu_mlp_outputs(
        a4,
        ins,
        prob,
        variables,
        "aftermoveup",
    )
    for other in range(n + 1):
        if other == newwinner:
            continue
        prob += outs_after_moveup[newwinner] >= outs_after_moveup[other]
    prob.solve(
        GUROBI_CMD(
            msg=False,
        )
    )
    print(LpStatus[prob.status])
    print(value(prob.objective))
    if value(prob.objective) is not None and value(prob.objective) > 0.01:
        print([value(variables[f"b{agent}"]) for agent in range(n)])
        print([value(o) for o in outs])
        print(value(movedup_bid), oldwinner, newwinner)
        print([value(o) for o in outs_after_moveup])
    return (
        value(prob.objective),
        [value(variables[f"b{agent}"]) for agent in range(n)],
        value(movedup_bid),
    )


def story4_sp_offer(profile, winner, model):
    max_res = 0
    for opponent in range(n + 1):
        if winner == opponent:
            continue
        res = story4_sp_offer_helper(profile, winner, opponent, model)
        if res is not None:
            max_res = max(res, max_res)
    return max_res


def story4_sp_offer_helper(profile, winner, opponent, model):
    assert winner != opponent
    assert 0 <= winner <= n - 1
    assert 0 <= opponent <= n
    a4 = model

    prob = LpProblem("story4_sp_offer", LpMaximize)
    variables = {}

    for agent in range(n):
        variables[f"b{agent}"] = LpVariable(f"b{agent}", lowBound=0, upBound=1)
    for agent in range(n):
        if agent != winner:
            prob += variables[f"b{agent}"] == profile[agent]
    prob += variables[f"b{winner}"]

    ins = [variables[f"b{agent}"] for agent in range(n)]
    outs = construct_relu_mlp_outputs(
        a4,
        ins,
        prob,
        variables,
        "spoffer",
    )
    prob += outs[opponent] >= outs[winner]
    prob.solve(
        GUROBI_CMD(
            msg=False,
        )
    )
    return value(prob.objective)
