from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch as t
from matplotlib import pyplot as plt
from random import random

device = t.device("cuda" if t.cuda.is_available() else "cpu")

class PIDController:
    def __init__(self,Kp,Ki,Kd,target):
        self.Kp = Kp
        self.Ki = Ki
        self.Kd = Kd
        self.target = target
        self.cumulate = np.ones(20) * 0.5
        self.integral = 0
        self.prev_error = 0
    def update(self, current_value):
        error = self.target - current_value
        self.integral = self.integral + error
        for i in range(19):
            self.cumulate[i] = self.cumulate[i+1]
        self.cumulate[-1] = error
        derivative = (error - self.prev_error)
        output = self.Kp * error + self.Ki * np.mean(self.cumulate) + self.Kd * derivative
        self.prev_error = error
        return output


class Generator():
    
    def __init__(self, args):
        self.args = args
        
    def generate_uniform(self, low, high):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(low,high,[num_instances,num_agent])
        return sample_val

    def generate_ctr(self,low,high):
        num_instances = self.args.num_sample_train
        sample_val = np.random.uniform(low,high,[num_instances])
        return sample_val

class Generator2():
    
    def __init__(self, args):
        self.args = args
        
    def generate_uniform(self, low, high):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(low,high,[num_instances,num_agent])
        return sample_val

    def generate_ctr(self,low,high):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(low,high,[num_instances,num_agent])
        return sample_val

class Args():
    
    def __init__(self,args):
        self.num_agent = args[0]
        self.num_item = args[1]
        self.distribution_type = args[2]
        self.num_linear = args[3]
        self.num_max = args[4]
        self.num_sample_train = args[5]
        self.num_sample_test = args[6]
        self.seed_val = args[7]
        
args = Args((2,1,"uniform",20,20,10000,10000,1))  


class Line_Net(nn.Module):
    def __init__(self,args,train_data,test_data):
        nn.Module.__init__(self)
        self.args = args
        self.train_data = train_data
        self.test_data = test_data
        self.tau = 0.1
        self.lola = 0.1
        self.alpha = 0.05
        

        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent 
        

        self.seller_w_init = np.random.normal(size = (num_max_units * num_func * num_agent)) / 5

        self.seller_w2_init = np.random.normal(size = (num_max_units * num_func * num_agent)) / 5
        
        self.seller_b_init = -np.random.rand(num_max_units * num_func * num_agent) * 1.0

        self.seller_w = t.tensor(self.seller_w_init, device = device, requires_grad=True)
        self.seller_w2 = t.tensor(self.seller_w2_init, device = device, requires_grad=True)
        self.seller_b = t.tensor(self.seller_b_init, device = device, requires_grad=True) 
      

    def deterministic_NeuralSort(self,s):
        
        n = s.size()[1]
        one = t.ones((n,1), dtype = t.float32, device = device)
        
        A_s = t.abs(s-s.permute(0,2,1))
        B = t.matmul(A_s, t.matmul(one, t.transpose(one, 0, 1)))
        scaling = (n+1 - 2*(t.arange(n, device = device)+1)).type(t.float32)
        C = t.matmul(s, scaling.unsqueeze(0))
        
        P_max = (C-B).permute(0, 2, 1)
        sm = nn.Softmax(-1)
        P_hat = sm(P_max / self.tau)
        
        return P_hat
    
    def t_reshape(self,x):
        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent        
        x = t.reshape(x,[num_agent, num_func, num_max_units])
        x = t.transpose(x,0,1)
        x = t.transpose(x,1,2)
        return x
    
    def forward(self,x,ctr,alpha,str='train'):

        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent              
        batch_size = t.tensor(x,device = device).size()[0]
        x = t.tensor(x,device = device)
        ctr = t.tensor(ctr, device=device)
        
        append_dummy_mat = t.tensor(
                            np.float32(np.append(np.identity(num_agent),
                            np.zeros([num_agent, 1]), 1)),device = device)  

    
        seller_w = self.t_reshape(self.seller_w)
        seller_w2 = self.t_reshape(self.seller_w2)
        seller_b = self.t_reshape(self.seller_b)
  
        w_copy = t.reshape(seller_w.repeat([batch_size,1,1]),[batch_size,num_func,num_max_units,num_agent])
        w2_copy = t.reshape(seller_w2.repeat([batch_size,1,1]),[batch_size,num_func,num_max_units,num_agent])
        b_copy = t.reshape(seller_b.repeat([batch_size,1,1]),[batch_size,num_func,num_max_units,num_agent])        
        xx_copy = t.reshape(x.repeat([1,num_func*num_max_units]),[batch_size,num_func,num_max_units,num_agent])
        ctr_copy = t.reshape(ctr.repeat([1,num_func*num_max_units]),[batch_size,num_func,num_max_units,num_agent])
        
        vv_max_units = t.max(t.mul(xx_copy, t.exp(w_copy)) +  b_copy,2).values    

        vv = (t.min(vv_max_units, 1).values) * ctr + alpha * ctr

        win_agent = t.argmax(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), 1) # The index of agent who win the item
        a_dummy0 = F.one_hot(win_agent, num_agent+1)     
        a0 = a_dummy0[:,0:num_agent]
        if str == 'train':

            w_a = t.tensor(np.float32(np.identity(num_agent+1) * 1000),device = device)
            a_dummy = F.softmax(t.matmul(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), w_a),dim=1)
        if str == 'test':

            k = t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32))
            k = t.reshape(k,[batch_size,num_agent+1,1])
            so = self.deterministic_NeuralSort(k)
            a_dummy = so[:,0]      
        if str == 'true':
            a_dummy = a_dummy0
   
        a = a_dummy[:,0:num_agent]
        w = np.zeros([num_agent,num_agent,num_agent])
        for i in range(num_agent):
            for j in range(num_agent):
                for k in range(num_agent):
                    if i==k :
                        w[i][j][k] = 0
                    else :
                        if j==k :
                            w[i][j][k] = 1
        w_p = t.tensor(w,dtype=t.float32,device = device)

        spa_tensor1 = t.reshape(t.reshape(vv, [-1]).repeat([num_agent]), [num_agent, -1, num_agent])
        spa_tensor2 = t.matmul(spa_tensor1.to(t.float32), w_p.to(t.float32))  
       
     
        p_spa = (t.transpose(t.max(spa_tensor2, 2).values,0,1) - alpha * ctr) /ctr
        p_spa_copy = t.reshape(p_spa.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func, num_agent])                
        p_max_units = t.min(t.mul(p_spa_copy - b_copy, t.reciprocal(t.exp(w_copy))),2).values        
       

        p = t.max(p_max_units,1).values
       
        revenue = t.mean(t.sum(t.mul(a, ctr*p) + alpha * t.mul(a, ctr),1)) 

        cost = t.mean(t.sum(t.mul(a, ctr*p),1))
        click = t.mean(t.sum(t.mul(a, ctr),1)) 
        
        payoff = t.mul(a, p)
 
        z = t.zeros(size=payoff.size())
        for i in range(z.size()[0]):
            for j in range(z.size()[1]):
                z[i][j] = a[i][j] * (x[i][j] - p[i][j])

        utility = t.mean(z,0)

                    
        return revenue, a, vv, utility, cost, click           
        
    
        
    def seller_backward(self,args,x,ctr,alpha,key):
     
        input = x
        output = self.forward(input,ctr,alpha,str=key)
        loss = -output[0]
        loss.backward() # fake backward
        

        self.seller_w.data.sub_(0.1 * self.seller_w.grad.data)
        #self.seller_w2.data.sub_(0.1 * self.seller_w2.grad.data)
        self.seller_b.data.sub_(0.1 * self.seller_b.grad.data)

        self.seller_w.grad.data.zero_()
        #self.seller_w2.grad.data.zero_()
        self.seller_b.grad.data.zero_()
        #self.bidder_w.grad.data.zero_()
  
        


class MLP(nn.Module):
    def __init__(self, layers, activation):
        super(MLP, self).__init__()
        self.layers_list = nn.ModuleList([nn.Linear(layers[i], layers[i+1]) for i in range(len(layers) - 1)])
        self.activation = activation

    def forward(self, x):
        for j, layer in enumerate(self.layers_list):
            if j == len(self.layers_list)-1:
                x = 5 * t.tanh(layer(x))
            else:
                x = self.activation(layer(x))
        return x

class CustomMLP(nn.Module):
    def __init__(self, input_size, output_size):
        super(CustomMLP, self).__init__()
        self.fc = nn.Linear(input_size, output_size)  

    def forward(self, x, given_weights, given_biases):
        x = t.matmul(x,given_weights.view(-1,given_biases.shape[0]))
        x = x + given_biases
        return x 

class Hyper_Myerson(nn.Module):
    def __init__(self, args):
        super(Hyper_Myerson, self).__init__()
        self.layers_list = []
        self.args = args
        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent 
        self.seller_w_init = np.random.normal(size = (num_max_units * num_func * num_agent)) / 5
        self.seller_b_init = -np.random.rand(num_max_units * num_func * num_agent) * 1.0

    def deterministic_NeuralSort(self,s):
        
        n = s.size()[1]
        one = t.ones((n,1), dtype = t.float32).to(device)  
        
        A_s = t.abs(s-s.permute(0,2,1))
        B = t.matmul(A_s, t.matmul(one, t.transpose(one, 0, 1)))
        scaling = (n+1 - 2*(t.arange(n)+1)).type(t.float32).to(device)  
        C = t.matmul(s, scaling.unsqueeze(0))
        
        P_max = (C-B).permute(0, 2, 1)
        sm = nn.Softmax(-1)
        P_hat = sm(P_max / self.tau)
        
        return P_hat
    
    def t_reshape(self,x):
        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent        
        x = t.reshape(x,[num_agent, num_func, num_max_units])
        x = t.transpose(x,0,1)
        x = t.transpose(x,1,2)
        return x
    
    def forward(self, str, inputs, ctr_ads, ctr_og, alpha, w1, b1):
        """Computes (approximately) optimal misreports for a given auction."""
        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent              
        batch_size = t.tensor(inputs).size()[0]
        ctr_ads = t.tensor(ctr_ads).to(device)  
        ctr_og = t.tensor(ctr_og).to(device)  
        x = t.tensor(inputs).to(device)  
        append_dummy_mat = t.tensor(
                            np.float32(np.append(np.identity(num_agent),
                            np.zeros([num_agent, 1]), 1))).to(device)  
        seller_w = t.reshape((w1.repeat(num_func,num_max_units)),[num_func,num_max_units,num_agent])
        seller_b = t.reshape((b1.repeat(num_func,num_max_units)),[num_func,num_max_units,num_agent])
        w_copy = t.reshape(seller_w.repeat([batch_size,1,1]),[batch_size,num_func,num_max_units,num_agent])
        b_copy = t.reshape(seller_b.repeat([batch_size,1,1]),[batch_size,num_func,num_max_units,num_agent])     
        xx_copy = t.reshape(x.repeat([1,num_func*num_max_units]),[batch_size,num_func,num_max_units,num_agent])

        vv_max_units = t.max(t.mul(xx_copy, t.exp(w_copy)) + b_copy,2).values    

        vv = (t.min(vv_max_units, 1).values) * ctr_ads + alpha *(ctr_ads - t.reshape(ctr_og.repeat(num_agent),[ctr_ads.size()[0],ctr_ads.size()[1]]))

        win_agent = t.argmax(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), 1) # The index of agent who win the item
        a_dummy0 = F.one_hot(win_agent, num_agent+1)     
        a0 = a_dummy0[:,0:num_agent]
        if str == 'train':

            w_a = t.tensor(np.float32(np.identity(num_agent+1) * 1000)).to(device)  
            a_dummy = F.softmax(t.matmul(t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32)), w_a),dim=1)
        if str == 'test':

            k = t.matmul(vv.to(t.float32), append_dummy_mat.to(t.float32))
            k = t.reshape(k,[batch_size,num_agent+1,1])
            so = self.deterministic_NeuralSort(k)
            a_dummy = so[:,0]      
        if str == 'true':
            a_dummy = a_dummy0
 
        a = a_dummy[:,0:num_agent]
        w = np.zeros([num_agent,num_agent,num_agent])
        for i in range(num_agent):
            for j in range(num_agent):
                for k in range(num_agent):
                    if i==k :
                        w[i][j][k] = 0
                    else :
                        if j==k :
                            w[i][j][k] = 1
        w_p = t.tensor(w,dtype=t.float32).to(device)  

        spa_tensor1 = t.reshape(t.reshape(vv, [-1]).repeat([num_agent]), [num_agent, -1, num_agent])
        spa_tensor2 = t.matmul(spa_tensor1.to(t.float32), w_p.to(t.float32))  


        ctr_tensor1 = t.reshape(t.reshape(ctr_ads, [-1]).repeat([num_agent]), [num_agent, -1, num_agent])
        ctr_tensor2 = t.matmul(ctr_tensor1.to(t.float32), t.ones(w_p.size()).to(device)  - w_p.to(t.float32)) 
        
    
        p_spa = (t.transpose(t.max(spa_tensor2, 2).values,0,1) - alpha * (ctr_ads - t.reshape(ctr_og.repeat(num_agent),[ctr_ads.size()[0],ctr_ads.size()[1]]))) / ctr_ads

        
        p_spa_copy = t.reshape(p_spa.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func, num_agent])    
        
        p_max_units = t.min(t.mul(p_spa_copy - b_copy,t.reciprocal(t.exp(w_copy))),2).values        

        p = t.max(p_max_units,1).values
       
        revenue = t.mean(t.sum(t.mul(a, p*ctr_ads+alpha*ctr_ads),1)) 
        revenue = revenue + alpha * t.mean(a_dummy[:,-1] * ctr_og)

        cost = t.mean(t.sum(t.mul(a, p*ctr_ads),1)) 
        click = (revenue - cost) / alpha
        
        payoff = t.mul(a, p)

        z = t.zeros(size=payoff.size())
        for i in range(z.size()[0]):
            for j in range(z.size()[1]):
                z[i][j] = a[i][j] * (x[i][j] - p[i][j])

        utility = t.mean(z,0)

        percent = t.mean(a_dummy[:,-1])
                    
        return revenue, a, vv, utility, cost, click, percent        
    
    def build_hyper_mlp_net_layer(self, inp_last_dim, units):
        return CustomMLP(inp_last_dim, units)

class Learner:
    """Two Player Auction Learner."""

    def __init__(self, args):
        self.args = args
        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent 
        # Define the PyTorch models
        
        self.auct_model = Hyper_Myerson(args)
        #self.Bidder_net2 = HyperDNNModel(self.bidders * self.items, self.items, self.bidders, self.items, 2)
        
        generator_vector = 3
        self.weight = 1.
        self.w_net = MLP([generator_vector,generator_vector*10,generator_vector*10,2], t.tanh).to(device)   
        self.b_net = MLP([generator_vector,generator_vector*10,generator_vector*10,2], t.tanh).to(device)    
        self.optimizers_auct = t.optim.Adam(self.w_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.lrupdate = t.optim.lr_scheduler.StepLR(self.optimizers_auct, 1, gamma=0.9999, last_epoch=-1)
        self.optimizers_auct2 = t.optim.Adam(self.b_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.lrupdate2 = t.optim.lr_scheduler.StepLR(self.optimizers_auct2, 1, gamma=0.9999, last_epoch=-1)
        
    def update_auct(self):

        low = 0.
        high = 2 * np.random.uniform() + 0.3
        alpha = np.random.uniform()
        distribution = t.tensor([low, high, alpha]).to(device)  
        generate_train = Generator(args)
        ctr_ads = generate_train.generate_uniform(0, 1)
        ctr_og = generate_train.generate_ctr(0, 2)
        train_data = generate_train.generate_uniform(low, high)
        loss =  - t.mean(self.auct_model('train', train_data, ctr_ads, ctr_og, alpha, self.w_net(distribution), self.b_net(distribution))[0])
        self.optimizers_auct.zero_grad()
        self.optimizers_auct2.zero_grad()
        loss.backward()
        self.optimizers_auct.step()
        self.optimizers_auct2.step()
        self.lrupdate.step()
        self.lrupdate2.step()

def cumulative_average(arr):

    cum_avg = np.zeros(len(arr))
    

    for k in range(1, len(arr) + 1):
        
        cum_avg[k - 1] = np.mean(arr[:k])
    
    return cum_avg

def average_by_tens(x):

    x = np.array(x)
    

    num_groups = len(x) // 10  
    

    y = [np.mean(x[i*10:(i+1)*10]) for i in range(num_groups)]
    
    return y

def train_linear(args, args2, net1, net2, net3, net4, rollouts):


    #generate_train2 = Generator2(args)
    #train_data2 = generate_train2.generate_uniform(0,1)
    
    losspr2 = [0]
    losscost2 = [0]
    lossclick2 = [0]
    lossmy = []
    lossfp = []
    losssp = []  
    lossbid1 = []
    lossbid2 = []
    losspr12 = [0]
    losscost12 = [0]
    lossclick12 = [0]
    losspr22 = [0]
    losscost22 = [0]
    lossclick22 = [0]
    losspr32 = [0]
    losscost32 = [0]
    lossclick32 = [0]
    losspr42 = [0]
    losscost42 = [0]
    lossclick42 = [0]
    key = 'train'
    
    generate_train = Generator(args)
    generate_train2 = Generator(args2)
    
    losspr = []
    losspr1 = [0]
    losspr2 = [0]
    losspr3 = [0]
    losspr4 = [0]
    lossmy = []
    lossfp = []
    losssp = []  
    lossbid1 = []
    lossbid2 = []
    lossbid3 = []
    lossbid4 = []
    lossmy1 = []
    lossfp1 = []
    losssp1 = [] 
    lossmy2 = []
    lossfp2 = []
    losssp2 = []
    lossmy3 = []
    lossfp3 = []
    losssp3 = [] 
    lossmy4 = []
    lossfp4 = []
    losssp4 = [] 
    losscost1 = [0]
    losscost2 = [0]
    losscost3 = [0]
    losscost4 = [0]
    lossclick1 = [0]
    lossclick2 = [0]
    lossclick3 = [0]
    lossclick4 = [0]
    perc1 = [0]
    perc2 = [0]
    perc3 = [0]
    perc4 = [0]
    alpha11 = []
    alpha22 = []
    alpha33 = []
    alpha44 = []
    key = 'train'
    #Myerson_auction = OptRevOneItem(args,train_data)
    
    learner = Learner(args)
    alpha1 = 0.5
    alpha2 = 0.5 
    alpha3 = 0.5
    alpha4 = 0.5
    PID1 = PIDController(0.05, 0.01, 0.1, 0.5)
    PID2 = PIDController(0.01, 0.001, 0.1, 0.5)
    PID3 = PIDController(0.01, 0.001, 0.05, 0.5)
    PID4 = PIDController(0.01, 0.001, 0.05, 0.5)

    
    for i in range(rollouts):
    
            
        if i < 100 :
            learner.update_auct()
            low = 0.
            high = 0.5 + 0.5 * 1
            distribution = t.tensor([low, high, 0.5]).float().to(device)     
            train_data = generate_train.generate_uniform(low, high)
            ctr_ads = generate_train.generate_uniform(0, 1)
            ctr_og = generate_train.generate_ctr(0, 2)
            revenue, _, _, utility, cost, click, perc = learner.auct_model('train', train_data, ctr_ads, ctr_og, alpha2, learner.w_net(distribution), learner.b_net(distribution))
            ut = utility.cpu().detach().numpy()
            rev = revenue.cpu().detach().numpy()
            print(rev)

            for l in range(4):
                
                key = 'train'
                train_data2 = generate_train2.generate_uniform(0,0.5+0.5*l)
                ctr = generate_train2.generate_uniform(0,1)
                if l == 0:
                    net1.seller_backward(args,train_data2,ctr,0.3,key)
                if l == 1:
                    net2.seller_backward(args,train_data2,ctr,0.3,key)
                if l == 2:
                    net3.seller_backward(args,train_data2,ctr,0.3,key)
                if l == 3:
                    net4.seller_backward(args,train_data2,ctr,0.3,key)  

            
        else : 
            for tt in range(10):
                learner.update_auct()
        
            for l in range(4):  
                low = 0.
                high = 0.5 + 0.5 * l
                if l == 0:
                    high = high + 0.2
                if l == 0:
                    Alpha = alpha1
                    alpha11.append(alpha1)
                if l == 1:
                    Alpha = alpha2
                    alpha22.append(alpha2)
                if l == 2:
                    Alpha = alpha3
                    alpha33.append(alpha3)
                if l == 3:
                    Alpha = alpha4
                    alpha44.append(alpha4)
                distribution = t.tensor([low, high, Alpha]).float().to(device)                    
                train_data = generate_train.generate_uniform(low, high)
                ctr_ads = generate_train.generate_uniform(0, 1)
                ctr_og = generate_train.generate_ctr(0, 2)
                revenue, _, _, utility, cost, click, perc = learner.auct_model('train', train_data, ctr_ads, ctr_og, Alpha, learner.w_net(distribution), learner.b_net(distribution))
                ut = utility.cpu().detach().numpy()
                rev = revenue.cpu().detach().numpy()
                cost = cost.cpu().detach().numpy()
                click = click.cpu().detach().numpy()
                perc = perc.cpu().detach().numpy()
                if l == 0:
                    lossbid1.append(np.mean(ut))               
                    losspr1.append(rev)
                    losscost1.append(cost)
                    lossclick1.append(click)
                    perc1.append(perc)
                    alpha1 = alpha1 * np.exp(PID1.update(perc))
                if l == 1:
                    lossbid2.append(np.mean(ut))               
                    losspr2.append(rev)  
                    losscost2.append(cost)
                    lossclick2.append(click)
                    perc2.append(perc)
                    alpha2 = alpha2 * np.exp(PID2.update(perc)) 
                if l == 2:
                    lossbid3.append(np.mean(ut))               
                    losspr3.append(rev) 
                    losscost3.append(cost)
                    lossclick3.append(click)        
                    perc3.append(perc)
                    alpha3 = alpha3 * np.exp(PID3.update(perc))
                if l == 3:
                    lossbid4.append(np.mean(ut))               
                    losspr4.append(rev) 
                    losscost4.append(cost)
                    lossclick4.append(click)
                    perc4.append(perc)
                    alpha4 = alpha4 * np.exp(PID4.update(perc)) 
            #net.bidder_w.data.sub_(grad_w)
            
            for l in range(4):      
                train_data2 = generate_train2.generate_uniform(0,0.5+0.5*l)
                ctr = generate_train2.generate_uniform(0,1)
                if l == 0:
                    net1.seller_backward(args,train_data2,ctr,0.3,key)
                    if np.random.uniform() > 0.5:
                        revenue, _, _, utility, cost, click = net1(train_data2,ctr,0.3)
                        ut = utility.cpu().detach().numpy()
                        losspr12.append(revenue.cpu().detach().numpy())
                        losscost12.append(cost.cpu().detach().numpy())
                        lossclick12.append(click.cpu().detach().numpy())
                    else:
                        losspr12.append(0.5)
                        losscost12.append(0)
                        lossclick12.append(1.)                     
                if l == 1:
                    net2.seller_backward(args,train_data2,ctr,0.3,key)
                    if np.random.uniform() > 0.5:
                        revenue, _, _, utility, cost, click = net2(train_data2,ctr,0.3)
                        ut = utility.cpu().detach().numpy()
                        losspr22.append(revenue.cpu().detach().numpy())
                        losscost22.append(cost.cpu().detach().numpy())
                        lossclick22.append(click.cpu().detach().numpy())
                    else:
                        losspr22.append(0.5)
                        losscost22.append(0)
                        lossclick22.append(1.) 
                if l == 2:
                    net3.seller_backward(args,train_data2,ctr,0.3,key)
                    if np.random.uniform() > 0.5:
                        revenue, _, _, utility, cost, click = net3(train_data2,ctr,0.3)
                        ut = utility.cpu().detach().numpy()
                        losspr32.append(revenue.cpu().detach().numpy())
                        losscost32.append(cost.cpu().detach().numpy())
                        lossclick32.append(click.cpu().detach().numpy())
                    else:
                        losspr32.append(0.5)
                        losscost32.append(0)
                        lossclick32.append(1.) 
                if l == 3:
                    net4.seller_backward(args,train_data2,ctr,0.3,key)      
                    if np.random.uniform() > 0.5:
                        revenue, _, _, utility, cost, click = net4(train_data2,ctr,0.3)
                        ut = utility.cpu().detach().numpy()
                        losspr42.append(revenue.cpu().detach().numpy())
                        losscost42.append(cost.cpu().detach().numpy())
                        lossclick42.append(click.cpu().detach().numpy())
                    else:
                        losspr42.append(0.5)
                        losscost42.append(0)
                        lossclick42.append(1.) 
           
        
        if i%2 == 0 :
            print('i=',i)
        if i%20 == 0 and i>100 :    

            plt.plot(perc1[-20:], label='percentage')
            plt.plot(perc2[-20:], label='percentage')
            plt.plot(perc3[-20:], label='percentage')
            plt.plot(perc4[-20:], label='percentage')
            plt.show()

            plt.plot(alpha11[-20:],label='1')
            plt.plot(alpha22[-20:],label='2')
            plt.plot(alpha33[-20:],label='3')
            plt.plot(alpha44[-20:],label='4')
            plt.show()
        if i%200 == 0 :
            fig, axs = plt.subplots(2, 2, dpi=600, figsize=(10, 6))

            axs[0, 0].plot(average_by_tens(cumulative_average(losscost12)), 'orange', marker='o', markevery=10, label='cost-RegretNet')
            axs[0, 0].plot(average_by_tens(cumulative_average(losscost1)), 'red', marker='s', markevery=10, label='cost-AMMD')
            axs[0, 0].plot(average_by_tens(cumulative_average(lossclick12)), 'cyan', marker='v', markevery=10, label='click-RegretNet')
            axs[0, 0].plot(average_by_tens(cumulative_average(lossclick1)), 'blue', marker='x', markevery=10, label='click-AMMD')
            axs[0, 0].plot(average_by_tens(perc1), 'green', marker='p', markevery=10, label='orgs percentage for AMMD')
            #axs[0, 0].legend()
            axs[0, 0].set_xlabel('rollouts',size=13)
            axs[0, 0].set_ylabel('score',size=13)
            axs[0, 0].set_title('Distribution 1',size=13)

            axs[0, 1].plot(average_by_tens(cumulative_average(losscost22)), 'orange', marker='o', markevery=10, label='cost-RegretNet')
            axs[0, 1].plot(average_by_tens(cumulative_average(losscost2)), 'red', marker='s', markevery=10, label='cost-AMMD')
            axs[0, 1].plot(average_by_tens(cumulative_average(lossclick22)), 'cyan', marker='v', markevery=10, label='click-RegretNet')
            axs[0, 1].plot(average_by_tens(cumulative_average(lossclick2)), 'blue', marker='x', markevery=10, label='click-AMMD')
            axs[0, 1].plot(average_by_tens(perc2), 'green', markevery=10, marker='p', label='orgs percentage for AMMD')
            #axs[0, 1].legend()
            axs[0, 1].set_xlabel('rollouts',size=13)
            axs[0, 1].set_ylabel('score',size=13)
            axs[0, 1].set_title('Distribution 2',size=13)

            axs[1, 0].plot(average_by_tens(cumulative_average(losscost32)), 'orange', marker='o', markevery=10, label='cost-RegretNet')
            axs[1, 0].plot(average_by_tens(cumulative_average(losscost3)), 'red', marker='s', markevery=10, label='cost-AMMD')
            axs[1, 0].plot(average_by_tens(cumulative_average(lossclick32)), 'cyan', marker='v', markevery=10, label='click-RegretNet')
            axs[1, 0].plot(average_by_tens(cumulative_average(lossclick3)), 'blue', marker='x', markevery=10, label='click-AMMD')
            axs[1, 0].plot(average_by_tens(perc3), 'green', marker='p', markevery=10, label='orgs percentage for AMMD')
            axs[1, 0].set_xlabel('rollouts',size=13)
            axs[1, 0].set_ylabel('score',size=13)
            axs[1, 0].set_title('Distribution 3',size=13)

            axs[1, 1].plot(average_by_tens(cumulative_average(losscost42)), 'orange', marker='o', markevery=10, label='cost-RegretNet')
            axs[1, 1].plot(average_by_tens(cumulative_average(losscost4)), 'red', marker='s', markevery=10, label='cost-AMMD')
            axs[1, 1].plot(average_by_tens(cumulative_average(lossclick42)), 'cyan', marker='v', markevery=10, label='click-RegretNet')
            axs[1, 1].plot(average_by_tens(cumulative_average(lossclick4)), 'blue', marker='x', markevery=10, label='click-AMMD')
            axs[1, 1].plot(average_by_tens(perc4), 'green', marker='p', markevery=10, label='orgs percentage for AMMD')
            #axs[1, 1].legend()
            axs[1, 1].set_xlabel('rollouts',size=13)
            axs[1, 1].set_ylabel('score',size=13)
            axs[1, 1].set_title('Distribution 4',size=13)
            
            lines, labels = fig.axes[-1].get_legend_handles_labels()

            fig.legend( lines, labels, bbox_to_anchor=(0.97, 0.),          
                ncol=5, framealpha=1, prop = {'size':10})

            plt.tight_layout()

            plt.show()   





if __name__ == "__main__":

            
    args2 = Args((2,1,"uniform",10,10,10000,10000,1))  
    args = Args((2,1,"uniform",10,10,20000,20000,1)) 
    generate_train2 = Generator2(args)
    generate_test2 = Generator2(args)
    train_data = generate_train2.generate_uniform(0,1)
    test_data = generate_test2.generate_uniform(0,1)

    net1 = Line_Net(args, train_data, test_data)
    net2 = Line_Net(args, train_data, test_data)
    net3 = Line_Net(args, train_data, test_data)
    net4 = Line_Net(args, train_data, test_data)
    generate_train = Generator(args)
    generate_test = Generator(args)
    train_linear(args,args2,net1,net2,net3,net4,200001) 