import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from copy import deepcopy
from datetime import datetime
from pulp import (
    LpVariable,
    LpMaximize,
    LpProblem,
    lpSum,
    GUROBI_CMD,
    value,
    LpStatus,
    LpMinimize,
)
import matplotlib.pyplot as plt
import numpy as np
import argparse
import regretNet

device = 'cpu'

parser = argparse.ArgumentParser(description='Run allocNet_mip_verifier with checkpoint path')
parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to the checkpoint file')
parser.add_argument('--num_bidders', type=int, required=True, help='Number of bidders')
args = parser.parse_args()

def allocNet_mip_verifier(num_bidders, alloc_hidden_sizes, pay_hidden_sizes, checkpoint_path=None):
    checkpoint = load_checkpoint(checkpoint_path)
    model = regretNet.RegretNet(num_bidders, alloc_hidden_sizes, pay_hidden_sizes)

    model.load_state_dict(checkpoint['model_state_dict']) 
    model.eval()
    flag = 0
    for oldwinner in range(model.num_bidders):  
        for newwinner in range(model.num_bidders + 1): # Here should be n+1 isntead of n
            if newwinner == oldwinner:
                continue
            print("oldwinner, newwinner =", oldwinner, newwinner)
            status, obj_value = allocNet_mip_verifier_segment(model, oldwinner, newwinner)
            if not status == "Optimal" or obj_value > 0.0001:
                print("NOT MONOTONE when oldwinner, newwinner =", oldwinner, newwinner)
                flag = 1
    if flag == 0:
        print("Verified Monotone!", checkpoint_path)


def allocNet_mip_verifier_segment(model, oldwinner, newwinner):
    model.eval()

    prob = LpProblem(f"Max_Error_(MIP_Verifier)_Story_1_oldwinner_{oldwinner}_newwinner_{newwinner}", LpMaximize)
    variables = {}

    n = model.num_bidders
    
    bid_vars = []
    for i in range(n):
        var_name = f"b{i}"
        var = LpVariable(var_name, lowBound=0, upBound=1)
        variables[var_name] = var
        bid_vars.append(var)

    alloc_vars = construct_relu_mlp_outputs(
        model, bid_vars, prob, variables, "allocNetwork", extra_inputs=True, shapes=model.alloc_net_shapes, network_layers=model.alloc_network
    )

    # pay_vars = []
    # for i in range(n):
    #     unique_label_name = f"payNetwork{i}"
    #     pay_vars.append(
    #         construct_relu_mlp_outputs(
    #         model, bid_vars, prob, variables, unique_label_name, extra_inputs=True, shapes=model.pay_net_shapes, network_layers=model.pay_network_list[i]
    #     ))

    # Extra constraints for verifying IC
    movedup_bid = LpVariable("movedup_bid", lowBound=0, upBound=1)
    variables["movedup_bid"] = movedup_bid

    assert 0 <= oldwinner <= n - 1
    for bidder in range(n + 1):  # Here should be n+1 isntead of n
        if bidder == oldwinner:
            continue
        prob += alloc_vars[oldwinner] >= alloc_vars[bidder]
        prob += movedup_bid >= variables[f"b{oldwinner}"]
        bid_vars_prime = bid_vars
        bid_vars_prime[oldwinner] = movedup_bid
        alloc_vars_prime = construct_relu_mlp_outputs(
            model, bid_vars_prime, prob, variables, f"aftermoveup{bidder}", extra_inputs=True, shapes=model.alloc_net_shapes, network_layers=model.alloc_network
        )

    prob += movedup_bid - variables[f"b{oldwinner}"]

    if newwinner == None:
        newwinner = n
    assert 0 <= newwinner <= n
    for bidder in range(n):
        if bidder == newwinner:
            continue
        prob += alloc_vars_prime[newwinner] >= alloc_vars_prime[bidder]    
    
    prob.solve(
        GUROBI_CMD(
            msg=False,
        )
    )
    print(LpStatus[prob.status])
    print(value(prob.objective))
    return LpStatus[prob.status], value(prob.objective)



def construct_relu_mlp_outputs(mlp, inputs, prob, variables, unique_label, extra_inputs=False, shapes=None, network_layers=None):
    M = 1000  # big constant for MIP -- its value seems to have implication on accuracy

    if extra_inputs == False:
        shapes = mlp.shapes
        network_layers = mlp.linear_relu_stack

    # print(mlp.shapes)
    assert len(inputs) == shapes[0]
    weights = []
    biases = []
    for i in range(0, len(network_layers), 2):
        weights.append(network_layers[i].weight.data.cpu().numpy())
        biases.append(network_layers[i].bias.data.cpu().numpy())
    # print(weights)
    # print(biases)
    for i, weight in enumerate(weights):
        assert weight.shape[0] == shapes[i + 1]
        assert weight.shape[1] == shapes[i]
    for i, bias in enumerate(biases):
        assert bias.shape[0] == shapes[i + 1]

    for i, size in enumerate(shapes):
        for j in range(size):
            # For all layers except for the last layer, nodes have nonnegtive values after relu
            if i < len(shapes) - 1:
                variables[(unique_label, i, j, "nonneg")] = LpVariable(
                    f"{unique_label},{i},{j},nonneg"
                )

            # For the first layer, the bids are nonnegative already
            if i == 0:
                prob += inputs[j] == variables[(unique_label, i, j, "nonneg")]

            # Except for the first and last layer, need ReLU binary variables
            if 0 < i < len(shapes) - 1:
                variables[(unique_label, i, j, "binary")] = LpVariable(
                    f"{unique_label},{i},{j},binary", cat="Binary"
                )

            # Except for the first layer, before ReLU, the node values may be negative
            if i > 0:
                variables[(unique_label, i, j)] = LpVariable(f"{unique_label},{i},{j}")

    for i in range(1, len(shapes) - 1):
        for j in range(shapes[i]):
            prob += (
                variables[(unique_label, i, j, "nonneg")]
                >= variables[(unique_label, i, j)]
            )
            prob += variables[(unique_label, i, j, "nonneg")] >= 0
            prob += (
                variables[(unique_label, i, j, "nonneg")]
                <= variables[(unique_label, i, j)]
                + M * variables[(unique_label, i, j, "binary")]
            )
            prob += (variables[(unique_label, i, j, "nonneg")] <= M * (
                1 - variables[(unique_label, i, j, "binary")]
            ))

    for i in range(1, len(shapes)):
        for j in range(shapes[i]):
            prob += (
                variables[(unique_label, i, j)]
                == lpSum(
                    [
                        variables[(unique_label, i - 1, k, "nonneg")]
                        * weights[i - 1][j][k]
                        for k in range(shapes[i - 1])
                        if abs(weights[i - 1][j][k]) > 0.000001
                    ]
                )
                + biases[i - 1][j]
            )

    return [
        variables[(unique_label, len(shapes) - 1, j)] for j in range(shapes[-1])
    ]


def load_checkpoint(checkpoint_path):
    print("Loading model from:", checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
    return checkpoint


if __name__ == '__main__':
    allocNet_mip_verifier(num_bidders=args.num_bidders, 
                          alloc_hidden_sizes=[50,50], 
                          pay_hidden_sizes=[50,50], 
                          checkpoint_path=args.checkpoint_path)

# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed1.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed2.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed3.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed4.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed5.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed6.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed7.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed8.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed9.pt')
# allocNet_mip_verifier(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], checkpoint_path = 'checkpoint_numbidders_5/seed0.pt')