import torch
torch.set_default_dtype(torch.float64)
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import operator
from functools import reduce
from sklearn.preprocessing import normalize
import numpy as np
import time, argparse
from pprint import pprint
from utils import set_seed, dict_agg
import default_args
import warnings
from dataset import Dataset as D
from helper_qp import NNPrimalSolver, NNDualSolver, load_qp_data
import json
# from helper_portfolio import NNPrimalPortfolioSolver, NNDualPortfolioSolver, load_portfolio_data
from helper_new_portfolio import NNPrimalSolver, NNDualSolver, load_portfolio_data, load_portfolio_dyn_data
from pathlib import Path
import pandas as pd


# Create a DataFrame from a dictionary

CURRENT_PATH = Path(__file__).absolute().parent
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("DEVICE", DEVICE, flush=True)
warnings.filterwarnings("ignore")

def main():
    parser = argparse.ArgumentParser(description='Primal Dual Learning')
    parser.add_argument('--seed', type=int, default=1001, help='random seed')
    parser.add_argument('--probtype', type=str, default='portfolio', choices=['portfolio'], help='problem types')
    # QP cases specific parameters
    parser.add_argument('--nvar', type=int, help='the number of decision variables')
    parser.add_argument('--nineq', type=int, help='the number of inequality constraints')
    parser.add_argument('--neq', type=int, help='the number of equality constraints')
    parser.add_argument('--nex', type=int, help='total number of data instances')
    parser.add_argument('--LSTMmodelindex', type=int, default = 29, help='LD penalty coefficient update epoch frequency')
    parser.add_argument('--lambda', type=float, default=2.,
                        help='regularization term amplitude of the opt. prob obj. function')
    # Related to training & neural nets
    parser.add_argument('--batchsize', type=int, help='training batch size')
    parser.add_argument('--nworkers', type=int, default=0, help='the number of workers for dataloader')
    parser.add_argument('--lr', type=float, help='neural network learning rate')
    parser.add_argument('--hiddensize', type=int, default=500, help='hidden layer size for neural network (used for QP and QCQP cases)')
    parser.add_argument('--hiddenfrac', type=float, default=None, help='hidden layer node fraction (only used for ACOPF)')
    parser.add_argument('--nlayer', type=int, help='the number of layers')
    parser.add_argument('--use_sigmoid', type=bool, help='whether to apply a sigmoid to the last layer')
    parser.add_argument('--normalize', type=bool, default = False, help='whether to apply a normalization to the neural network input')
    # PDL specific hyperparameters
    parser.add_argument('--rho', type=float, help='initial coefficient of the penalty terms')
    parser.add_argument('--rhomax', type=float, help='maximum rho limit')
    parser.add_argument('--tau', type=float, help='parameter for updating rho')
    parser.add_argument('--alpha', type=float, help='updating rate of rho')
    parser.add_argument('--objscaler', type=float, default=None, help='scaling objective value (only used for ACOPF')
    parser.add_argument('--maxouteriter', type=int, help='maximum outer iterations')
    parser.add_argument('--epochs', type=int, help='maximum (inner) epochs')
    parser.add_argument('--maxinneriter', type=int, help='maximum inner iterations')
    parser.add_argument('--index', type=int, help='index to keep track of different runs')
    parser.add_argument('--save', type=bool, default=False, help='whether to save statistics')
    parser.add_argument('--useLSTM', type=bool, default=True, help='whether to save statistics')

    args = vars(parser.parse_args()) # to dictionary
    args_default = default_args.pdl_default_args(args['probtype'])

    for k,v in args_default.items():
        args[k] = v if args[k] is None else args[k]
    pprint(args) # print out args

    set_seed(args['seed'],DEVICE)

    # load data

    if 'portfolio' in args['probtype']:
        data, args = load_portfolio_dyn_data(args, CURRENT_PATH, DEVICE)
    else:
        raise NotImplementedError
    print("Loading Data Done Successfully:", str(data))


    t0 = time.time()
    out, pnet, dnet, best_results, best_results_batch = train_net(data, args)
    print("Train is done, Elapsed Time %.2fs"%(time.time()-t0), flush=True)

    data = {
            #'Probtype': [args['probtype']],
            #'Lamg': [args['lamg']],
            #'lamh' : [args['lamh']],
            #'ldupdatefreq' : [args['ldupdatefreq']],
            #'ldstepsize' :[args['ldstepsize']],
            #'lr' : [ args['lr']],
            #'nlayer' : [args['nlayer']],
            #'Seed': [args['seed']],
            'Seed': [args['seed']],
            'Losstype': [args['losstype']],
            'Max eq. viol': [best_results[0]],
            'Mean eq. viol': [best_results[1]],
            'Max ineq. viol.': [best_results[2]],
            'Mean ineq. viol.': [best_results[3]],
            'Mean opt.gap': [best_results[4]],
            'Mean opt.gap batch': [best_results[5]],

            'Max eq. viol batch': [best_results_batch[0]],
            'Mean eq. viol batch': [best_results_batch[1]],
            'Max ineq. viol. batch': [best_results_batch[2]],
            'Mean ineq. viol. batch': [best_results_batch[3]],
            'Mean opt.gap per batch': [best_results_batch[4]],
            'Mean opt.gap batch per batch': [best_results_batch[5]],
        }

    df = pd.DataFrame(data)
    df.to_csv('pdl_dynamic_results.csv', mode='a', header=True, index=False)
    
    # Display the DataFrame
    print(df)

    if args['save']:
        save_dir = CURRENT_PATH/"results"/args['probtype']
        save_dir.mkdir(exist_ok=True, parents=True)
        save_name = "PDL_%s_s%d.chpt"%(str(data),args['seed'])
        save_fn = save_dir/save_name
        save_dict = {
            'pnet': pnet.to('cpu').state_dict(),
            'dnet': dnet.to('cpu').state_dict(),
            'args': args,
            'out': out
        }
        torch.save(save_dict, save_fn)

    return None

def feature_Generation(feature_Net, train_data, valid_data, test_data):
    return feature_Net(train_data.to(DEVICE)), feature_Net(valid_data.to(DEVICE)), feature_Net(test_data.to(DEVICE))


def train_net(data, args):
    if 'portfolio' in args['probtype']:
        train_dataset = PDLDataSet(data.trainX)
        valid_dataset = PDLDataSet(data.validX)
        test_dataset = PDLDataSet(data.testX)
        update_pnet = update_pnet_epoch
        update_dnet = update_dnet_epoch
        update_rho = update_rho_epoch
        shuffle = True
        valid_bs = len(valid_dataset)

    train_loader = DataLoader(train_dataset, batch_size=args['batchsize'], shuffle=False, num_workers=args['nworkers'])
    #train_loader = DataLoader(train_dataset, batch_size=args['batchsize'], shuffle=True, num_workers=args['nworkers'])
    valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=args['nworkers'])
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=args['nworkers'])
    #valid_loader = DataLoader(valid_dataset, batch_size=valid_bs, num_workers=args['nworkers'])
    #test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), num_workers=args['nworkers'])

    if args['probtype'] in ['convexqp', 'nonconvexqp', 'portfolio', 'predopt_portfolio']: ### VDVF
        pnet = NNPrimalSolver(data, args).to(DEVICE)
    else:
        raise NotImplementedError

    pnet.train()

    if args['probtype'] in ['convexqp', 'nonconvexqp', 'portfolio', 'predopt_portfolio']:
        dnet = NNDualSolver(data, args).to(DEVICE)
        dnet_k = NNDualSolver(data, args).to(DEVICE)
    else:
        raise NotImplementedError
    dnet.train(); dnet_k.eval()

    viol_train, viol_valid = None, None
    rho = args['rho']

    feature_Generator = 0

    parameter_regressor_Net = 0
    global min_opt_gap
    min_opt_gap = 1000

    for k in range(args['maxouteriter']): # outer iteration k
        print("Outer iter: %3d"%(k), flush=True)
        out = update_pnet(pnet,dnet,feature_Generator,parameter_regressor_Net, train_loader,valid_loader,test_loader,data,rho,args)
        if out['opt_gap_mean'] < min_opt_gap:
            min_opt_gap = out['opt_gap_mean']
            max_eq = out['eq_max']
            mean_eq = out['eq_mean']
            best_results = [min_opt_gap, max_eq, mean_eq]
        if k == args['maxouteriter']-1:
            break
        viol_train, viol_valid, rho_updated = update_rho(pnet,dnet,feature_Generator,k, train_loader,valid_loader,data,rho,args,prev_viol_train=viol_train,prev_viol_valid=viol_valid)
        update_dnet(pnet, dnet, dnet_k, feature_Generator, train_loader, valid_loader, data, rho, args)
        print("*** update rho:", rho, min(rho_updated, args['rhomax']), flush=True)
        rho = min(rho_updated, args['rhomax'])

    return out, pnet, dnet, best_results, best_results


def update_pnet_epoch(pnet, dnet, feature_Generator, parameter_regressor_Net, train_loader, valid_loader, test_loader, data, rho, args):
    popt = optim.Adam(pnet.parameters(), lr=args['lr'])
    scheduler = torch.optim.lr_scheduler.StepLR(popt, step_size=1, gamma=0.99)

    best_valid_loss = float("inf")

    ### VDVF
    ### Defining the feature generator network

    if args['useLSTM']==False:
        init_X_train = np.load(f'portfolio_data/pred_asset_prices_training_{int(.8*args["nex"])}.npy')
    else:
        index = args['LSTMmodelindex']
        ntrain = 1000
        init_X_train = np.load(f'portfolio_data/{index}_pred_asset_prices_training_{ntrain}.npy')

    mse = torch.nn.MSELoss()

    for i in range(args['epochs']):
        t0 = time.time()
        epoch_stats = {}
        train_ploss_ = 0.
        pnet.train();
        dnet.eval()
        for j,x in enumerate(train_loader):
            popt.zero_grad()
            ### VDVF
            ### Generating the features by feeding the feature generator network with problem parameters
            x_feat = torch.from_numpy(init_X_train[j*args['batchsize']:(j+1)*args['batchsize'], :]).double()

            if args['normalize']:
                x_feat = torch.nn.functional.normalize(x_feat)
            #if 'predopt' in args['probtype']:
            #    x_hat = parameter_regressor_Net(x_feat)
            #print(mse(x_feat, x))

            y = pnet(x_feat)
            #print(y[0,:])
            #if np.random.randint(1,100)==2:
            #    print("Primal output: ", y[0,:])
            #print('Inside update_pnet_epoch:')
            #print('x from train_loader, y from pnet(x)')
            #print('x:')
            #print( x  )
            #print('y:')
            #print( y  )
            #print('y.sum(1):')
            #print( y.sum(1)  )
            #input('waiting')
            lamda, mu = dnet(x_feat)
            if mu is not None:
                mu = torch.clamp(mu, min=0.)
            train_ploss = total_loss_primal(data, x, y, lamda, mu, rho, args)
            train_ploss.mean().backward()
            popt.step()
            train_ploss_ += train_ploss.mean().item()

        train_ploss_ /= (len(train_loader))
        t1 = time.time()
        pnet.eval(); dnet.eval()
        for x in valid_loader:
            epoch_stats = eval_net(data, x, pnet, dnet, feature_Generator, parameter_regressor_Net, rho, args, 'valid', epoch_stats, -1)

        if epoch_stats['valid_primal_loss'].mean()<best_valid_loss-1e-3:
            best_valid_loss = epoch_stats['valid_primal_loss'].mean()
        else:
            scheduler.step()

        if i%10 == 0:
            print("P Epoch:%05d | loss:%.4f | time:%.4fs"%(i, train_ploss_, t1-t0), flush=True)
            print("        valid | loss:%.4f | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f"%(
                np.mean(epoch_stats['valid_primal_loss']), np.mean(epoch_stats['valid_eval']),
                np.mean(epoch_stats['valid_ineq_max']), np.mean(epoch_stats['valid_ineq_mean']),
                np.mean(epoch_stats['valid_eq_max']),np.mean(epoch_stats['valid_eq_mean'])
            ), flush=True)

    for j,x in enumerate(train_loader):
        epoch_stats = eval_net(data, x, pnet, dnet,feature_Generator, parameter_regressor_Net, rho, args, 'train', epoch_stats, j)

    for x in test_loader:
        epoch_stats = eval_net(data, x, pnet, dnet,feature_Generator, parameter_regressor_Net, rho, args, 'test', epoch_stats, -1)
        epoch_stats = eval_net(data, x, pnet, dnet, feature_Generator, parameter_regressor_Net, rho, args, 'test_gt', epoch_stats, -1)

    tmp = epoch_stats['test_opt_gap_mean']
    print("Quantile of opt. gap", np.quantile(tmp, [0.25, 0.5, 0.75, 1]))

    print("        train |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f"%(
                np.mean(epoch_stats['train_eval']),
                np.mean(epoch_stats['train_ineq_max']), np.mean(epoch_stats['train_ineq_mean']),
                np.mean(epoch_stats['train_eq_max']),np.mean(epoch_stats['train_eq_mean'])
            ), flush=True)

    print("         test |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f | optgap max:%.4f mean:%.4f"%(
        np.mean(epoch_stats['test_eval']),
        np.mean(epoch_stats['test_ineq_max']), np.mean(epoch_stats['test_ineq_mean']),
        np.mean(epoch_stats['test_eq_max']),np.mean(epoch_stats['test_eq_mean']),
        100*np.max(epoch_stats['test_opt_gap_mean']),100*np.mean(epoch_stats['test_opt_gap_mean']) ### VDVF: adding the max. opt. gap for evaluation
    ), flush=True)

    print("      test gt |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f"%(
        np.mean(epoch_stats['test_gt_eval']),
        np.mean(epoch_stats['test_gt_ineq_max']), np.mean(epoch_stats['test_gt_ineq_mean']),
        np.mean(epoch_stats['test_gt_eq_max']),np.mean(epoch_stats['test_gt_eq_mean'])
    ), flush=True)

    '''
    quartile_file.write(str(np.quantile(tmp, [.9]).item()) + " , ")
    opt_gap_max_file.write(str(100 * np.max(epoch_stats['test_opt_gap_max'])) + " , ")
    opt_gap_mean_file.write(str(100 * np.mean(epoch_stats['test_opt_gap_mean'])) + " , ")
    eq_viol_max_file.write(str(np.mean(epoch_stats['test_eq_max'])) + " , ")
    eq_viol_mean_file.write(str(np.mean(epoch_stats['test_eq_mean'])) + " , ")
    ineq_viol_max_file.write(str(np.mean(epoch_stats['test_ineq_max'])) + " , ")
    ineq_viol_mean_file.write(str(np.mean(epoch_stats['test_ineq_mean'])) + " , ")
    '''

    out = {
        'obj': np.mean(epoch_stats['test_eval']),
        'eq_max': np.mean(epoch_stats['test_eq_max']),
        'ineq_max': np.mean(epoch_stats['test_ineq_max']),
        'opt_gap_max': 100 * np.max(epoch_stats['test_opt_gap_mean']), ### VDVF: adding the max. opt. gap for evaluation
        'eq_mean': np.mean(epoch_stats['test_eq_mean']),
        'ineq_mean': np.mean(epoch_stats['test_ineq_mean']),
        'opt_gap_mean': 100 * np.mean(epoch_stats['test_opt_gap_mean']),
    }
    return out


def update_rho_epoch(pnet,dnet,feature_Generator ,k,train_loader,valid_loader,data,rho,args,prev_viol_train=None,prev_viol_valid=None):
    torch.set_grad_enabled(False)
    pnet.eval(); dnet.eval()
    viol_train = []

    if args['useLSTM']==False:
        init_X_train = np.load(f'portfolio_data/pred_asset_prices_training_{int(.8*args["nex"])}.npy')
        init_X_valid = np.load(f'portfolio_data/pred_asset_prices_validation_{int(.1*args["nex"])}.npy')

    else:
        index = args['LSTMmodelindex']
        ntrain = 1000
        ntest = 100
        init_X_train = np.load(f'portfolio_data/{index}_pred_asset_prices_training_{ntrain}.npy')
        init_X_valid = np.load(f'portfolio_data/{index}_pred_asset_prices_validation_{ntest}.npy')


    for j,x in enumerate(train_loader):
        x_feat = torch.from_numpy(init_X_train[j*args['batchsize']:(j+1)*args['batchsize'], :]).double()
        if args['normalize']:
            x_feat = torch.nn.functional.normalize(x_feat)
        y = pnet(x_feat)
        lamda, mu = dnet(x_feat)
        mu = torch.clamp(mu, min=0.)
        viol_train_ = None
        if args['neq']>0:
            eq_val = data.eq_resid(x,y)
            viol_train_ = eq_val.abs().max(dim=1).values
        if args['nineq']>0:
            ineq_val = data.ineq_resid(x,y)
            ineq_viol = torch.maximum(ineq_val,-mu/rho).max(dim=1).values
            viol_train_ = ineq_viol if viol_train_ is None else torch.maximum(viol_train_, ineq_viol)
        viol_train.append(viol_train_)
    viol_train = torch.cat(viol_train)

    viol_valid = []
    for x in valid_loader:
        x_feat = torch.from_numpy(init_X_valid).double()
        if args['normalize']:
            x_feat = torch.nn.functional.normalize(x_feat)
        y = pnet(x_feat)
        lamda, mu = dnet(x_feat)
        mu = torch.clamp(mu, min=0.)
        viol_valid_ = None
        if args['neq']>0:
            eq_val = data.eq_resid(x,y)
            viol_valid_ = eq_val.abs().max(dim=1).values
        if args['nineq']>0:
            ineq_val = data.ineq_resid(x,y)
            ineq_viol = torch.maximum(ineq_val, -mu/rho).max(dim=1).values
            viol_valid_ = ineq_viol if viol_valid_ is None else torch.maximum(viol_valid_, ineq_viol)
        viol_valid.append(viol_valid_)
    viol_valid = torch.cat(viol_valid)

    viol_train = viol_train.max().item() # max over instances
    viol_valid = viol_valid.max().item()

    if k==0:
        print("viol_train:", viol_train, "viol_valid:", viol_valid)
        rho_updated = rho
        torch.set_grad_enabled(True)
        return viol_train, viol_valid, rho_updated,
    else: # k>0
        print("prev_viol_train:", prev_viol_train, "prev_viol_valid:", prev_viol_valid, flush=True)
        print("viol_train:", viol_train, "viol_valid:", viol_valid, flush=True)
        if viol_train>args['tau']*prev_viol_train:
            rho_updated = args['alpha']*rho
        else:
            rho_updated = rho
        torch.set_grad_enabled(True)
        return viol_train, viol_valid, rho_updated

def update_dnet_epoch(pnet,dnet,dnet_k,feature_Generator,train_loader,valid_loader,data,rho,args):
    dnet_k.load_state_dict(dnet.state_dict())
    dopt = optim.Adam(dnet.parameters(), lr=args['lr'])
    scheduler = torch.optim.lr_scheduler.StepLR(dopt, step_size=1, gamma=0.99)
    best_valid_loss = float("inf")

    if args['useLSTM']==False:
        init_X_train = np.load(f'portfolio_data/pred_asset_prices_training_{int(.8*args["nex"])}.npy')
        init_X_valid = np.load(f'portfolio_data/pred_asset_prices_validation_{int(.1*args["nex"])}.npy')

    else:
        index = args['LSTMmodelindex']
        ntrain = 1000
        ntest = 100
        init_X_train = np.load(f'portfolio_data/{index}_pred_asset_prices_training_{ntrain}.npy')
        init_X_valid = np.load(f'portfolio_data/{index}_pred_asset_prices_validation_{ntest}.npy')

    for i in range(args['epochs']):
        t0 = time.time()
        pnet.eval(); dnet.train(); dnet_k.eval()
        train_dloss_ = 0.
        for i, x in enumerate(train_loader):
            dopt.zero_grad()
            x_feat = torch.from_numpy(init_X_train[i*args['batchsize']:(i+1)*args['batchsize'], :]).double()
            if args['normalize']:
                x_feat = torch.nn.functional.normalize(x_feat)
            y = pnet(x_feat)
            lamda, mu = dnet(x_feat)
            lamda_k, mu_k = dnet_k(x_feat)
            train_dloss = total_loss_dual(data, x, y, lamda, mu, lamda_k, mu_k, rho, args)
            train_dloss.mean().backward()
            dopt.step()
            train_dloss_ += train_dloss.mean().item()
        train_dloss_ /= len(train_loader)
        t1 = time.time()

        pnet.eval(); dnet.eval(); dnet_k.eval()
        valid_dloss_ = 0.
        for x in valid_loader:
            x_feat = torch.from_numpy(init_X_valid).double()
            if args['normalize']:
                x_feat = torch.nn.functional.normalize(x_feat)
            y = pnet(x_feat)
            lamda, mu = dnet(x_feat)
            lamda_k, mu_k = dnet_k(x_feat)
            valid_dloss = total_loss_dual(data, x, y, lamda, mu, lamda_k, mu_k, rho, args).detach()
            valid_dloss_ += valid_dloss.mean().item()

        if valid_dloss_<best_valid_loss-1e-3:
            best_valid_loss=valid_dloss_
        else:
            scheduler.step()
        if i%10==0:
            print("D Epoch %d: train loss %.4f | valid loss: %.4f | time:%.4f"%(i, train_dloss_, valid_dloss_,t1-t0),flush=True)


def eval_net(data, X, pnet, dnet,feature_Generator, parameter_regressor_Net, rho, args, prefix, stats, j):
    torch.set_grad_enabled(False)
    make_prefix = lambda x: "{}_{}".format(prefix, x)
    mse = torch.nn.MSELoss()

    if args['useLSTM']==False:
        init_X_train = np.load(f'portfolio_data/pred_asset_prices_training_{int(.8*args["nex"])}.npy')
        init_X_valid = np.load(f'portfolio_data/pred_asset_prices_validation_{int(.1*args["nex"])}.npy')
        init_X_test = np.load(f'portfolio_data/pred_asset_prices_test_{int(.1*args["nex"])}.npy')

    else:
        index = args['LSTMmodelindex']
        ntrain = 1000
        ntest = 100
        init_X_train = np.load(f'portfolio_data/{index}_pred_asset_prices_training_{ntrain}.npy')
        init_X_valid = np.load(f'portfolio_data/{index}_pred_asset_prices_validation_{ntest}.npy')
        init_X_test = np.load(f'portfolio_data/{index}_pred_asset_prices_test_{ntest}.npy')


    if prefix == 'test_gt':
        Y = data.testY
    elif prefix == 'test':
        Ygt = data.testY
        X_feat = torch.from_numpy(init_X_test).double()
        if args['normalize']:
            X_feat = torch.nn.functional.normalize(X_feat)
        Y = pnet(X_feat)
    elif prefix == 'valid':
        X_feat = torch.from_numpy(init_X_valid).double()
        if args['normalize']:
            X_feat = torch.nn.functional.normalize(X_feat)
        Y = pnet(X_feat)
    else:
        X_feat = torch.from_numpy(init_X_train[j*args['batchsize']:(j+1)*args['batchsize'], :]).double()
        if args['normalize']:
            X_feat = torch.nn.functional.normalize(X_feat)
        Y = pnet(X_feat)
    if prefix == 'test_gt':
        X_feat = X
        if args['normalize']:
            X_feat = torch.nn.functional.normalize(X_feat)

    lamda, mu = dnet(X_feat)
    mu = torch.clamp(mu, min=0.)

    eqval = data.eq_resid(X, Y).float()
    ineqval = data.ineq_dist(X, Y)

    # if 'portfolio' == args['probtype'] and prefix == 'test' or 'predopt_portfolio' == args['probtype'] and prefix == 'test':
    #     if torch.count_nonzero(torch.abs(ineqval)).item() > 0 :
    #         Y = torch.clamp(Y, min = 0)
    #         ineqval = data.ineq_dist(X, Y)
    #     print("Inequality violations: ", torch.count_nonzero(torch.abs(ineqval)).item())
    #     eqval = data.eq_resid(X, Y).float()
    #     if torch.count_nonzero(torch.abs(eqval)).item() > 0 :
    #         Y = Y/Y.sum(dim=1, keepdim=True)
    #         eqval = data.eq_resid(X, Y).float()
    #     print("Equality violations: ", torch.count_nonzero(torch.abs(eqval)).item())

    acopf_restoration = True
    dict_agg(stats, make_prefix('eval'), data.obj_fn(X, Y).detach().cpu().numpy() * data.obj_scaler)
    dict_agg(stats, make_prefix('primal_loss'), total_loss_primal(data, X, Y, lamda, mu, rho, args).detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_max'), torch.max(torch.abs(eqval), dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_mean'), torch.mean(torch.abs(eqval), dim=1).detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_max'), torch.max(ineqval, dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_mean'), torch.mean(ineqval, dim=1).detach().cpu().numpy())

    if prefix == 'test':
        dict_agg(stats, make_prefix('opt_gap_mean'), data.opt_gap(X, Y, Ygt).detach().cpu().numpy())
        print("Before: ", np.mean(data.opt_gap(X, Y, Ygt).detach().cpu().numpy()))


    torch.set_grad_enabled(True)
    return stats


def total_loss_primal(data, x, y, lamda, mu, rho, args):
    obj_term, eq_term, ineq_term = 0., 0., 0.
    if args['neq']>0:
        lamda = lamda.detach().to(DEVICE)
        eq_val = data.eq_resid(x,y)
        eq_term = (lamda*eq_val).sum(dim=1) + (rho/2. * (eq_val).pow(2)).sum(dim=1)

    if args['nineq']>0:
        mu = mu.detach().to(DEVICE)
        ineq_val = data.ineq_resid(x,y)
        ineq_term = (mu*ineq_val).sum(dim=1) + (rho/2. * (torch.clamp(ineq_val, min=0.)).pow(2)).sum(dim=1)

    obj_term = data.obj_fn(x,y)
    return obj_term + eq_term + ineq_term

def total_loss_dual(data, x, y, lamda, mu, lamda_k, mu_k, rho, args):
    loss_eq, loss_ineq = 0., 0.
    if args['neq']>0:
        eq_val = data.eq_resid(x,y).detach()
        lamda_gt = lamda_k + rho*eq_val
        loss_eq = torch.norm(lamda_gt.detach() - lamda, dim=1)

    if args['nineq']>0:
        ineq_val = data.ineq_resid(x,y)
        mu_gt = torch.clamp(mu_k + rho*ineq_val, min=0.)
        loss_ineq = torch.norm(mu_gt.detach() - mu, dim=1)
    return loss_eq + loss_ineq


class PDLDataSet(Dataset):
    def __init__(self, X):
        super().__init__()
        self.X = X
        try:
            self.nex = self.X.shape[0]
        except:
            self.nex = self.X["pd"].shape[0]

    def __len__(self):
        return self.nex

    def __getitem__(self, idx):
        if isinstance(self.X,dict):
            return {k:v[idx] for k,v in self.X.items()}
        else:
            return self.X[idx]

### VDVF
def featGen(x, feat_Gen_Net):
    features = feat_Gen_Net(x.to(DEVICE)).to(DEVICE)
    return features

if __name__=='__main__':
    main()
