import os
import argparse
import torch
import time
from torch.utils.data import DataLoader
from compas_loader import *
from networks import net
from network_MLP import net2
from sklearn.metrics import mean_squared_error,accuracy_score
import matplotlib.pyplot as plt

parser=argparse.ArgumentParser(description='Constrained_learning')
parser.add_argument('--batch_size',default=256,help='batch_size',type=int)
parser.add_argument('--lr',default=5e-4,help='learning rate',type=float)
parser.add_argument('--epochs',default=1000,help='epoch',type=int)

args=parser.parse_args()

def compute_dual_loss(in_list,out_list,mu_lower,t_lower):
    Lagrangian_loss=0
    length=len(in_list)
    t_lower=torch.FloatTensor(t_lower)
    mu_lower_batch=np.tile(mu_lower,(128,1))
    mu_lower_batch=torch.FloatTensor(mu_lower_batch)

    grad_mat=torch.zeros([128,4],dtype=torch.float32)
    for i in range(length):
        xx=in_list[i]
        yy=out_list[i]
        grad_input=torch.autograd.grad(torch.sum(yy[:,0]),xx,create_graph=True,allow_unused=True)[0]

        # record the gradient for dual update
        grad_mat=grad_input

        indicator=t_lower-grad_input
        indicator[indicator<0]=0
        loss_dim=torch.sum(indicator*mu_lower_batch,dim=1)
        Lagrangian_loss+=loss_dim
    return Lagrangian_loss,grad_mat

def train(args):
    data_train_path='train.csv'
    data_test_path='test.csv'
    train_dataset=COMPASdataLoader(csv_path=data_train_path)
    test_dataset=COMPASdataLoader(csv_path=data_test_path)
    print("Training dataset size:",len(train_dataset))
    print("Testing dataset size:",len(test_dataset))
    train_loader=DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True,drop_last=False)
    test_loader=DataLoader(test_dataset,batch_size=len(test_dataset),shuffle=True,drop_last=False)

    model=net(input_size=13,mono_size=4,
              mono_feature=np.asarray([0,1,2,3]))

    param_amout=0
    for p in model.named_parameters():
        param_amout+=p[1].numel()
    print('The total param amount:',param_amout)

    criterion=torch.nn.BCEWithLogitsLoss()
    optimizer=torch.optim.Adam(model.parameters(),lr=float(args.lr))

    num_epochs=args.epochs
    total_batch=len(train_loader)
    model.train()
    max_acc=0.000001

    ## Set the auxiliary variable and dual variables
    mu_lower=np.zeros((4,),dtype=np.float32)
    t_lower=np.ones((4,),dtype=np.float32)*0.0001
    ZEROS=[0]*128
    alpha=0.1 # The violation probability
    lr_mu=5
    mu_lower_rec=[]

    start=time.time()
    for epoch in range(num_epochs):
        train_loss=0.0
        model.train()
        for batch,(inputs,targets) in enumerate(train_loader):
            outputs=model(inputs)
            loss=criterion(outputs,targets)

            in_list,out_list=model.reg_forward(feature_num=13,num=128)
            Lagrangian_loss,grad_mat=compute_dual_loss(in_list,out_list,mu_lower,t_lower)
            Lagrangian_loss=Lagrangian_loss.mean()
            grad_mat=grad_mat.detach().cpu().numpy()
            total_loss=loss+Lagrangian_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            train_loss+=loss.item()

            ## The dual update
            for j in range(mu_lower.shape[0]):
                grad_mu_lower=np.mean(np.max((t_lower[j]-grad_mat[:,j],ZEROS),axis=0)-alpha*t_lower[j])

                mu_lower[j]=max(mu_lower[j]+lr_mu*(grad_mu_lower-0.0*mu_lower[j]),0)

            # record the dual variables
            mu_lower_rec.append(mu_lower.copy())

        train_loss=train_loss/total_batch
        with torch.no_grad():
            model.eval()
            for j,(x,y) in enumerate(test_loader):
                pred_y=model(x)
                test_loss=criterion(pred_y,y)

                true_y=y.detach().numpy()
                pred_y=pred_y.detach().numpy()
                pred_y_zero_one=np.where(pred_y>0,1,0)

                acc=accuracy_score(true_y,pred_y_zero_one)

                if acc>=max_acc:
                    torch.save(model,"./model.pth")
                max_acc=max(max_acc,acc)

        if (epoch+1)%10==0:
            print('Epoch:{}, Training Loss: {}, Testing acc:{}, Test max acc: {}'.format(epoch+1,train_loss,acc,max_acc))
    # summation of negative gradients
    in_list,out_list=model.reg_forward(feature_num=13,num=128)
    _,grad_mat1=compute_dual_loss(in_list,out_list,mu_lower,t_lower)
    grad_mat1[grad_mat1>0]=0
    grad_mat[grad_mat>0]=0
    print('The sum of negative gradient: {} and in training: {}'.format(torch.sum(grad_mat1),np.sum(grad_mat)))

    return mu_lower_rec

if __name__=="__main__":
    mu_lower_rec=train(args)

    ## Plot the dual variables
    plt.figure(1)
    mu_lower_rec = np.asarray(mu_lower_rec)
    timeLine = list(range(mu_lower_rec.shape[0]))
    for i in range(mu_lower_rec.shape[1]):
        plt.plot(timeLine, mu_lower_rec[:, i])
    plt.xlabel('training episode')
    plt.ylabel('dual variables')
    plt.grid(True)
    plt.show()
