import os
import argparse
import torch
import time
from torch.utils.data import DataLoader,random_split
from networks import net
from network_MLP import net2
from utils import *
from sklearn.metrics import mean_squared_error,accuracy_score
import matplotlib.pyplot as plt


parser=argparse.ArgumentParser(description='Constrained_learning')
parser.add_argument('--batch_size',default=128,help='batch_size',type=int)
parser.add_argument('--lr',default=1e-3,help='learning rate',type=float)
parser.add_argument('--epochs',default=3000,help='epoch',type=int)

args=parser.parse_args()

def compute_dual_loss(in_list,out_list,mu_lower,t_lower):
    Lagrangian_loss=0
    length=len(in_list)
    t_lower=torch.FloatTensor(t_lower)
    mu_lower_batch=np.tile(mu_lower,(128,1))
    mu_lower_batch=torch.FloatTensor(mu_lower_batch)

    grad_mat=torch.zeros([128,4],dtype=torch.float32)
    for i in range(length):
        xx=in_list[i]
        yy=out_list[i]
        grad_input=torch.autograd.grad(torch.sum(yy[:,0]),xx,create_graph=True,allow_unused=True)[0]

        # record the gradient for dual update
        grad_mat=grad_input

        indicator=t_lower-grad_input
        indicator[indicator<0]=0
        loss_dim=torch.sum(indicator*mu_lower_batch,dim=1)
        Lagrangian_loss+=loss_dim
    return Lagrangian_loss,grad_mat




def train(args):
    data_path='synthetic_111.csv'
    dataset=CustomDataset(csv_path=data_path)
    dataset_size=len(dataset)
    train_size=int(dataset_size*0.7)
    test_size=dataset_size-train_size

    train_dataset,test_dataset=random_split(dataset,[train_size,test_size])

    train_loader=DataLoader(train_dataset,batch_size=args.batch_size,shuffle=True,drop_last=False)
    test_loader=DataLoader(test_dataset,batch_size=test_size,shuffle=True,drop_last=False)

    model=net(input_size=2,mono_size=1,
              mono_feature=np.asarray([1]))

    param_amout=0
    for p in model.named_parameters():
        param_amout+=p[1].numel()
    print('The total param amount:',param_amout)

    criterion=torch.nn.MSELoss()
    optimizer=torch.optim.Adam(model.parameters(),lr=float(args.lr))

    num_epochs=args.epochs
    total_batch=len(train_loader)
    model.train()
    max_acc=0.000001

    ## Set the auxiliary variable and dual variables
    mu_lower=np.zeros((1,),dtype=np.float32)
    t_lower=np.ones((1,),dtype=np.float32)*0.0001
    ZEROS=[0]*128
    alpha=0.1 # The violation probability
    lr_mu=5
    mu_lower_rec=[]

    start=time.time()
    for epoch in range(num_epochs):
        train_loss=0.0
        model.train()
        for batch,(inputs,targets) in enumerate(train_loader):
            outputs=model(inputs)
            loss=criterion(outputs,targets)

            in_list,out_list=model.reg_forward(feature_num=2,num=128)
            Lagrangian_loss,grad_mat=compute_dual_loss(in_list,out_list,mu_lower,t_lower)
            Lagrangian_loss=Lagrangian_loss.mean()
            grad_mat=grad_mat.detach().cpu().numpy()
            total_loss=loss#+Lagrangian_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            train_loss+=loss.item()

            ## The dual update
            for j in range(mu_lower.shape[0]):
                grad_mu_lower=np.mean(np.max((t_lower[j]-grad_mat[:,j],ZEROS),axis=0)-alpha*t_lower[j])

                mu_lower[j]=max(mu_lower[j]+lr_mu*(grad_mu_lower-0.0*mu_lower[j]),0)

            # record the dual variables
            mu_lower_rec.append(mu_lower.copy())

        train_loss=train_loss/total_batch

        if (epoch+1)%10==0:
            with torch.no_grad():
                model.eval()
                for j, (x, y) in enumerate(test_loader):
                    pred_y = model(x)
                    test_loss = criterion(pred_y, y)

                    end=time.time()
                    elapsed=end-start

            print('Epoch:[{}/{}], Training Loss: {:.5f}, Testing loss:{:.5f}, Elapsed time (s): {:.4f}'.format(epoch+1,num_epochs,train_loss,test_loss.item(),elapsed))

    torch.save(model,"model_pth/model.pth")
    contour_2d(data=dataset.inp, feature_idx=np.asarray([0, 1]), model=model)
    contour_2d_original(data=dataset.inp, feature_idx=np.asarray([0, 1]), model=model)

    # summation of negative gradients
    in_list,out_list=model.reg_forward(feature_num=2,num=128)
    _,grad_mat1=compute_dual_loss(in_list,out_list,mu_lower,t_lower)
    grad_mat1[grad_mat1>0]=0
    grad_mat[grad_mat>0]=0
    print('The sum of negative gradient: {} and in training: {}'.format(torch.sum(grad_mat1),np.sum(grad_mat)))

    return mu_lower_rec






if __name__=="__main__":
    # # set the seed
    # seed = 15
    # torch.manual_seed(seed)

    mu_lower_rec=train(args)

    ## Plot the dual variables
    plt.figure(1)
    mu_lower_rec = np.asarray(mu_lower_rec)
    timeLine = list(range(mu_lower_rec.shape[0]))
    for i in range(mu_lower_rec.shape[1]):
        plt.plot(timeLine, mu_lower_rec[:, i])
    plt.xlabel('training episode')
    plt.ylabel('dual variables')
    plt.grid(True)
    plt.show()
