import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch as t
from functions import CustomMLP  

device = t.device("cuda" if t.cuda.is_available() else "cpu")

class Line_Net(nn.Module):
    def __init__(self, args, train_data, test_data):
        super(Line_Net, self).__init__()
        self.args = args
        self.train_data = train_data
        self.test_data = test_data
        self.tau = 0.1
        self.alpha = 0.05

        num_func = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent = self.args.num_agent 

        self.seller_w = t.tensor(np.random.normal(size=(num_max_units * num_func * num_agent)) / 5, device=device, requires_grad=True)
        self.seller_w2 = t.tensor(np.random.normal(size=(num_max_units * num_func * num_agent)) / 5, device=device, requires_grad=True)
        self.seller_b = t.tensor(-np.random.rand(num_max_units * num_func * num_agent) * 1.0, device=device, requires_grad=True)

    def deterministic_NeuralSort(self, s):
        n = s.size()[1]
        one = t.ones((n, 1), dtype=t.float32, device=device)
        
        A_s = t.abs(s - s.permute(0, 2, 1))
        B = t.matmul(A_s, t.matmul(one, t.transpose(one, 0, 1)))
        scaling = (n + 1 - 2 * (t.arange(n, device=device) + 1)).type(t.float32)
        C = t.matmul(s, scaling.unsqueeze(0))
        
        P_max = (C - B).permute(0, 2, 1)
        sm = nn.Softmax(-1)
        P_hat = sm(P_max / self.tau)
        
        return P_hat
    
    def t_reshape(self, x):
        num_func = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent = self.args.num_agent        
        x = t.reshape(x, [num_agent, num_func, num_max_units])
        x = t.transpose(x, 0, 1)
        x = t.transpose(x, 1, 2)
        return x
    
    def forward(self, x, ctr, alpha, mode='train'):
        num_func = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent = self.args.num_agent              
        batch_size = t.tensor(x, device=device).size()[0]
        x = t.tensor(x, device=device)
        ctr = t.tensor(ctr, device=device)
        
        append_dummy_mat = t.tensor(np.float32(np.append(np.identity(num_agent), np.zeros([num_agent, 1]), 1)), device=device)  
        
        seller_w = self.t_reshape(self.seller_w)
        seller_w2 = self.t_reshape(self.seller_w2)
        seller_b = self.t_reshape(self.seller_b)
        
        w_copy = t.reshape(seller_w.repeat([batch_size, 1, 1]), [batch_size, num_func, num_max_units, num_agent])
        w2_copy = t.reshape(seller_w2.repeat([batch_size, 1, 1]), [batch_size, num_func, num_max_units, num_agent])
        b_copy = t.reshape(seller_b.repeat([batch_size, 1, 1]), [batch_size, num_func, num_max_units, num_agent])        
        xx_copy = t.reshape(x.repeat([1, num_func * num_max_units]), [batch_size, num_func, num_max_units, num_agent])
        
        vv_max_units = t.max(t.mul(xx_copy, t.exp(w_copy)) + b_copy, 2).values    
        vv = (t.min(vv_max_units, 1).values) * ctr + alpha * ctr
        
        win_agent = t.argmax(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), 1)
        a_dummy0 = F.one_hot(win_agent, num_agent + 1)     
        a0 = a_dummy0[:, 0:num_agent]
        
        if mode == 'train':
            w_a = t.tensor(np.float32(np.identity(num_agent + 1) * 1000), device=device)
            a_dummy = F.softmax(t.matmul(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), w_a), dim=1)
        elif mode == 'test':
            k = t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32))
            k = t.reshape(k, [batch_size, num_agent + 1, 1])
            so = self.deterministic_NeuralSort(k)
            a_dummy = so[:, 0]      
        elif mode == 'true':
            a_dummy = a_dummy0

        a = a_dummy[:, 0:num_agent]
        w = np.zeros([num_agent, num_agent, num_agent])
        for i in range(num_agent):
            for j in range(num_agent):
                for k in range(num_agent):
                    if i == k:
                        w[i][j][k] = 0
                    elif j == k:
                        w[i][j][k] = 1
        w_p = t.tensor(w, dtype=t.float32, device=device)
        
        spa_tensor1 = t.reshape(t.reshape(vv, [-1]).repeat([num_agent]), [num_agent, -1, num_agent])
        spa_tensor2 = t.matmul(spa_tensor1.to(t.float32), w_p.to(t.float32))  
       
        p_spa = (t.transpose(t.max(spa_tensor2, 2).values, 0, 1) - alpha * ctr) / ctr
        p_spa_copy = t.reshape(p_spa.repeat([1, num_func * num_max_units]), [batch_size, num_max_units, num_func, num_agent])                
        p_max_units = t.min(t.mul(p_spa_copy - b_copy, t.reciprocal(t.exp(w_copy))), 2).values        
       
        p = F.relu(t.max(p_max_units, 1).values)
       
        revenue = t.mean(t.sum(t.mul(a, ctr * p) + alpha * t.mul(a, ctr), 1)) 
        cost = t.mean(t.sum(t.mul(a, ctr * p), 1))
        click = t.mean(t.sum(t.mul(a, ctr), 1)) 
        
        payoff = t.mul(a, p)
        z = t.zeros(size=payoff.size())
        for i in range(z.size()[0]):
            for j in range(z.size()[1]):
                z[i][j] = a[i][j] * (x[i][j] - p[i][j])

        utility = t.mean(z, 0)
                    
        return revenue, a, vv, utility, cost, click           
        
    def seller_backward(self, args, x, ctr, alpha, mode):
        input = x
        output = self.forward(input, ctr, alpha, mode=mode)
        loss = -output[0]
        loss.backward() 
        
        self.seller_w.data.sub_(0.01 * self.seller_w.grad.data)
        self.seller_b.data.sub_(0.01 * self.seller_b.grad.data)
    
        self.seller_w.grad.data.zero_()
        self.seller_b.grad.data.zero_()

class Hyper_Myerson(nn.Module):
    def __init__(self, args):
        super(Hyper_Myerson, self).__init__()
        self.args = args
        num_func = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent = self.args.num_agent 
        self.seller_w_init = np.random.normal(size=(num_max_units * num_func * num_agent)) / 5
        self.seller_b_init = -np.random.rand(num_max_units * num_func * num_agent) * 1.0

    def deterministic_NeuralSort(self, s):
        n = s.size()[1]
        one = t.ones((n, 1), dtype=t.float32).to(device)  
        
        A_s = t.abs(s - s.permute(0, 2, 1))
        B = t.matmul(A_s, t.matmul(one, t.transpose(one, 0, 1)))
        scaling = (n + 1 - 2 * (t.arange(n) + 1)).type(t.float32).to(device)  
        C = t.matmul(s, scaling.unsqueeze(0))
        
        P_max = (C - B).permute(0, 2, 1)
        sm = nn.Softmax(-1)
        P_hat = sm(P_max / self.tau)
        
        return P_hat
    
    def t_reshape(self, x):
        num_func = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent = self.args.num_agent        
        x = t.reshape(x, [num_agent, num_func, num_max_units])
        x = t.transpose(x, 0, 1)
        x = t.transpose(x, 1, 2)
        return x
    
    def forward(self, mode, inputs, ctr_ads, ctr_og, alpha, w1, b1):
        num_func = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent = self.args.num_agent              
        batch_size = t.tensor(inputs).size()[0]
        ctr_ads = t.tensor(ctr_ads).to(device)  
        ctr_og = t.tensor(ctr_og).to(device)  
        x = t.tensor(inputs).to(device)  
        append_dummy_mat = t.tensor(np.float32(np.append(np.identity(num_agent), np.zeros([num_agent, 1]), 1))).to(device)  
        seller_w = t.reshape((w1.repeat(num_func, num_max_units)), [num_func, num_max_units, num_agent])
        seller_b = t.reshape((b1.repeat(num_func, num_max_units)), [num_func, num_max_units, num_agent])
        w_copy = t.reshape(seller_w.repeat([batch_size, 1, 1]), [batch_size, num_func, num_max_units, num_agent])
        b_copy = t.reshape(seller_b.repeat([batch_size, 1, 1]), [batch_size, num_func, num_max_units, num_agent])     
        xx_copy = t.reshape(x.repeat([1, num_func * num_max_units]), [batch_size, num_func, num_max_units, num_agent])

        vv_max_units = t.max(t.mul(xx_copy, t.exp(w_copy)) + b_copy, 2).values    
        vv = (t.min(vv_max_units, 1).values) * ctr_ads + alpha * (ctr_ads - t.reshape(ctr_og.repeat(num_agent), [ctr_ads.size()[0], ctr_ads.size()[1]]))
        
        win_agent = t.argmax(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), 1)
        a_dummy0 = F.one_hot(win_agent, num_agent + 1)     
        a0 = a_dummy0[:, 0:num_agent]
        
        if mode == 'train':
            w_a = t.tensor(np.float32(np.identity(num_agent + 1) * 1000)).to(device)  
            a_dummy = F.softmax(t.matmul(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), w_a), dim=1)
        elif mode == 'test':
            k = t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32))
            k = t.reshape(k, [batch_size, num_agent + 1, 1])
            so = self.deterministic_NeuralSort(k)
            a_dummy = so[:, 0]      
        elif mode == 'true':
            a_dummy = a_dummy0

        a = a_dummy[:, 0:num_agent]
        w = np.zeros([num_agent, num_agent, num_agent])
        for i in range(num_agent):
            for j in range(num_agent):
                for k in range(num_agent):
                    if i == k:
                        w[i][j][k] = 0
                    elif j == k:
                        w[i][j][k] = 1
        w_p = t.tensor(w, dtype=t.float32).to(device)  
        
        spa_tensor1 = t.reshape(t.reshape(vv, [-1]).repeat([num_agent]), [num_agent, -1, num_agent])
        spa_tensor2 = t.matmul(spa_tensor1.to(t.float32), w_p.to(t.float32))  

        ctr_tensor1 = t.reshape(t.reshape(ctr_ads, [-1]).repeat([num_agent]), [num_agent, -1, num_agent])
        ctr_tensor2 = t.matmul(ctr_tensor1.to(t.float32), t.ones(w_p.size()).to(device) - w_p.to(t.float32)) 
        
        p_spa = (t.transpose(t.max(spa_tensor2, 2).values, 0, 1) - alpha * (ctr_ads - t.reshape(ctr_og.repeat(num_agent), [ctr_ads.size()[0], ctr_ads.size()[1]]))) / ctr_ads
        p_spa_copy = t.reshape(p_spa.repeat([1, num_func * num_max_units]), [batch_size, num_max_units, num_func, num_agent])    
        p_max_units = t.min(t.mul(p_spa_copy - b_copy, t.reciprocal(t.exp(w_copy))), 2).values        
       
        p = F.relu(t.max(p_max_units, 1).values)
       
        revenue = t.mean(t.sum(t.mul(a, p * ctr_ads + alpha * ctr_ads), 1)) 
        revenue = revenue + alpha * t.mean(a_dummy[:, -1] * ctr_og)

        cost = t.mean(t.sum(t.mul(a, p * ctr_ads), 1)) 
        click = (revenue - cost) / alpha
        
        payoff = t.mul(a, p)
        z = t.zeros(size=payoff.size())
        for i in range(z.size()[0]):
            for j in range(z.size()[1]):
                z[i][j] = a[i][j] * (x[i][j] - p[i][j])

        utility = t.mean(z, 0)
        percent = t.mean(a_dummy[:, -1])
                    
        return revenue, a, vv, utility, cost, click, percent        
    
    def build_hyper_mlp_net_layer(self, inp_last_dim, units):
        return CustomMLP(inp_last_dim, units)