from __future__ import division
from __future__ import print_function

import os
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from random import random
from datetime import datetime



device = t.device("cpu")

class PIDController:
    def __init__(self,Kp,Ki,Kd,target):
        self.Kp = Kp
        self.Ki = Ki
        self.Kd = Kd
        self.target = target
        self.cumulate = np.ones(24) * 0.5
        self.integral = 0
        self.prev_error = 0
    def update(self, current_value):
        error = self.target - current_value
        self.integral = self.integral + error
        for i in range(23):
            self.cumulate[i] = self.cumulate[i+1]
        self.cumulate[-1] = error
        derivative = (error - self.prev_error)
        output = self.Kp * error + self.Ki * np.mean(self.cumulate) + self.Kd * derivative
        self.prev_error = error
        return output

class Args():
    
    def __init__(self,args):
        self.num_agent = args[0]
        self.num_item = args[1]
        self.distribution_type = args[2]
        self.num_linear = args[3]
        self.num_max = args[4]
        self.num_sample_train = args[5]
        self.num_sample_test = args[6]
        self.seed_val = args[7]
        
args = Args((4,1,"uniform",20,20,10000,10000,1))  

class Generator():
    
    def __init__(self, args):
        self.args = args
        
    def generate_uniform(self, low, high):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(low,high,[num_instances,num_agent])
        return sample_val

    def generate_ctr(self,low,high):
        num_instances = self.args.num_sample_train
        sample_val = np.random.uniform(low,high,[num_instances])
        return sample_val

    def generate_uniform2(self, low, high):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(low,high,[num_instances,num_agent])
        sample_val[:,0] = np.random.uniform(0,1,[num_instances])
        return sample_val
    
    def generate_asymmetry(self):
        num_instances = self.args.num_sample_train
        num_agent = self.args.num_agent
        sample_val = np.random.uniform(0,1,[num_instances,num_agent])
        sample_val[-num_instances/2:-1,:] = np.random.uniform(0,2,[-num_instances/2,num_agent])
        return sample_val        
    

def get_position_ctr(position,choice):
    ctr = np.zeros(position.shape)
    if choice == 2:
        for i in range(position.shape[0]):
            for j in range(position.shape[1]):
                if position[i,j] == 0:
                    ctr[i,j] = 1.
                if position[i,j] == 1:
                    ctr[i,j] = 0.99
                if position[i,j] == 2:
                    ctr[i,j] = 0.6
                if position[i,j] == 3:
                    ctr[i,j] = 0.5
    if choice == 1:
        for i in range(position.shape[0]):
            if position[i] == 0:
                ctr[i] = 1.
            if position[i] == 1:
                ctr[i] = 0.99
            if position[i] == 2:
                ctr[i] = 0.6
            if position[i] == 3:
                ctr[i] = 0.5        
    return ctr


def get_pos(position,choice):
    if position == 0:
        return 1.
    if position == 1:
        return 0.99
    if position == 2:
        return 0.6
    if position == 3:
        return 0.5

def torch_get_position_ctr(position,choice):
    ctr = t.zeros(position.shape).float().to(device)
    if choice == 2:
        for i in range(position.shape[0]):
            for j in range(position.shape[1]):
                if position[i,j] == 0:
                    ctr[i,j] = 1.
                if position[i,j] == 1:
                    ctr[i,j] = 0.99
                if position[i,j] == 2:
                    ctr[i,j] = 0.6
                if position[i,j] == 3:
                    ctr[i,j] = 0.5
    if choice == 1:
        for i in range(position.shape[0]):
            if position[i] == 0:
                ctr[i] = 1.
            if position[i] == 1:
                ctr[i] = 0.99
            if position[i] == 2:
                ctr[i] = 0.6
            if position[i] == 3:
                ctr[i] = 0.5        
    return ctr


def get_sort(array):
    col_indicies = np.zeros(array.shape[0])
    for i in range(array.shape[0]):
        for j in range(array.shape[1]):
            if array[i,j] == 0:
                col_indicies[i]=int(j)
    return col_indicies

def alpha_VCG_Mechanism(value_ads,ctr_ads,ctr_org,alpha):
    Value_ads = value_ads * ctr_ads + alpha * ctr_ads
    Value_org = alpha * ctr_org
    Value = np.hstack((Value_ads,Value_org))
    Value_sorted_indices = np.argsort(-Value, axis=1)
    Value_sorted = np.take_along_axis(Value, Value_sorted_indices, axis=1)
    Value_ads_sorted_indices = np.argsort(-Value_ads, axis=1)
    Value_ads_sorted = np.take_along_axis(Value_ads, Value_ads_sorted_indices, axis=1)
    Value_org_sorted_indices = np.argsort(-Value_org, axis=1)
    Value_org_sorted = np.take_along_axis(Value_org, Value_org_sorted_indices, axis=1)
    Ads_que = np.zeros([value_ads.shape[0],2])
    Org_que = np.zeros([value_ads.shape[0],2])
    #Pos_que = np.zeros([value_ads.shape[0],2])
    for i in range(Value.shape[0]):
        ads_po = 0
        org_po = 0
        for j in range(4):
            if Value_sorted[i,j] in Value_ads[i]:
                if Ads_que[i,-1] == 0:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1
                else:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1 
            else:
                if Org_que[i,-1] == 0:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1             
                else:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1     

    Ctr_sorted = np.take_along_axis(ctr_ads, Value_ads_sorted_indices, axis=1)
    Ctr_sorted2 = np.take_along_axis(ctr_org, Value_org_sorted_indices, axis=1)
    click = np.mean(np.sum(Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1) + np.sum(Ctr_sorted2[:,0:2] * get_position_ctr(Org_que[:,0:2],2), axis=1))
    
    Payment = np.zeros([value_ads.shape[0],2])
    Payment[:,0] = ((Value_ads_sorted[:,1] * get_position_ctr(Ads_que[:,0],1) +  Value_ads_sorted[:,2] * get_position_ctr(Ads_que[:,1],1) -  Value_ads_sorted[:,1] * get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,0] * get_position_ctr(Ads_que[:,0],1)) / (Ctr_sorted[:,0] * get_position_ctr(Ads_que[:,0],1))
    Payment[:,1] = ((Value_ads_sorted[:,2] * get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,1] * get_position_ctr(Ads_que[:,1],1)) / (Ctr_sorted[:,1] * get_position_ctr(Ads_que[:,1],1))
    
    #cost = np.mean(np.sum(Payment * Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1))
    cost = np.mean(np.sum(np.maximum(Payment,0) * Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1))
    return click, cost


def alpha_VCG_Mechanism_IC(value_ads,ctr_ads,ctr_org,alpha,bid):
    value_ads2 = value_ads.copy()
    value_ads2[:,0] = value_ads2[:,0] * bid
    Value_ads = value_ads2 * ctr_ads + alpha * ctr_ads
    Value_org = alpha * ctr_org
    Value = np.hstack((Value_ads,Value_org))
    Value_sorted_indices = np.argsort(-Value, axis=1)
    Value_sorted = np.take_along_axis(Value, Value_sorted_indices, axis=1)
    Value_ads_sorted_indices = np.argsort(-Value_ads, axis=1)
    Value_ads_sorted = np.take_along_axis(Value_ads, Value_ads_sorted_indices, axis=1)
    Value_org_sorted_indices = np.argsort(-Value_org, axis=1)
    Value_org_sorted = np.take_along_axis(Value_org, Value_org_sorted_indices, axis=1)
    first_col_indices = get_sort(Value_ads_sorted_indices)
    Ads_que = np.zeros([value_ads.shape[0],2])
    Org_que = np.zeros([value_ads.shape[0],2])
    #Pos_que = np.zeros([value_ads.shape[0],2])
    for i in range(Value.shape[0]):
        ads_po = 0
        org_po = 0
        for j in range(4):
            if Value_sorted[i,j] in Value_ads[i]:
                if Ads_que[i,-1] == 0:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1
                else:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1 
            else:
                if Org_que[i,-1] == 0:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1             
                else:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1     

    Ctr_sorted = np.take_along_axis(ctr_ads, Value_ads_sorted_indices, axis=1)
    Ctr_sorted2 = np.take_along_axis(ctr_org, Value_org_sorted_indices, axis=1)
    click = np.mean(np.sum(Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1) + np.sum(Ctr_sorted2[:,0:2] * get_position_ctr(Org_que[:,0:2],2), axis=1))
    
    Payment = np.zeros([value_ads.shape[0],2])
    Payment[:,0] = ((Value_ads_sorted[:,1] * get_position_ctr(Ads_que[:,0],1) +  Value_ads_sorted[:,2] * get_position_ctr(Ads_que[:,1],1) -  Value_ads_sorted[:,1] * get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,0] * get_position_ctr(Ads_que[:,0],1)) / (Ctr_sorted[:,0] * get_position_ctr(Ads_que[:,0],1))
    Payment[:,1] = ((Value_ads_sorted[:,2] * get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,1] * get_position_ctr(Ads_que[:,1],1)) / (Ctr_sorted[:,1] * get_position_ctr(Ads_que[:,1],1))
    
    #cost = np.mean(np.sum(Payment * Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1))
    cost = np.mean(np.sum(np.maximum(Payment,0) * Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1))
    first_pay = np.zeros(value_ads.shape[0])
    for kk in range(value_ads.shape[0]):
        if first_col_indices[kk] < 2:
            first_pay[kk] = value_ads[kk,0] * ctr_ads[kk,0] * get_pos(first_col_indices[kk],1) - Payment[kk,int(first_col_indices[kk])] * ctr_ads[kk,0] * get_pos(first_col_indices[kk],1)
    return np.mean(first_pay)


def alpha_GSP_Mechanism(value_ads,ctr_ads,ctr_org,alpha):
    Value_ads = value_ads * ctr_ads + alpha * ctr_ads
    Value_org = alpha * ctr_org
    Value = np.hstack((Value_ads,Value_org))
    Value_sorted_indices = np.argsort(-Value, axis=1)
    Value_sorted = np.take_along_axis(Value, Value_sorted_indices, axis=1)
    Value_ads_sorted_indices = np.argsort(-Value_ads, axis=1)
    Value_ads_sorted = np.take_along_axis(Value_ads, Value_ads_sorted_indices, axis=1)
    Value_org_sorted_indices = np.argsort(-Value_org, axis=1)
    Value_org_sorted = np.take_along_axis(Value_org, Value_org_sorted_indices, axis=1)
    Ads_que = np.zeros([value_ads.shape[0],2])
    Org_que = np.zeros([value_ads.shape[0],2])
    #Pos_que = np.zeros([value_ads.shape[0],2])
    for i in range(Value.shape[0]):
        ads_po = 0
        org_po = 0
        for j in range(4):
            if Value_sorted[i,j] in Value_ads[i]:
                if Ads_que[i,-1] == 0:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1
                else:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1 
            else:
                if Org_que[i,-1] == 0:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1             
                else:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1                
    

    Ctr_sorted = np.take_along_axis(ctr_ads, Value_ads_sorted_indices, axis=1)
    Ctr_sorted2 = np.take_along_axis(ctr_org, Value_org_sorted_indices, axis=1)
    click = np.mean(np.sum(Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1) + np.sum(Ctr_sorted2[:,0:2] * get_position_ctr(Org_que[:,0:2],2), axis=1))
    
    Payment = np.zeros([value_ads.shape[0],2])
    Payment[:,0] = ((Value_ads_sorted[:,1]) - alpha * Ctr_sorted[:,0]) / (Ctr_sorted[:,0])
    Payment[:,1] = ((Value_ads_sorted[:,2]) - alpha * Ctr_sorted[:,1]) / (Ctr_sorted[:,1])
    
    cost = np.mean(np.sum(np.maximum(Payment,0) * Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1))
    return click, cost    


def alpha_GSP_Mechanism_IC(value_ads,ctr_ads,ctr_org,alpha,bid):
    value_ads2 = value_ads.copy()
    value_ads2[:,0] = value_ads2[:,0] * bid
    Value_ads = value_ads2 * ctr_ads + alpha * ctr_ads
    Value_org = alpha * ctr_org
    Value = np.hstack((Value_ads,Value_org))
    Value_sorted_indices = np.argsort(-Value, axis=1)
    Value_sorted = np.take_along_axis(Value, Value_sorted_indices, axis=1)
    Value_ads_sorted_indices = np.argsort(-Value_ads, axis=1)
    Value_ads_sorted = np.take_along_axis(Value_ads, Value_ads_sorted_indices, axis=1)
    Value_org_sorted_indices = np.argsort(-Value_org, axis=1)
    Value_org_sorted = np.take_along_axis(Value_org, Value_org_sorted_indices, axis=1)
    first_col_indices = get_sort(Value_ads_sorted_indices)
    Ads_que = np.zeros([value_ads.shape[0],2])
    Org_que = np.zeros([value_ads.shape[0],2])
    #Pos_que = np.zeros([value_ads.shape[0],2])
    for i in range(Value.shape[0]):
        ads_po = 0
        org_po = 0
        for j in range(4):
            if Value_sorted[i,j] in Value_ads[i]:
                if Ads_que[i,-1] == 0:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1
                else:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1 
            else:
                if Org_que[i,-1] == 0:
                    Org_que[i,org_po] = j
                    org_po = org_po + 1             
                else:
                    Ads_que[i,ads_po] = j
                    ads_po = ads_po + 1                
    

    Ctr_sorted = np.take_along_axis(ctr_ads, Value_ads_sorted_indices, axis=1)
    Ctr_sorted2 = np.take_along_axis(ctr_org, Value_org_sorted_indices, axis=1)
    click = np.mean(np.sum(Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1) + np.sum(Ctr_sorted2[:,0:2] * get_position_ctr(Org_que[:,0:2],2), axis=1))
    
    Payment = np.zeros([value_ads.shape[0],2])
    Payment[:,0] = ((Value_ads_sorted[:,1]) - alpha * Ctr_sorted[:,0]) / (Ctr_sorted[:,0])
    Payment[:,1] = ((Value_ads_sorted[:,2]) - alpha * Ctr_sorted[:,1]) / (Ctr_sorted[:,1])
    
    cost = np.mean(np.sum(np.maximum(Payment,0) * Ctr_sorted[:,0:2] * get_position_ctr(Ads_que[:,0:2],2), axis=1))
    first_pay = np.zeros(value_ads.shape[0])
    for kk in range(value_ads.shape[0]):
        if first_col_indices[kk] < 2:
            first_pay[kk] = value_ads[kk,0] * ctr_ads[kk,0] * get_pos(first_col_indices[kk],1) - Payment[kk,int(first_col_indices[kk])] * ctr_ads[kk,0] * get_pos(first_col_indices[kk],1)
    return np.mean(first_pay)

def deterministic_NeuralSort(s, tau):  
    n = s.size()[1]
    one = t.ones((n,1), dtype = t.float32).to(device)
        
    A_s = t.abs(s-s.permute(0,2,1)).float()
    B = t.matmul(A_s, t.matmul(one, t.transpose(one, 0, 1)))
    scaling = (n+1 - 2*(t.arange(n)+1)).type(t.float32).to(device)
    C = t.matmul(s.float(), scaling.unsqueeze(0))
        
    P_max = (C-B).permute(0, 2, 1)
    sm = nn.Softmax(-1)
    P_hat = sm(P_max / tau)   
    return P_hat




class Score_VCG(nn.Module):
    def __init__(self,args):
        nn.Module.__init__(self)
        self.args = args
        #self.train_data = train_data
        #self.test_data = test_data
        

        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent 

        self.seller_w_init = np.random.normal(size = (num_agent)) / 5
        self.seller_w2_init = np.random.normal(size = (num_agent)) / 5
        self.seller_b_init = -np.random.rand( num_agent) * 1.0
        self.seller_w = t.tensor(self.seller_w_init, device = device, requires_grad=True).to(device)
        self.seller_w2 = t.tensor(self.seller_w2_init, device = device, requires_grad=True).to(device)
        self.seller_b = t.tensor(self.seller_b_init, device = device, requires_grad=True).to(device)
      
    def forward(self,x, ctr_ads, ctr_og, alpha):

        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent              
        batch_size = t.tensor(x,device = device).size()[0]
        x = t.tensor(x,device = device)
        ctr_ads = t.tensor(ctr_ads,device = device).float()
        ctr_og = t.tensor(ctr_og,device = device).float() 
        
        w_copy = t.reshape(self.seller_w.repeat([batch_size,1]),[batch_size,num_agent])
        b_copy = t.reshape(self.seller_b.repeat([batch_size,1]),[batch_size,num_agent])     
        #xx_copy = t.reshape(x.repeat([1]),[batch_size,num_agent])
        #vv_max_units = t.max(t.mul(xx_copy, t.exp(w_copy)) + b_copy,2).values    
        vv = t.mul(x, t.exp(w_copy)) + b_copy
        
        Value_ads = vv * ctr_ads + alpha * ctr_ads
        Value_org = alpha * ctr_og
        Value = t.hstack((Value_ads,Value_org))
        Value_ads_sorted_indices = deterministic_NeuralSort(Value_ads.unsqueeze(-1), tau=0.001)
        Value_ads_sorted = (Value_ads_sorted_indices.float() @ Value_ads.float().unsqueeze(-1)).squeeze(-1)
        Value_org_sorted_indices = deterministic_NeuralSort(Value_org.unsqueeze(-1), tau=0.001)
        Value_org_sorted = (Value_org_sorted_indices.float() @ Value_org.float().unsqueeze(-1)).squeeze(-1)
        
        
        Ads_que = t.zeros([Value_ads.shape[0],2])
        Org_que = t.zeros([Value_ads.shape[0],2])
        Value_sorted = t.hstack((Value_ads_sorted[:,0:2],Value_org_sorted[:,0:2]))
        Value_sorted2 = -t.sort(-Value_sorted).values
        #Pos_que = np.zeros([value_ads.shape[0],2])
        for i in range(Value.shape[0]):
            ads_po = 0
            org_po = 0
            for j in range(4):
                if Value_sorted2[i,j] in Value_ads_sorted[i]:
                    if Ads_que[i,-1] == 0:
                        Ads_que[i,ads_po] = j
                        ads_po = ads_po + 1
                    else:
                        Org_que[i,org_po] = j
                        org_po = org_po + 1 
                else:
                    if Org_que[i,-1] == 0:
                        Org_que[i,org_po] = j
                        org_po = org_po + 1             
                    else:
                        Ads_que[i,ads_po] = j
                        ads_po = ads_po + 1     

        Ctr_sorted = (Value_ads_sorted_indices.float() @ ctr_ads.float().unsqueeze(-1)).squeeze(-1)
        Ctr_sorted2 = (Value_org_sorted_indices.float() @ ctr_og.float().unsqueeze(-1)).squeeze(-1)
        click = t.mean(t.sum(Ctr_sorted[:,0:2] * torch_get_position_ctr(Ads_que[:,0:2],2), axis=1) + t.sum(Ctr_sorted2[:,0:2] * torch_get_position_ctr(Org_que[:,0:2],2), axis=1))
    
        Payment = t.zeros([Value_ads.shape[0],2]).to(device)
        Payment[:,0] = ((Value_ads_sorted[:,1] * torch_get_position_ctr(Ads_que[:,0],1) +  Value_ads_sorted[:,2] * torch_get_position_ctr(Ads_que[:,1],1) -  Value_ads_sorted[:,1] * torch_get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,0] * torch_get_position_ctr(Ads_que[:,0],1)) / (Ctr_sorted[:,0] * torch_get_position_ctr(Ads_que[:,0],1))
        Payment[:,1] = ((Value_ads_sorted[:,2] * torch_get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,1] * torch_get_position_ctr(Ads_que[:,1],1)) / (Ctr_sorted[:,1] * torch_get_position_ctr(Ads_que[:,1],1))
        
        w_spa = (Value_ads_sorted_indices.float() @ w_copy.float().unsqueeze(-1)).squeeze(-1)
        #w_spa = t.reshape(w_spa,[batch_size,10,10,4])
        b_spa = (Value_ads_sorted_indices.float() @ b_copy.float().unsqueeze(-1)).squeeze(-1)
        #b_spa = t.reshape(w_spa,[batch_size,10,10,4])
        
        pay = t.hstack((Payment,t.zeros(batch_size,2).to(device)))
        #p_spa_copy = t.reshape(pay.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func, num_agent])  
        #p_max_units = t.min(t.mul(p_spa_copy - b_spa,t.reciprocal(t.exp(w_spa))),2).values 
        p = F.relu(t.mul(pay - b_spa,t.reciprocal(t.exp(w_spa))))[:,0:2]
        
        cost = t.mean(t.sum(p * Ctr_sorted[:,0:2] * torch_get_position_ctr(Ads_que[:,0:2],2), axis=1))
        
        revenue = cost + alpha * click
        
        return revenue, click, cost       
        
    
    def ic_check(self,x, ctr_ads, ctr_og, alpha, bid):

        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent              
        batch_size = t.tensor(x,device = device).size()[0]
        x = t.tensor(x,device = device)
        x2 = x.clone()
        x2[:,0] = x2[:,0] * bid
        ctr_ads = t.tensor(ctr_ads,device = device).float()
        ctr_og = t.tensor(ctr_og,device = device).float() 
        
        w_copy = t.reshape(self.seller_w.repeat([batch_size,1]),[batch_size,num_agent])
        b_copy = t.reshape(self.seller_b.repeat([batch_size,1]),[batch_size,num_agent])     
        #xx_copy = t.reshape(x.repeat([1]),[batch_size,num_agent])
        #vv_max_units = t.max(t.mul(xx_copy, t.exp(w_copy)) + b_copy,2).values    
        vv = t.mul(x2, t.exp(w_copy)) + b_copy
        
        Value_ads = vv * ctr_ads + alpha * ctr_ads
        Value_org = alpha * ctr_og
        Value = t.hstack((Value_ads,Value_org))
        Value_ads_sorted_indices = t.argsort(-Value_ads,dim=1)
        Value_ads_sorted = t.take_along_dim(Value_ads,Value_ads_sorted_indices,dim=1)
        Value_org_sorted_indices = t.argsort(-Value_org,dim=1)
        Value_org_sorted = t.take_along_dim(Value_org,Value_org_sorted_indices,dim=1)

        first_col_indices = get_sort(Value_ads_sorted_indices)
        
        Ads_que = t.zeros([Value_ads.shape[0],2])
        Org_que = t.zeros([Value_ads.shape[0],2])
        Value_sorted = t.hstack((Value_ads_sorted[:,0:2],Value_org_sorted[:,0:2]))
        Value_sorted2 = -t.sort(-Value_sorted).values
        #Pos_que = np.zeros([value_ads.shape[0],2])
        for i in range(Value.shape[0]):
            ads_po = 0
            org_po = 0
            for j in range(4):
                if Value_sorted2[i,j] in Value_ads_sorted[i]:
                    if Ads_que[i,-1] == 0:
                        Ads_que[i,ads_po] = j
                        ads_po = ads_po + 1
                    else:
                        Org_que[i,org_po] = j
                        org_po = org_po + 1 
                else:
                    if Org_que[i,-1] == 0:
                        Org_que[i,org_po] = j
                        org_po = org_po + 1             
                    else:
                        Ads_que[i,ads_po] = j
                        ads_po = ads_po + 1     

        Ctr_sorted = t.take_along_dim(ctr_ads,Value_ads_sorted_indices,dim=1)
        Ctr_sorted2 = t.take_along_dim(ctr_og,Value_org_sorted_indices,dim=1)
        click = t.mean(t.sum(Ctr_sorted[:,0:2] * torch_get_position_ctr(Ads_que[:,0:2],2), axis=1) + t.sum(Ctr_sorted2[:,0:2] * torch_get_position_ctr(Org_que[:,0:2],2), axis=1))
    
        Payment = t.zeros([Value_ads.shape[0],2]).to(device)
        Payment[:,0] = ((Value_ads_sorted[:,1] * torch_get_position_ctr(Ads_que[:,0],1) +  Value_ads_sorted[:,2] * torch_get_position_ctr(Ads_que[:,1],1) -  Value_ads_sorted[:,1] * torch_get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,0] * torch_get_position_ctr(Ads_que[:,0],1)) / (Ctr_sorted[:,0] * torch_get_position_ctr(Ads_que[:,0],1))
        Payment[:,1] = ((Value_ads_sorted[:,2] * torch_get_position_ctr(Ads_que[:,1],1)) - alpha * Ctr_sorted[:,1] * torch_get_position_ctr(Ads_que[:,1],1)) / (Ctr_sorted[:,1] * torch_get_position_ctr(Ads_que[:,1],1))
        
        w_spa = t.take_along_dim(w_copy,Value_ads_sorted_indices,dim=1)
        #w_spa = t.reshape(w_spa,[batch_size,10,10,4])
        b_spa = t.take_along_dim(b_copy,Value_ads_sorted_indices,dim=1)
        #b_spa = t.reshape(w_spa,[batch_size,10,10,4])
        
        pay = t.hstack((Payment,t.zeros(batch_size,2).to(device)))
        #p_spa_copy = t.reshape(pay.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func, num_agent])  
        #p_max_units = t.min(t.mul(p_spa_copy - b_spa,t.reciprocal(t.exp(w_spa))),2).values 
        p = F.relu(t.mul(pay - b_spa,t.reciprocal(t.exp(w_spa))))[:,0:2]
        

        first_pay = t.zeros(Value_ads.shape[0]).to(device)
        for kk in range(Value_ads.shape[0]):
            if first_col_indices[kk] < 2:
                first_pay[kk] = x[kk,0] * ctr_ads[kk,0] * get_pos(Ads_que[kk,int(first_col_indices[kk])],1) - Payment[kk,int(first_col_indices[kk])] * ctr_ads[kk,0] * get_pos(Ads_que[kk,int(first_col_indices[kk])],1)
        return t.mean(first_pay).cpu().detach().numpy()
    


    def seller_backward(self,args,x,ctr_ads,ctr_og,alpha):
     
        input = x
        output = self.forward(input,ctr_ads,ctr_og,alpha)
        loss = -output[0]
        loss.backward() # fake backward

        self.seller_w.data.sub_(0.01 * self.seller_w.grad.data)
        #self.seller_w2.data.sub_(0.1 * self.seller_w2.grad.data)
        self.seller_b.data.sub_(0.01 * self.seller_b.grad.data)

        self.seller_w.grad.data.zero_()
        #self.seller_w2.grad.data.zero_()
        self.seller_b.grad.data.zero_()
        #self.bidder_w.grad.data.zero_()
  


class MLP(nn.Module):
    def __init__(self, layers, activation):
        super(MLP, self).__init__()
        self.layers_list = nn.ModuleList([nn.Linear(layers[i], layers[i+1]) for i in range(len(layers) - 1)])
        self.activation = activation

    def forward(self, x):
        for j, layer in enumerate(self.layers_list):
            if j == len(self.layers_list)-1:
                x = 5 * t.tanh(layer(x))
            else:
                x = self.activation(layer(x))
        return x
    
class MLP2(nn.Module):
    def __init__(self, layers, activation):
        super(MLP2, self).__init__()
        self.layers_list = nn.ModuleList([nn.Linear(layers[i], layers[i+1]) for i in range(len(layers) - 1)])
        self.activation = activation

    def forward(self, x):
        for j, layer in enumerate(self.layers_list):
            if j == len(self.layers_list)-1:
                x = 5 * t.sigmoid(layer(x))
            else:
                x = self.activation(layer(x))
        return x

class Hyper_VCG(nn.Module):
    def __init__(self, args):
        super(Hyper_VCG, self).__init__()   
        self.args = args

    def forward(self, inputs, ctr_ads, ctr_og, alpha, w1, b1):
        """Computes (approximately) optimal misreports for a given auction."""
        pos_ratio = t.tensor([1.,0.99,0.6,0.5]).float().to(device)
        pos_rati = t.tensor([0.99,0.6,0.5]).float().to(device)
        pos_rat = t.tensor([0.6,0.5]).float().to(device)
        pos_ra = t.tensor([0.5]).float().to(device)

        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent              
        batch_size = t.tensor(inputs).size()[0]
        ctr_ads = t.tensor(ctr_ads).float().to(device)
        ctr_og = t.tensor(ctr_og).float().to(device) 
        x = t.tensor(inputs).float().to(device)  
        #seller_w = t.reshape((w1.repeat(num_func,num_max_units)),[num_func,num_max_units,num_agent])
        #seller_b = t.reshape((b1.repeat(num_func,num_max_units)),[num_func,num_max_units,num_agent])
        w_copy = w1.repeat([batch_size,1])
        b_copy = b1.repeat([batch_size,1]) 
        #xx_copy = t.reshape(x.repeat([1,num_func*num_max_units]),[batch_size,num_func,num_max_units,num_agent])
        
        #vv_max_units = t.mul(xx_copy, t.exp(w_copy)) + b_copy,2).values    
        vv = t.mul(x, t.exp(w_copy)) + b_copy

        Value_ads = vv * ctr_ads + alpha * ctr_ads
        Value_org = alpha * ctr_og
        
        Bid_org = t.zeros(x.shape).to(device)
        
        #Value_ads_sorted_indices = deterministic_NeuralSort(-Value_ads.unsqueeze(-1), tau=1.)
        #Value_ads_sorted = (Value_ads.unsqueeze(-1) * Value_ads_sorted_indices).squeeze(-1)
        #Value_org_sorted_indices = deterministic_NeuralSort(-Value_org.unsqueeze(-1), tau=1.)
        #Value_org_sorted = (Value_org.unsqueeze(-1) * Value_org_sorted_indices).squeeze(-1)
        
        
        
        click = t.zeros(x.shape[0]).to(device)   
        cost = t.zeros(x.shape[0],4).to(device)    
        
        ads_po = 0
        org_po = 0
        
        Value = t.hstack((Value_ads,Value_org)) 
        Bid = t.hstack((x,Bid_org))
        Ctr = t.hstack((ctr_ads,ctr_og))
        Value_sorted_indices = deterministic_NeuralSort(Value.unsqueeze(-1), tau=0.01)
        Value_sorted = (Value_sorted_indices @ Value.unsqueeze(-1)).squeeze(-1)  
        Bid_sorted = (Value_sorted_indices @ Bid.unsqueeze(-1)).squeeze(-1) 
        Ctr_sorted = (Value_sorted_indices @ Ctr.unsqueeze(-1)).squeeze(-1) 
        
        click = t.mean(t.sum(Ctr_sorted[:,0:4]*pos_ratio,1))
        
        cost[:,0] = (t.sum(Value_sorted[:,1:5]*pos_ratio,1) - t.sum(Value_sorted[:,1:4]*pos_rati,1) - alpha * 1. * Ctr_sorted[:,0]) / (1. * Ctr_sorted[:,0]) 
        cost[:,1] = (t.sum(Value_sorted[:,2:5]*pos_rati,1) - t.sum(Value_sorted[:,2:4]*pos_rat,1) - alpha * 0.99 * Ctr_sorted[:,1]) / (0.99 * Ctr_sorted[:,1])
        cost[:,2] = (t.sum(Value_sorted[:,3:5]*pos_rat,1) - t.sum(Value_sorted[:,3:4]*pos_ra,1) - alpha * 0.6 * Ctr_sorted[:,2]) / (0.6 * Ctr_sorted[:,2])
        cost[:,3] = (t.sum(Value_sorted[:,4:5]*pos_ra,1) - alpha * 0.5 * Ctr_sorted[:,3]) / (0.5 * Ctr_sorted[:,3])
        
        #vvvv = t.hstack((vv,t.zeros([batch_size,num_agent])))
        #vvvv = ((Value_sorted_indices @ vvvv.unsqueeze(-1)).squeeze(-1))[:,0:4] 
        
        p_spa = t.transpose(cost,0,1)
        
        ww = t.reshape(w1.repeat(batch_size),[batch_size,num_agent])
        ww2 = t.ones([batch_size,num_agent]).to(device)
        w_copy = ((Value_sorted_indices @ t.hstack((ww,ww2)).unsqueeze(-1)).squeeze(-1))[:,0:4] 
        #w_copy = t.reshape((w_copy.repeat(num_func,num_max_units)),[-1,num_func,num_max_units,num_agent])
        bb = t.reshape(b1.repeat(batch_size),[batch_size,num_agent])
        bb2 = 0.1 * t.ones([batch_size,num_agent]).to(device)
        b_copy = ((Value_sorted_indices @ t.hstack((bb,bb2)).unsqueeze(-1)).squeeze(-1))[:,0:4]
        #b_copy = t.reshape((b_copy.repeat(num_func,num_max_units)),[-1,num_func,num_max_units,num_agent])
        #p_spa_copy = t.reshape(cost.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func,4])  
        #p_spa_copy = t.reshape(p_spa.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func,4])  
        #p_max_units = t.min(t.mul(p_spa_copy - b_copy,t.reciprocal(t.exp(w_copy))),2).values
        p_max_units = t.mul(cost - b_copy,t.reciprocal(t.exp(w_copy)))
        cost = F.relu(p_max_units,1)              
        
        count = 0
        for q1 in range(x.shape[0]):
            for q2 in range(num_agent):
                if Bid_sorted[q1,q2] < 0.001:
                    count = count + 1
        perct = count / (x.shape[0] * num_agent)
        
        
        cost = t.mean(1.*Ctr_sorted[:,0]*cost[:,0]+0.99*Ctr_sorted[:,1]*cost[:,1]+0.6*Ctr_sorted[:,2]*cost[:,2]+0.5*Ctr_sorted[:,3]*cost[:,3])
        
        revenue = cost + alpha * click
        
        return revenue, cost, click, perct

    def ic_check(self,inputs, ctr_ads, ctr_og, alpha, w1, b1, bid):
        inputs2 = inputs.copy()
        inputs2[:,0] = inputs2[:,0] * bid
        pos_ratio = t.tensor([1.,0.99,0.6,0.5]).float().to(device)
        pos_rati = t.tensor([0.99,0.6,0.5]).float().to(device)
        pos_rat = t.tensor([0.6,0.5]).float().to(device)
        pos_ra = t.tensor([0.5]).float().to(device)

        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent              
        batch_size = t.tensor(inputs).size()[0]
        ctr_ads = t.tensor(ctr_ads).float().to(device)
        ctr_og = t.tensor(ctr_og).float().to(device) 
        x = t.tensor(inputs2).float().to(device)  
        #seller_w = t.reshape((w1.repeat(num_func,num_max_units)),[num_func,num_max_units,num_agent])
        #seller_b = t.reshape((b1.repeat(num_func,num_max_units)),[num_func,num_max_units,num_agent])
        w_copy = w1.repeat([batch_size,1])
        b_copy = b1.repeat([batch_size,1]) 
        #xx_copy = t.reshape(x.repeat([1,num_func*num_max_units]),[batch_size,num_func,num_max_units,num_agent])
        
        #vv_max_units = t.mul(xx_copy, t.exp(w_copy)) + b_copy,2).values    
        vv = t.mul(x, t.exp(w_copy)) + b_copy

        Value_ads = vv * ctr_ads + alpha * ctr_ads
        Value_org = alpha * ctr_og
        
        Bid_org = t.zeros(x.shape).to(device)
        
        #Value_ads_sorted_indices = deterministic_NeuralSort(-Value_ads.unsqueeze(-1), tau=1.)
        #Value_ads_sorted = (Value_ads.unsqueeze(-1) * Value_ads_sorted_indices).squeeze(-1)
        #Value_org_sorted_indices = deterministic_NeuralSort(-Value_org.unsqueeze(-1), tau=1.)
        #Value_org_sorted = (Value_org.unsqueeze(-1) * Value_org_sorted_indices).squeeze(-1)
        
        
        
        click = t.zeros(x.shape[0]).to(device)   
        cost = t.zeros(x.shape[0],4).to(device)    
        
        ads_po = 0
        org_po = 0
        
        Value = t.hstack((Value_ads,Value_org)) 
        Bid = t.hstack((x,Bid_org))
        Ctr = t.hstack((ctr_ads,ctr_og))


        Value_sorted_indices = t.argsort(-Value,dim=1)

        Value_sorted =t.take_along_dim(Value,Value_sorted_indices,dim=1)  
        Bid_sorted = t.take_along_dim(Bid,Value_sorted_indices,dim=1) 
        Ctr_sorted = t.take_along_dim(Ctr,Value_sorted_indices,dim=1) 
        

        Value_sorted_indices2 = t.argsort(-Value,dim=1)
        first_col_indices = get_sort(Value_sorted_indices2)

        click = t.mean(t.sum(Ctr_sorted[:,0:4]*pos_ratio,1))
        
        cost[:,0] = (t.sum(Value_sorted[:,1:5]*pos_ratio,1) - t.sum(Value_sorted[:,1:4]*pos_rati,1) - alpha * 1. * Ctr_sorted[:,0]) / (1. * Ctr_sorted[:,0]) 
        cost[:,1] = (t.sum(Value_sorted[:,2:5]*pos_rati,1) - t.sum(Value_sorted[:,2:4]*pos_rat,1) - alpha * 0.99 * Ctr_sorted[:,1]) / (0.99 * Ctr_sorted[:,1])
        cost[:,2] = (t.sum(Value_sorted[:,3:5]*pos_rat,1) - t.sum(Value_sorted[:,3:4]*pos_ra,1) - alpha * 0.6 * Ctr_sorted[:,2]) / (0.6 * Ctr_sorted[:,2])
        cost[:,3] = (t.sum(Value_sorted[:,4:5]*pos_ra,1) - alpha * 0.5 * Ctr_sorted[:,3]) / (0.5 * Ctr_sorted[:,3])
        
        #vvvv = t.hstack((vv,t.zeros([batch_size,num_agent])))
        #vvvv = ((Value_sorted_indices @ vvvv.unsqueeze(-1)).squeeze(-1))[:,0:4] 
        
        p_spa = t.transpose(cost,0,1)
        
        ww = t.reshape(w1.repeat(batch_size),[batch_size,num_agent])
        ww2 = t.ones([batch_size,num_agent]).to(device)
        w_copy = t.take_along_dim(t.hstack((ww,ww2)),Value_sorted_indices,dim=1)[:,0:4] 
        #w_copy = t.reshape((w_copy.repeat(num_func,num_max_units)),[-1,num_func,num_max_units,num_agent])
        bb = t.reshape(b1.repeat(batch_size),[batch_size,num_agent])
        bb2 = 0.1 * t.ones([batch_size,num_agent]).to(device)
        b_copy = t.take_along_dim(t.hstack((bb,bb2)),Value_sorted_indices,dim=1)[:,0:4]
        #b_copy = t.reshape((b_copy.repeat(num_func,num_max_units)),[-1,num_func,num_max_units,num_agent])
        #p_spa_copy = t.reshape(cost.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func,4])  
        #p_spa_copy = t.reshape(p_spa.repeat( [1, num_func * num_max_units]),[batch_size, num_max_units, num_func,4])  
        #p_max_units = t.min(t.mul(p_spa_copy - b_copy,t.reciprocal(t.exp(w_copy))),2).values
        p_max_units = t.mul(cost - b_copy,t.reciprocal(t.exp(w_copy)))
        cost = F.relu(p_max_units,1)              
        
        count = 0
        for q1 in range(x.shape[0]):
            for q2 in range(num_agent):
                if Bid_sorted[q1,q2] < 0.001:
                    count = count + 1
        perct = count / (x.shape[0] * num_agent)
        
        
        first_pay = t.zeros(Value_ads.shape[0]).to(device)
        for kk in range(Value_ads.shape[0]):
            if first_col_indices[kk] < 4:
                first_pay[kk] = inputs[kk,0] * ctr_ads[kk,0] * get_pos(first_col_indices[kk],1) - cost[kk,int(first_col_indices[kk])] * ctr_ads[kk,0] * get_pos(first_col_indices[kk],1)
        return t.mean(first_pay).cpu().detach().numpy()

    


class Learner:
    """Two Player Auction Learner."""

    def __init__(self, args):
        self.args = args
        num_func      = self.args.num_linear
        num_max_units = self.args.num_max
        num_agent     = self.args.num_agent 
        # Define the PyTorch models
        
        self.auct_model = Hyper_VCG(args)
        #self.Bidder_net2 = HyperDNNModel(self.bidders * self.items, self.items, self.bidders, self.items, 2)
        
        generator_vector = 3

        self.w_net = MLP2([generator_vector,generator_vector*10,generator_vector*10,num_agent], t.tanh).to(device)   
        self.b_net = MLP([generator_vector,generator_vector*10,generator_vector*10,num_agent], t.tanh).to(device)  
        self.optimizers_auct = t.optim.Adam(self.w_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.optimizers_auct2 = t.optim.Adam(self.b_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        
    def update_auct(self):

        low = 0.
        high = 2. * np.random.uniform() + 0.5
        alpha = 1.5 * np.random.uniform()
        distribution = t.tensor([low, high, alpha]).to(device)
        generate_train = Generator(self.args)
        ctr_ads = generate_train.generate_uniform(0, 1.)
        ctr_og = generate_train.generate_uniform(0, 2.5)
        train_data = generate_train.generate_uniform(low, high)
        loss =  - t.mean(self.auct_model(train_data, ctr_ads, ctr_og, alpha, self.w_net(distribution), self.b_net(distribution))[0])
        self.optimizers_auct.zero_grad()
        self.optimizers_auct2.zero_grad()
        loss.backward()
        self.optimizers_auct.step()
        self.optimizers_auct2.step()



def cumulative_average(arr):
 
    cum_avg = np.zeros(len(arr))
  
    for k in range(1, len(arr) + 1):
        cum_avg[k - 1] = np.mean(arr[:k])
    
    return cum_avg


def calculate_t(t):
    t0 = t % 24
    if t0 <= 5:
        return np.max([(1.5 + 0.5 * (t0 / 5) + 0.1 * np.random.normal()),0.5])
    elif t0 > 5 and t0 <= 17:
        return np.max([(2. - (t0 - 5) / 12 + 0.1 * np.random.normal()),0.5])
    else:
        return np.max([(1. + 0.5 * (t0 - 17) /6 + 0.1 * np.random.normal()),0.5])


def filter_pairs(x, y):
 
    to_keep = np.ones(len(x), dtype=bool)

    for i in range(len(x)):
        if x[i] < 0 or y[i] < 0:
            to_keep[i] = False
            continue
        for j in range(len(x)):
            if i != j and x[j] > x[i] and y[j] > y[i]:
                to_keep[i] = False
                break

    x_filtered = x[to_keep]
    y_filtered = y[to_keep]
    
    return x_filtered, y_filtered


def calculate_distances(x,y, A, B, C, D):

    all_points = np.vstack((x, y, A, B, C, D))
 
    max_x = np.max(all_points[:, 0])
    max_y = np.max(all_points[:, 1])
    k = np.array([max_x, max_y])
 
    norm_x = all_points[:, 0] / max_x
    norm_y = all_points[:, 1] / max_y
    norm_all_points = np.column_stack((norm_x, norm_y))
    norm_k = np.array([1, 1])

    dist_x = np.linalg.norm(norm_k - norm_all_points[0])
    dist_y = np.linalg.norm(norm_k - norm_all_points[1])

    min_dist_A = np.min(np.linalg.norm(norm_all_points[2:2+len(A)] - norm_k, axis=1))

    min_dist_B = np.min(np.linalg.norm(norm_all_points[2+len(A):2+len(A)+len(B)] - norm_k, axis=1))

    min_dist_C = np.min(np.linalg.norm(norm_all_points[2+len(A)+len(B):2+len(A)+len(B)+len(C)] - norm_k, axis=1))

    min_dist_D = np.min(np.linalg.norm(norm_all_points[2+len(A)+len(B)+len(C):] - norm_k, axis=1))
    
    return dist_x, dist_y, min_dist_A, min_dist_B, min_dist_C, min_dist_D



def plot_interpolation2(points, step):

    points = np.array(points)

    points[:, :] *= 100
 
    points[points > 100] = 100

    def aggregate_data(data):
        aggregated_mean = []
        aggregated_lower = []
        aggregated_upper = []
        for i in range(0, len(data), 10):
            chunk = data[i:i+10]
            mean = np.mean(chunk)
            lower = np.percentile(chunk, 20)
            upper = np.percentile(chunk, 80)
            aggregated_mean.append(mean)
            aggregated_lower.append(lower)
            aggregated_upper.append(upper)
        return aggregated_mean, aggregated_lower, aggregated_upper
    
    def avg_data(data):
        data2 = np.zeros(len(data))
        for i in range(len(data)):
            if i == 0:
                data2[i] = data[i]
            else:
                data2[i] = (data2[i-1] * (i) + data[i]) / (i+1)
        return data2
    
    aggregated_points = [aggregate_data(points[:, i]) for i in range(points.shape[1])]
    
    x = range(len(aggregated_points[0][0]))  
    
    plt.figure()
    labels = ['AMMD (online)', 'AMMD (offline)', 'VCG', 'GSP', 'SW-VCG (offline)', 'SW-VCG (online)']
    colors = ['brown', 'red', 'green', 'blue', 'gray', 'violet']
    markers = ['d', 'p', 'o', 's', 'v', 'x']
    
    for i, (mean, lower, upper) in enumerate(aggregated_points):
        plt.plot(x, avg_data(mean), color=colors[i], linestyle='-', label=labels[i])
        plt.fill_between(x, avg_data(np.array(mean)-(np.array(mean)-np.array(lower))/3), avg_data(np.array(mean)+(np.array(upper)-np.array(mean))/3), color=colors[i], alpha=0.1)
        plt.scatter(x[::10], avg_data(mean)[::10], marker=markers[i], color=colors[i])
        
    plt.xlabel('Traffic Samples')
    plt.ylabel('Utopia Distance (%)')
    plt.title('Utopia Distance of Different Mechanisms')
    plt.legend()
    save_plot()
    plt.show(block=False)
    plt.close()
    
    
    for i, (mean, _, _) in enumerate(aggregated_points):
        print(f'{labels[i]}  {mean[-1]}')

def save_plot():
    """Save the current plot to the imgs directory with a timestamp."""
    if not os.path.exists('imgs'):
        os.makedirs('imgs')
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    plt.savefig(f'imgs/dynamic_utopia_plot_{timestamp}.pdf')

def train_linear(args, args2, net1, net2, net3, net4, net5, net12, net22, net32, net42, net52, rollouts):
    losspr2 = [0]
    losscost2 = [0]
    lossclick2 = [0]
    losspr4 = [0]
    losscost4 = [0]
    lossclick4 = [0]
    lossmy = []
    lossfp = []
    losssp = []  
    lossbid1 = []
    lossbid2 = []
    losspr12 = [0]
    losscost12 = [0]
    lossclick12 = [0]
    losspr22 = [0]
    losscost22 = [0]
    lossclick22 = [0]
    losspr32 = [0]
    losscost32 = [0]
    lossclick32 = [0]
    losspr42 = [0]
    losscost42 = [0]
    lossclick42 = [0]
    losspr52 = [0]
    losscost52 = [0]
    lossclick52 = [0]
    losspr62 = [0]
    losscost62 = [0]
    lossclick62 = [0]
    losspr72 = [0]
    losscost72 = [0]
    lossclick72 = [0]
    losspr82 = [0]
    losscost82 = [0]
    lossclick82 = [0]
    losspr122 = [0]
    losscost122 = [0]
    lossclick122 = [0]
    losspr222 = [0]
    losscost222 = [0]
    lossclick222 = [0]
    losspr322 = [0]
    losscost322 = [0]
    lossclick322 = [0]
    losspr422 = [0]
    losscost422 = [0]
    lossclick422 = [0]
    losspr522 = [0]
    losscost522 = [0]
    lossclick522 = [0]
    losspr622 = [0]
    losscost622 = [0]
    lossclick622 = [0]
    losspr722 = [0]
    losscost722 = [0]
    lossclick722 = [0]
    losspr822 = [0]
    losscost822 = [0]
    lossclick822 = [0]    
    perc1 = [0]
    perc2 = [0]
    perc3 = [0]
    perc4 = [0]
    generate_train = Generator(args)
    generate_train2 = Generator(args2)
    learner = Learner(args2)
    alpha = 0.5
    alpha2 = 0.5
    PID1 = PIDController(0.05, 0.002, 1., 0.5)
    PID2 = PIDController(0.05, 0.002, 0.1, 0.5)
    point = np.zeros([1,6])
    for i in range(rollouts):
    
            
        if i < 100 :
            learner.update_auct()
            low = 0.
            high = 0.5 + 0.5 * 1
            distribution = t.tensor([low, high, 0.5]).float().to(device)  
            train_data2 = generate_train2.generate_uniform(low, high)
            ctr_ads2 = generate_train2.generate_uniform(0, 1.)
            ctr_og2 = generate_train2.generate_uniform(0, 2.5)
            revenue, cost, click, perc = learner.auct_model(train_data2, ctr_ads2, ctr_og2, alpha, learner.w_net(distribution), learner.b_net(distribution))
            rev = revenue.cpu().detach().numpy()
            print(rev)
            
            for l in range(8):
                train_data = generate_train.generate_uniform(0,1.5)
                ctr_ads = generate_train.generate_uniform(0, 1.)
                ctr_og = generate_train.generate_uniform(0, 2.5)
                if l == 0:
                    net1.seller_backward(args,train_data,ctr_ads,ctr_og,0.01)
                if l == 1:
                    net2.seller_backward(args,train_data,ctr_ads,ctr_og,0.1)
                if l == 2:
                    net3.seller_backward(args,train_data,ctr_ads,ctr_og,0.2)
                if l == 3:
                    net4.seller_backward(args,train_data,ctr_ads,ctr_og,0.5)
                if l == 4:
                    net5.seller_backward(args,train_data,ctr_ads,ctr_og,1.0)

            for l in range(8):
                high = calculate_t(i)
                train_data = generate_train.generate_uniform(0,high)
                ctr_ads = generate_train.generate_uniform(0, 1.)
                ctr_og = generate_train.generate_uniform(0, 2.5)
                if l == 0:
                    net12.seller_backward(args,train_data,ctr_ads,ctr_og,0.01)
                if l == 1:
                    net22.seller_backward(args,train_data,ctr_ads,ctr_og,0.1)
                if l == 2:
                    net32.seller_backward(args,train_data,ctr_ads,ctr_og,0.2)
                if l == 3:
                    net42.seller_backward(args,train_data,ctr_ads,ctr_og,0.5)
                if l == 4:
                    net52.seller_backward(args,train_data,ctr_ads,ctr_og,1.0)

        
        else:
            learner.update_auct()
            low = 0.
            high = calculate_t(i)
            Alpha = alpha
            distribution = t.tensor([low, high, Alpha]).float().to(device)        
            train_data2 = generate_train2.generate_uniform2(low, high)
            ctr_ads2 = generate_train2.generate_uniform(0, 1.)
            ctr_og2 = generate_train2.generate_uniform(0, 2.5)
            revenue, cost, click, perc = learner.auct_model(train_data2, ctr_ads2, ctr_og2, Alpha, learner.w_net(distribution), learner.b_net(distribution))
            #ut = utility.cpu().detach().numpy()
            rev = revenue.cpu().detach().numpy()
            cost = cost.cpu().detach().numpy()
            click = click.cpu().detach().numpy()
            #perc = perc.cpu().detach().numpy()

            #lossbid2.append(np.mean(ut))               
            losspr2.append(rev)  
            losscost2.append(cost)
            lossclick2.append(click)
            perc2.append(perc)
            alpha = alpha * np.exp(PID1.update(perc))

            high = 1.5
            Alpha = alpha2
            distribution = t.tensor([low, high, Alpha]).float().to(device)        
            revenue, cost, click, perc = learner.auct_model(train_data2, ctr_ads2, ctr_og2, Alpha, learner.w_net(distribution), learner.b_net(distribution))
            #ut = utility.cpu().detach().numpy()
            rev = revenue.cpu().detach().numpy()
            cost = cost.cpu().detach().numpy()
            click = click.cpu().detach().numpy()
            #perc = perc.cpu().detach().numpy()

            #lossbid4.append(np.mean(ut))               
            losspr4.append(rev)  
            losscost4.append(cost)
            lossclick4.append(click)
            perc4.append(perc)
            alpha2 = alpha2 * np.exp(PID2.update(perc))

            for l in range(8):      
                train_data = generate_train.generate_uniform(0,1.5)
                ctr_ads = generate_train.generate_uniform(0, 1.)
                ctr_og = generate_train.generate_uniform(0, 2.5)
                if l == 0:
                    if i % 1 == 0:
                        net1.seller_backward(args,train_data,ctr_ads,ctr_og,0.01)                 
                if l == 1:
                    if i % 1 == 0:
                        net2.seller_backward(args,train_data,ctr_ads,ctr_og,0.1)
                if l == 2:
                    if i % 1 == 0:
                        net3.seller_backward(args,train_data,ctr_ads,ctr_og,0.2)                 
                if l == 3:
                    if i % 1 == 0:
                        net4.seller_backward(args,train_data,ctr_ads,ctr_og,0.5)              
                if l == 5:
                    if i % 1 == 0:
                        net5.seller_backward(args,train_data,ctr_ads,ctr_og,1.0)

            
            high = calculate_t(i)
            train_data = generate_train.generate_uniform(0,high)
            for l in range(8):      
                #train_data = generate_train.generate_uniform(0,1.)
                ctr_ads = generate_train.generate_uniform(0, 1.)
                ctr_og = generate_train.generate_uniform(0, 2.5)
                if l == 0:
                    if i % 1 == 0:
                        net12.seller_backward(args,train_data,ctr_ads,ctr_og,0.01)
               
                if l == 1:
                    if i % 1 == 0:
                        net22.seller_backward(args,train_data,ctr_ads,ctr_og,0.1)
                if l == 2:
                    if i % 1 == 0:
                        net32.seller_backward(args,train_data,ctr_ads,ctr_og,0.2)
                
                if l == 3:
                    if i % 1 == 0:
                        net42.seller_backward(args,train_data,ctr_ads,ctr_og,0.5)
            
                if l == 5:
                    if i % 1 == 0:
                        net52.seller_backward(args,train_data,ctr_ads,ctr_og,1.0)

        



        if i%2 == 0 :
            print('i=',i)
        if i%100 == 0 and i>100 :    
        #if i%20 == 0 and i>200  :
            plt.plot(losspr2[-20:], label='revenue')
            plt.plot(perc2[-20:], label='percentage')
            plt.show(block=False)       
            plt.close()
            
        if i%10 == 0 and i>99 :
            beta = 0.1
            args0 = Args((4,1,"uniform",10,10,50,50,1)) 
            generate_train0 = Generator(args0)
            train_data = generate_train0.generate_uniform(0,1.5)
            ctr_ads = generate_train0.generate_uniform(0, 1.)
            ctr_og = generate_train0.generate_uniform(0, 2.5)    

            v0 = alpha_VCG_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=1.)     
            v1 = alpha_VCG_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.99)  
            v2 = alpha_VCG_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.95)  
            v3 = alpha_VCG_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.9)  
            v4 = alpha_VCG_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.8)  
            VCG_IC = 100*np.array([v0,v1,v2,v3,v4])/v0
            v0 = alpha_GSP_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=1.)     
            v1 = alpha_GSP_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.99)  
            v2 = alpha_GSP_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.95)  
            v3 = alpha_GSP_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.9)  
            v4 = alpha_GSP_Mechanism_IC(train_data,ctr_ads,ctr_og,beta,bid=0.8)  
            GSP_IC = 100*np.array([v0,v1,v2,v3,v4])/v0
            v0 = net1.ic_check(train_data,ctr_ads,ctr_og,beta,bid=1.)  
            v1 = net1.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.99)  
            v2 = net1.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.95)  
            v3 = net1.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.9)  
            v4 = net1.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.8)  
            SW_VCG_Static_IC = 100*np.array([v0,v1,v2,v3,v4])/v0
            v0 = net12.ic_check(train_data,ctr_ads,ctr_og,beta,bid=1.)  
            v1 = net12.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.99)  
            v2 = net12.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.95)  
            v3 = net12.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.9)  
            v4 = net12.ic_check(train_data,ctr_ads,ctr_og,beta,bid=0.8)  
            SW_VCG_Dynamic_IC = 100*np.array([v0,v1,v2,v3,v4])/v0
            distribution = t.tensor([0, 1.5, alpha]).float().to(device)  
            v0 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha, learner.w_net(distribution), learner.b_net(distribution),bid=1.)
            v1 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha, learner.w_net(distribution), learner.b_net(distribution),bid=0.99)
            v2 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha, learner.w_net(distribution), learner.b_net(distribution),bid=0.95)
            v3 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha, learner.w_net(distribution), learner.b_net(distribution),bid=0.9)
            v4 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha, learner.w_net(distribution), learner.b_net(distribution),bid=0.8)
            SW_VCG_Online_IC = 100*np.array([v0,v1,v2,v3,v4])/v0
            distribution = t.tensor([0, 1.5, alpha2]).float().to(device)  
            v0 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha2, learner.w_net(distribution), learner.b_net(distribution),bid=1.)
            v1 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha2, learner.w_net(distribution), learner.b_net(distribution),bid=0.99)
            v2 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha2, learner.w_net(distribution), learner.b_net(distribution),bid=0.95)
            v3 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha2, learner.w_net(distribution), learner.b_net(distribution)  ,bid=0.9)
            v4 = learner.auct_model.ic_check(train_data,ctr_ads,ctr_og,alpha2, learner.w_net(distribution), learner.b_net(distribution)  ,bid=0.8)
            SW_VCG_Offline_IC = 100*np.array([v0,v1,v2,v3,v4])/v0
            x = np.array([1.,0.99,0.95,0.9,0.8])
            plt.figure()
            plt.plot(x,VCG_IC,label='VCG')
            plt.plot(x,GSP_IC,label='GSP')
            plt.plot(x,SW_VCG_Static_IC,label='SW_VCG_Offline')
            plt.plot(x,SW_VCG_Dynamic_IC,label='SW_VCG_Online')
            plt.plot(x,SW_VCG_Online_IC,label='AMMD_Online')
            plt.plot(x,SW_VCG_Offline_IC,label='AMMD_Offline')
            plt.xlabel('Bid_ratio')
            plt.ylabel('Utility_ratio (%)')
            plt.title('IC test of Different Mechanisms')
            plt.legend()
            save_plot()
            plt.show(block=False)
            plt.close()
            


if __name__ == "__main__":
    args = Args((4,1,"uniform",10,10,2000,2000,1))  
    args2 = Args((4,1,"uniform",10,10,2000,2000,1))  
    net1 = Score_VCG(args)
    net2 = Score_VCG(args)
    net3 = Score_VCG(args)
    net4 = Score_VCG(args)
    net5 = Score_VCG(args)
    net12 = Score_VCG(args)
    net22 = Score_VCG(args)
    net32 = Score_VCG(args)
    net42 = Score_VCG(args)
    net52 = Score_VCG(args)
    train_linear(args, args2, net1, net2, net3, net4, net5, net12, net22, net32, net42, net52, rollouts=100001)