import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import os
from copy import deepcopy
from datetime import datetime
from pulp import (
    LpVariable,
    LpMaximize,
    LpProblem,
    lpSum,
    GUROBI_CMD,
    value,
    LpStatus,
    LpMinimize,
)
import matplotlib.pyplot as plt
import numpy as np
from settings import device
import regretNet
import distribution


def train_regretnet(num_epochs, num_bidders, alloc_hidden_sizes, pay_hidden_sizes, learning_rate, batch_size, checkpoint_path="checkpoint/",load_path=None, distri=None):

    lmd_update_interval = 20 
    lmd = 50 * torch.ones(num_bidders, device=device)
    rho = 0.05
    # For continued training
    model = regretNet.RegretNet(num_bidders, alloc_hidden_sizes, pay_hidden_sizes)

    if torch.cuda.is_available():
        model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if not load_path == None:
        print("Loading model from path:", load_path)
        checkpoint = load_checkpoint(load_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.train()
        lmd = checkpoint['lmd']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)

    epoch_loss_list = []

    for epoch in range(num_epochs):
        if distri == None: 
            batch = distribution.uniform_sampling(batch_size, num_bidders)
        else: 
            batch = distri.rejection_sampling(batch_size)
        batch = torch.tensor(batch, device=device)
        bids_vec = batch

        alloc_probs, payment_outputs = model(bids_vec)  
        utility_org = alloc_probs[:,:-1] * bids_vec - payment_outputs * bids_vec * alloc_probs[:,:-1] 
        rgt_vec, loss = regretNet.loss_function_b_deviate(model, alloc_probs, payment_outputs, bids_vec, lmd, rho)

        rgt_vec_copy = rgt_vec.clone()
        rgt_vec_copy.detach()

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
            
        if (epoch+1) % lmd_update_interval == 0:
            rgt_vec_copy.detach()
            lmd *= torch.exp(0.2 * (rgt_vec_copy > 0.0001))
            lmd /= torch.exp(0.2 * (rgt_vec_copy < 0.0001))
            print("After update, lmd =", lmd)


        if (epoch+1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}, regret_vec: {rgt_vec_copy}')
            print("utility =", utility_org)
            print("alloc_probs =", alloc_probs)
            print("payment_outputs =", payment_outputs)

        if (epoch+1) % 1000 == 0:
            if not os.path.exists(checkpoint_path):
                os.makedirs(checkpoint_path)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            save_checkpoint({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
                'lmd': lmd,
                'rho': rho
            }, filename=os.path.join(checkpoint_path, f'worst10_numbidders_{num_bidders}_epoch_{epoch+1}_{timestamp}.pt'))

    return epoch_loss_list


def save_checkpoint(state, filename):
    torch.save(state, filename)


def load_checkpoint(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
    return checkpoint


def eval_regretnet_bs_plot(num_bidders, alloc_hidden_sizes, pay_hidden_sizes, learning_rate, batch_size, checkpoint_path="checkpoint/", distri=None):

    checkpoint = load_checkpoint(checkpoint_path)
    model = regretNet.RegretNet(num_bidders, alloc_hidden_sizes, pay_hidden_sizes)

    model.load_state_dict(checkpoint['model_state_dict']) 
    model.eval()

    batch_size=10000

    if distri == None:
        batch = distribution.uniform_sampling(batch_size, num_bidders)
    else:
        batch = distri.rejection_sampling(batch_size) 

    batch = torch.tensor(batch, device=device)
    print("input =", batch)

    alloc_probs, payment_outputs = model(batch)   
    print("alloc_probs =", alloc_probs)
    print("payment_outputs =", payment_outputs)

    max_indices = torch.argmax(alloc_probs, dim=1)
    print("max_indices =", max_indices)
    payment = []

    fig = plt.figure()

    for i in range(len(batch)):
        if max_indices[i] == 0:
            plt.scatter(batch[i][0], batch[i][1], color='red')
        elif max_indices[i] == 1:
            plt.scatter(batch[i][0], batch[i][1], color='yellow')
        elif max_indices[i] == 2:
            plt.scatter(batch[i][0], batch[i][1], color='green')

        if max_indices[i] == num_bidders:
            payment.append(0)
            continue
        bids = batch[i]
        min_bid = 0
        max_bid = 1
        curr_bid = 0.5
        while curr_bid - min_bid > 0.0001: 
            bids_cloned = bids.clone()
            bids_cloned[max_indices[i]] = curr_bid
            alloc_probs_i, payment_outputs_i = model(bids_cloned.unsqueeze(0))
            if torch.argmax(alloc_probs_i) == max_indices[i]:
                max_bid = curr_bid
                curr_bid = min_bid / 2 + curr_bid / 2
            else:
                min_bid = curr_bid
                curr_bid = max_bid / 2 + curr_bid / 2
        payment.append(curr_bid)
        print(bids_cloned.numpy(), payment_outputs_i[0, max_indices[i]].detach().numpy(), bids[max_indices[i]].numpy(), curr_bid)
    
    payment = np.array(payment)
    avg_tot_payment = np.average(payment)

    print("avg_payment =", avg_tot_payment, "std =", np.std(payment))
    labels = ['bidder 0', 'bidder 1', 'not allocate']
    plt.legend(labels)

    plt.legend()
    plt.savefig(f"figures/utility{avg_tot_payment}_{checkpoint_path[30:-3]}.png")
    plt.show()


def eval_regretnet_bs(num_bidders, alloc_hidden_sizes, pay_hidden_sizes, learning_rate, batch_size, checkpoint_path="checkpoint/", distri=None):
    checkpoint = load_checkpoint(checkpoint_path)
    model = regretNet.RegretNet(num_bidders, alloc_hidden_sizes, pay_hidden_sizes)

    model.load_state_dict(checkpoint['model_state_dict']) 
    model.eval()

    batch_size=100000

    if distri == None:
        batch = distribution.uniform_sampling(batch_size, num_bidders)
    else:
        batch = distri.rejection_sampling(batch_size) 

    batch = torch.tensor(batch, device=device)
    print("input =", batch)

    alloc_probs, payment_outputs = model(batch) 
    print("alloc_probs =", alloc_probs)
    print("payment_outputs =", payment_outputs)

    max_indices = torch.argmax(alloc_probs, dim=1)
    print("max_indices =", max_indices)
    payment = []

    for i in range(len(batch)):
        if i % 1000 == 0:
            print(i)
        if max_indices[i] == num_bidders:
            payment.append(0)
            continue
        bids = batch[i]
        min_bid = 0
        max_bid = 1
        curr_bid = 0.5
        while curr_bid - min_bid > 0.0001: 
            bids_cloned = bids.clone()
            bids_cloned[max_indices[i]] = curr_bid
            alloc_probs_i, payment_outputs_i = model(bids_cloned.unsqueeze(0))
            if torch.argmax(alloc_probs_i) == max_indices[i]:
                max_bid = curr_bid
                curr_bid = min_bid / 2 + curr_bid / 2
            else:
                min_bid = curr_bid
                curr_bid = max_bid / 2 + curr_bid / 2
        payment.append(curr_bid)    
    
    payment = np.array(payment)
    avg_tot_payment = np.average(payment)
    print("avg_payment =", avg_tot_payment, "std =", np.std(payment))


d = distribution.GridDistribution(size=5, seed=0, binary=False)   # seed = 0 ~ 9
# d = distribution.worst100

train_regretnet(num_epochs=20000, num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], learning_rate=0.001, batch_size=16, checkpoint_path="checkpoint/", load_path=None, distri=d)
# eval_regretnet_bs(num_bidders=5, alloc_hidden_sizes=[50,50], pay_hidden_sizes=[50,50], learning_rate=0.001, batch_size=256, distri=d, checkpoint_path = '')  # Add checkpoint path here
