import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import os
from copy import deepcopy
from datetime import datetime
from scipy import optimize

from settings import device
import numpy as np

class RegretNet(nn.Module):
    def __init__(self, num_bidders, alloc_hidden_sizes, pay_hidden_sizes, initial_state_dict=None):
        super(RegretNet, self).__init__()
        
        self.num_bidders = num_bidders

        alloc_hidden_sizes = [self.num_bidders] + alloc_hidden_sizes + [self.num_bidders + 1]

        alloc_stacks = []
        for i in range(len(alloc_hidden_sizes) - 1):
            alloc_stacks.append(nn.Linear(alloc_hidden_sizes[i], alloc_hidden_sizes[i + 1]))
            alloc_stacks.append(nn.ReLU())
        alloc_stacks = alloc_stacks[:-1]
        self.alloc_network = nn.Sequential(*alloc_stacks)

        pay_hidden_sizes = [self.num_bidders] + pay_hidden_sizes + [self.num_bidders]  
        # pay_hidden_sizes = [self.num_bidders] + pay_hidden_sizes + [2 * self.num_bidders]  # added "2*" for the last layer softmax (2 outputs for one bidder)

        pay_stacks = []
        for i in range(len(pay_hidden_sizes) - 1):
            pay_stacks.append(nn.Linear(pay_hidden_sizes[i], pay_hidden_sizes[i + 1]))
            pay_stacks.append(nn.ReLU())
        # remove the last ReLu
        pay_stacks = pay_stacks[:-1]
        self.pay_network = nn.Sequential(*pay_stacks)

        self.softmax = nn.Softmax()   # dim=1 (?)

        if initial_state_dict is not None:
            self.load_state_dict(initial_state_dict)

        self.initial_state_dict = deepcopy(self.state_dict())
        self.pay_net_shapes = pay_hidden_sizes
        self.alloc_net_shapes = alloc_hidden_sizes
        
    def forward(self, bids_vec):
        beta = 1
        alloc_out_no_softmax = self.alloc_network(bids_vec)
        allocation_probs = F.softmax(alloc_out_no_softmax * beta, dim=1)   # When beta is large, this tends to be one-hot. 
        pay_out_no_softmax = self.pay_network(bids_vec)
        # Now, payment_outputs is the proportion of bids, which ensures IR 
        # payment_outputs = F.softmax(pay_out_no_softmax * beta, dim=1)

        # sigmoid version payment net
        payment_outputs = torch.sigmoid(pay_out_no_softmax)   

        # clamp version payment net
        # payment_outputs = torch.clamp(pay_out_no_softmax, 0, 1) 

        # 2*self.numbidders dim output softmax to self.numbidders dim payment_outputs
        # x_reshaped = pay_out_no_softmax.view(-1, self.num_bidders, 2)
        # y = F.softmax(x_reshaped, dim=2)
        # payment_outputs = y[:,:,0]

        return allocation_probs, payment_outputs   
    
    def load_params(self, stored_state_dict):
        self.load_state_dict(stored_state_dict)

    def store_params(self, path=None):
        if path == None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            path = f'/stored_model/model_{timestamp}.pth'
        torch.save(self.state_dict(), path)

    def non_zero_proportion(self, current_layer_abs_sum):
        res = 1000000
        for v in current_layer_abs_sum:
            # prevent 0s
            if 0.000001 < v:
                res = min(v / (torch.sum(current_layer_abs_sum) + 0.000001), res)
        return res

    # prune_alloc = True when pruning allocation network, prune_alloc = False when pruning payment network
    def prune_node(self, prune_alloc):
        if prune_alloc == True:
            curr_net = self.alloc_network
        else: 
            curr_net = self.pay_network

        parameters_to_prune = [
            (curr_net[i], "weight")
            for i in range(0, len(curr_net), 2)
        ]
        print()
        print("before")
        for x, _ in parameters_to_prune[1:]:
            print(x.weight.data.cpu().numpy())

        layer_to_prune = 2
        curr_all_layer_prop_min = 999
        node_to_prune = None

        for i in range(2, len(curr_net), 2):    # Do not prune the first layer nodes (input nodes)
            node_weights = torch.abs(curr_net[i].weight).sum(dim=0)
            layer_i_prop_min = self.non_zero_proportion(node_weights)
            if layer_i_prop_min < curr_all_layer_prop_min:
                curr_all_layer_prop_min = layer_i_prop_min
                layer_to_prune = i
                node_to_prune = torch.argmin(node_weights)

        layer_weights = curr_net[layer_to_prune]
        mask = torch.ones_like(layer_weights)
        mask[:, node_to_prune] = 0

        prune.custom_from_mask(
            curr_net[layer_to_prune],
            name="weight",
            mask=mask,
        )
        print("after")

        print("isjfisfjdsifsf", curr_net[layer_to_prune])

        for x, _ in parameters_to_prune[1:]:
            print(x.weight.data.cpu().numpy())
        
    def count_one_network_edges(self, net):
        res = 0
        for i in range(0, len(net), 2):
            res += self.alloc_network[i].weight.shape[1] * self.alloc_network[i].weight.shape[0]
        return res

    # alloc = 1 when counting allocation network, alloc = 0 when counting payment network (given bidder_id_list when counting specific payment network), alloc < 0 when counting total
    def count_total_edges(self, alloc=-1, bidder_id_list = None): 
        tot = 0
        if alloc == 1 or alloc < 0:
            res = self.count_one_network_edges(self.alloc_network)
            tot += res
            if alloc == 1:
                return tot
        if alloc == 0:
            if bidder_id_list == None:
                bidder_id_list = range(self.pay_network_list)
            res = 0
            for bidder_id in bidder_id_list:
                res += self.count_one_network_edges(self.pay_network_list[bidder_id])
            return res
        if alloc < 0:
            res = 0
            for bidder_id in range(self.pay_network_list):
                res += self.count_one_network_edges(self.pay_network_list[bidder_id])
            tot += res
            return tot


# In this function, valuation is truthful valuation
def optim_deviate(model, valuation, gamma=0.0001, T=1000):
    model.eval()
    n = valuation.shape[0]  # number of bidders

    # valuation = valuation.to(device)
    opt_dev_utility = torch.zeros(n, device=device)
    # regret = torch.zeros(n, device=device)
    bids_deviated = torch.zeros(n, n, device=device)

    for idx in range(n):
        bids = valuation.clone() 

        def utility_func(x):
            # model.eval()
            x = torch.tensor(x)
            bids_cloned = bids.clone()
            bids_cloned[idx] = x
            allocation_probs, payment_outputs = model(bids_cloned.unsqueeze(0))
            utility_idx = allocation_probs[0][:-1][idx] * valuation[idx] - payment_outputs[0][idx] * bids_cloned[idx] * allocation_probs[0][:-1][idx]
            out = - utility_idx
            out = out.detach().numpy()
            return out

        model.eval()
        resbrute = optimize.brute(
            utility_func,
            (slice(0, 1 + 1/1000, 1/1000),),
            # (np.arange(0,1,1/20)),
            full_output=True,
            finish=None,
        )

        parameter_value = resbrute[0]  # parameter at global minimum
        objective_value = resbrute[1]  # function value at global minimum

        util = - objective_value

        bids[idx] = parameter_value
        bids_deviated[idx] = bids
        opt_dev_utility[idx] = torch.tensor(util)
        
    return bids_deviated


def loss_function(model, alloc_probs, payment_outputs, bids_vec, lmd, rho):
    n = lmd.shape[0]  # number of bidders
    rho = 0
    batch_size = alloc_probs.shape[0]

    tot_payments = torch.sum(payment_outputs * bids_vec * alloc_probs[:,:-1])
    loss_1 = - tot_payments / batch_size

    regret = torch.zeros(n, device=device)

    for i in range(batch_size):
        bids_deviated = optim_deviate(model, bids_vec[i], gamma=0.01, T=1000)
        allocation_probs, payment_outputs = model(bids_vec[i].unsqueeze(0))
        print("bids_deviated =", bids_deviated)

        utility_org = allocation_probs[0][:-1] * bids_vec[i] - payment_outputs[0] * bids_vec[i] * allocation_probs[0][:-1] 
        print("utility_org =", utility_org)
        allocation_probs_deviated, payment_outputs_deviated = model(bids_deviated)

        utility_deviated = allocation_probs_deviated[:,:-1] * bids_vec[i] - payment_outputs_deviated * bids_deviated * allocation_probs_deviated[:,:-1] 

        print("utility_deviated =", utility_deviated)

        for bidder in range(n):
            regret[bidder] += utility_deviated[bidder][bidder] - utility_org[bidder]
    
    regret /= batch_size
    print("regret =", regret)
    loss_2 = torch.sum(lmd * regret)
    loss_3 = rho * (torch.pow(regret, 2).sum() / 2.0)

    total_loss = loss_1 + loss_2 + loss_3
    print("losses:", loss_1, loss_2, loss_3)
 
    return regret, total_loss


def optim_b_deviate(model, valuation, gamma=0.0001, T=1000):
    model.eval()
    n = valuation.shape[0]  # number of bidders
    opt_dev_utility = torch.zeros(n, device=device)
    bids_deviated = torch.zeros(n, n, device=device)

    for idx in range(n):
        bids = valuation.clone() 

        def utility_func(x):
            x = torch.tensor(x)
            bids_cloned = bids.clone()
            bids_cloned[idx] = x
            allocation_probs, payment_outputs = model(bids_cloned.unsqueeze(0))
            utility_idx = allocation_probs[0][:-1][idx] * valuation[idx] - payment_outputs[0][idx] * bids_cloned[idx] * allocation_probs[0][:-1][idx]
            out = - utility_idx
            out = out.detach().numpy()
            return out

        model.eval()
        resbrute = optimize.brute(
            utility_func,
            (np.array([0,1]),),
            full_output=True,
            finish=None,
        )

        parameter_value = resbrute[0]  # parameter at global minimum
        objective_value = resbrute[1]  # function value at global minimum

        util = - objective_value

        bids[idx] = parameter_value
        bids_deviated[idx] = bids
        opt_dev_utility[idx] = torch.tensor(util)

        # print("parameter_value =", parameter_value, ", objective_value =", objective_value, ", util =", util)

        bids[idx] = parameter_value
        bids_deviated[idx] = bids
        opt_dev_utility[idx] = torch.tensor(util)
        
    return bids_deviated


def loss_function_b_deviate(model, alloc_probs, payment_outputs, bids_vec, lmd, rho):
    n = lmd.shape[0]  # number of bidders
    rho = 0
    batch_size = alloc_probs.shape[0]
    tot_payments = torch.sum(payment_outputs * bids_vec * alloc_probs[:,:-1])
    loss_1 = - tot_payments / batch_size

    regret = torch.zeros(n, device=device)

    for i in range(batch_size):
        bids_deviated = optim_b_deviate(model, bids_vec[i], gamma=0.01, T=1000)
        allocation_probs, payment_outputs = model(bids_vec[i].unsqueeze(0))

        utility_org = allocation_probs[0][:-1] * bids_vec[i] - payment_outputs[0] * bids_vec[i] * allocation_probs[0][:-1] 
        allocation_probs_deviated, payment_outputs_deviated = model(bids_deviated)

        utility_deviated = allocation_probs_deviated[:,:-1] * bids_vec[i] - payment_outputs_deviated * bids_deviated * allocation_probs_deviated[:,:-1] 

        for bidder in range(n):
            regret[bidder] += utility_deviated[bidder][bidder] - utility_org[bidder]
    
    regret /= batch_size
    loss_2 = torch.sum(lmd * regret)
    loss_3 = rho * (torch.pow(regret, 2).sum() / 2.0)

    total_loss = loss_1 + loss_2 + loss_3 
    return regret, total_loss