import torch
torch.set_default_dtype(torch.float64)
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time, argparse
from pprint import pprint
from utils_file import dict_agg, set_seed
import default_args
from helper_acopf import NNPrimalACOPFSolver, load_acopf_data 
import json
from dataset import Dataset as D
import pandas as pd
from pathlib import Path
from save_all import save_data, save_lists, save_error_data
import matplotlib.pyplot as plt

CURRENT_PATH = Path(__file__).absolute().parent

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("DEVICE", DEVICE, flush=True)

def main():
    parser = argparse.ArgumentParser(description='Baseline Unsupervised')
    parser.add_argument('--seed', type=int, default=1001, help='random seed')
    parser.add_argument('--probtype', type=str, default='acopf57', help='problem type')
    parser.add_argument('--use_sigmoid', type=bool, help='whether to apply a sigmoid to the last layer')
    parser.add_argument('--nHiddenUnit_SS', type=int, default = 50, help='number of hidden units')
    parser.add_argument('--nHiddenUnit_NODE', type=int, default = 200, help='number of hidden units') #50 for the old models
    parser.add_argument('--activation_NODE', type=str, default = "RELU", help='activation_function')
    parser.add_argument('--activation_SS', type=str, default = "RELU", help='activation_function')
    parser.add_argument('--optimizer', type=int, default = 2, help='GD algorithm')
    parser.add_argument('--lr', type=float, default = 1e-3, help='total number of datapoints')
    parser.add_argument('--batchsize', type=int, default = 100, help='training batch size') #10
    parser.add_argument('--epochs', type=int, default = 1, help='training batch size')
    parser.add_argument('--T', type=int, default = 3, help='time interval')
    parser.add_argument('--initSplit', type=float, default = .1)
    parser.add_argument('--normalize', type=bool, default = False)
    parser.add_argument('--plot', type=bool, default = False)
    parser.add_argument('--ldstepsize', type=float, default = .01)
    parser.add_argument('--activate_instability_computation_epoch', type=int, default = -1)
    parser.add_argument('--max_patience', type=int, default = 10)
    parser.add_argument('--nLayer_NODE', type=int, default = 5, help='number of layers')
    parser.add_argument('--nLayer_SS', type=int, default = 5, help='number of layers')
    parser.add_argument('--id', type=int, default = 999, help='number of layers')
    parser.add_argument('--lambda_zero_balance', type=float, default = .1, help='lambda(0)[0]')
    parser.add_argument('--lambda_zero_boundaries', type=float, default = .1, help='lambda(0)[1]')
    parser.add_argument('--lambda_zero_instability', type=float, default = 1, help='lambda(0)[2]')

    # SL baseline parameters
    parser.add_argument('--maxouteriter', type=int, default = 3000, help='maximum outer iterations')
    parser.add_argument('--losstype', type=str, default='ld', choices = ['mae','mse', 'maep', 'msep', 'ld'], help='MAE or MSE')
    parser.add_argument('--ldupdatefreq', type=int, help='LD penalty coefficient update epoch frequency')
    parser.add_argument('--hiddensize', type=int, default=500, help='hidden layer size for neural network')
    parser.add_argument('--hiddenfrac', type=float, help='hidden layer node fraction (only used for ACOPF)')

    # JK removed default=1.2 ^
    parser.add_argument('--nlayer', type=int, help='the number of layers')
    parser.add_argument('--lamg', type=float, help='penalty coefficient for inequality constraints')
    parser.add_argument('--lamh', type=float, help='penalty coefficient for equality constraints')
    parser.add_argument('--lamu', type=float, help='penalty coefficient for stabiility constraints')
    parser.add_argument('--save', type=bool, default=True, help='whether to save statistics')

    # JK
    parser.add_argument('--objscaler', type=bool, default=None, help='objective scaling factor')
    parser.add_argument('--index', type=int, help='index to keep track of different runs')

    ### VDVF
    parser.add_argument('--acopf_parameters', type=str, default="demands")
    parser.add_argument('--acopf_feature_mapping_type', type=str, default="synthetic_demands")

    args = vars(parser.parse_args()) # to dictionary
    args_default = default_args.baseline_supervised_default_args(args['probtype'],args['losstype'])
    for k,v in args_default.items():
        args[k] = v if args[k] is None else args[k]
    pprint(args)

    set_seed(args['seed'],DEVICE)

    data, args = load_acopf_data(args, CURRENT_PATH, DEVICE)

    print("Loading Data Done Successfully:", str(data))

    tstart = time.time()
    
    out, net, best_results = train_net(data, args)

    data = {
        #'Hyperparams': [str(args['rho']) + "_" + str(args['tau']) + "_" + str(args['rhomax']) + "_" + str(args['alpha']) + "_"],
        'Probtype': [args['probtype']],
        'n': [args['nlayer']],
        'Mean opt.gap': [best_results[0]],
        'Max Opt.gap': [best_results[1]],
        #'Losstype': [args['losstype']],
        'Mean eq. viol': [best_results[3]],
        'Max eq. viol': [best_results[2]],
        'T': [args['T']],
        'ID': [args['id']]

        # 'Status': ['Completed', 'Completed', 'Failed', 'Completed', 'Completed']
    }

    df = pd.DataFrame(data)
    df.to_csv('acopf57_results.csv', mode='a', header=False, index=False)

    print("Train is done, Elapsed Time %.2fs"%(time.time()-tstart), flush=True)

    if args['save']:
        save_dir = CURRENT_PATH/"results"/args['probtype']
        save_dir.mkdir(exist_ok=True, parents=True)
        save_name = "Supervised_loss%s_%s_s%d.chpt"%(args['losstype'],str(data),args['seed'])
        save_fn = save_dir/save_name
        save_dict = {
            'net': net.to('cpu').state_dict(),
            'args': args,
            'out': out
        }
        torch.save(save_dict, save_fn)

    return None

def feature_Generation(feature_Net, train_data, valid_data, test_data):
    return feature_Net(train_data.to(DEVICE)), feature_Net(valid_data.to(DEVICE)), feature_Net(test_data.to(DEVICE))

def train_net(data, args):

    train_dataset = BaselineDataSet(data.trainX, data.trainY)
    valid_dataset = BaselineDataSet(data.validX, data.validY)
    test_dataset = BaselineDataSet(data.testX, data.testY)

    # print("3rd group")
    # print(torch.mean(data.trainY["va"][:,1]))
    # print(torch.mean(data.trainY["va"][:,5]))
    # print(torch.mean(data.trainY["va"][:,8]))
    # # print(torch.mean(data.trainY["vm"][:,1]))
    # # print(torch.mean(data.trainY["vm"][:,5]))
    # # print(torch.mean(data.trainY["vm"][:,8]))

    # print(torch.max(data.trainY["va"][:,1]))
    # print(torch.max(data.trainY["va"][:,5]))
    # print(torch.max(data.trainY["va"][:,8]))
    # # print(torch.max(data.testY["vm"][:,1]))
    # # print(torch.max(data.testY["vm"][:,5]))
    # # print(torch.max(data.testY["vm"][:,8]))

    # print(torch.min(data.trainY["va"][:,1]))
    # print(torch.min(data.trainY["va"][:,5]))
    # print(torch.min(data.trainY["va"][:,8]))
    # # print(torch.min(data.testY["vm"][:,1]))
    # # print(torch.min(data.testY["vm"][:,5]))
    # # print(torch.min(data.testY["vm"][:,8]))

    # print("1st group")
    # print(torch.mean(data.trainY["va"][:,0]))
    # print(torch.mean(data.trainY["va"][:,7]))
    # print(torch.mean(data.trainY["va"][:,11]))
    # # print(torch.mean(data.trainY["vm"][:,1]))
    # # print(torch.mean(data.trainY["vm"][:,5]))
    # # print(torch.mean(data.trainY["vm"][:,8]))

    # print(torch.max(data.trainY["va"][:,0]))
    # print(torch.max(data.trainY["va"][:,7]))
    # print(torch.max(data.trainY["va"][:,11]))
    # # print(torch.max(data.testY["vm"][:,1]))
    # # print(torch.max(data.testY["vm"][:,5]))
    # # print(torch.max(data.testY["vm"][:,8]))

    # print(torch.min(data.trainY["va"][:,0]))
    # print(torch.min(data.trainY["va"][:,7]))
    # print(torch.min(data.trainY["va"][:,11]))
    # print(torch.min(data.testY["vm"][:,1]))
    # print(torch.min(data.testY["vm"][:,5]))
    # print(torch.min(data.testY["vm"][:,8]))

    train_loader = DataLoader(train_dataset, batch_size=args['batchsize'], shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

    net = NNPrimalACOPFSolver(data, args).to(DEVICE)

    net.train()
    optimizer = optim.Adam(net.parameters(), lr=args['lr'])

    lamh = args['lamh']*torch.ones(args['neq'],dtype=torch.get_default_dtype()).to(DEVICE)
    lamg = args['lamg']*torch.ones(args['nineq'],dtype=torch.get_default_dtype()).to(DEVICE)
    lamu = args['lamu']*torch.ones(1,dtype=torch.get_default_dtype()).to(DEVICE)
    
    feature_Generator = 0
    parameter_regressor_Net = 0

    train_loss_list, val_loss_list , test_loss_list = [], [], [] ### TOTAL LOSS
    mse_loss_list_train, mse_loss_list_valid, mse_loss_list_test = [], [], [] ### MSE LOSS
    balance_loss_list_train, balance_loss_list_valid, balance_loss_list_test = [], [], [] ### BALANCE CONSTRAINT LOSS
    bc_loss_list_train, bc_loss_list_valid, bc_loss_list_test = [], [], [] ### BOUNDARIES CONSTRAINT LOSS
    unstable_loss_list_train, unstable_loss_list_valid, unstable_loss_list_test = [], [], [] ### INSTABILITY PENALTY LOSS
    c_0_list, c_1_list, c_2_list  = [], [], []  ### LAGRANGIAN MULTIPLIER 
    true_unstable_list_train, false_unstable_list_train, detected_unstable_list_train = [], [], [] ## UNSTABLE STATS
    true_unstable_list_valid, false_unstable_list_valid, detected_unstable_list_valid = [], [], []
    true_unstable_list_test, false_unstable_list_test, detected_unstable_list_test = [], [], []

    global min_opt_gap
    min_opt_gap = 1000
    global n_epochs
    n_epochs = 0
    patience = 0
    max_patience = 300
    
    for i in range(args['maxouteriter']):
        print("Epoch :", i+1)
        loss_item = []
        for i in range(args['epochs']):
            t0 = time.time()
            epoch_stats = {}
            train_loss_ = 0.
            batch_loss_train = []
            batch_loss_val = []
            batch_mse_loss_list, batch_balance_loss_list, batch_bc_loss_list, batch_unstable_loss_list = [], [], [], []
            batch_true_unstable_list, batch_false_unstable_list, batch_detected_unstable = [], [], []
            
            net.train()

            for j, (Xtrain, Ytrain) in enumerate(train_loader):
                optimizer.zero_grad()
                Xfeat = Xtrain
                #print(Xfeat[0,:])
                Yhat_train = net(Xfeat)
                #print(Yhat_train[0,:])
                train_loss, loss_item = total_loss(data, Xtrain, Ytrain, Yhat_train, lamg, lamh, lamu, args, j)
                train_loss.mean().backward()
                batch_mse_loss_list.append(loss_item[0]+loss_item[1]+loss_item[2]+loss_item[3])
                batch_balance_loss_list.append(loss_item[4])
                batch_bc_loss_list.append(loss_item[5])
                batch_unstable_loss_list.append(loss_item[6])
                #true_unstable_list_train.append(tu.item())
                #false_unstable_list_train.append(fu.item())
                #detected_true_unstable_list_train.append(dtu.item())
                optimizer.step()
                batch_loss_train.append(train_loss.mean().item())
                train_loss_ += train_loss.mean().item()
                break
            train_loss_ /= (len(train_loader))
            t1 = time.time()
            net.eval()
            #print("Epoch training time: ", t1 - t0)
            #print("Lambda h size: ",lamh.size())
            #print("Lambda g size: ",lamg.size())

        mse_loss_list_train.append(np.mean(batch_mse_loss_list))
        balance_loss_list_train.append(np.mean(batch_balance_loss_list))
        bc_loss_list_train.append(np.mean(batch_bc_loss_list)) 
        unstable_loss_list_train.append(np.mean(batch_unstable_loss_list))
        train_loss_list.append(np.mean(batch_loss_train))

        for Xvalid, Yvalid in valid_loader:
            epoch_stats, loss_item = eval_net(data, Xvalid, Yvalid, net, feature_Generator, -1,  lamg, lamh, lamu, args, 'valid', epoch_stats)
            mse_loss_list_valid.append(loss_item[0]+loss_item[1]+loss_item[2]+loss_item[3])
            balance_loss_list_valid.append(loss_item[4])
            bc_loss_list_valid.append(loss_item[5])
            unstable_loss_list_valid.append(loss_item[6])
            batch_loss_val.append(loss_item[7].mean().detach().cpu().item())
            break
        val_loss_list.append(np.mean(batch_loss_val))

        #if i%10 == 0 and i>0:
        print("P Epoch:%05d | loss:%.4f | time:%.4fs"%(
            i, train_loss_, t1-t0
        ), flush=True)
        print("        valid | loss:%.4f | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f | instab pen:%.4f"%(
            np.mean(epoch_stats['valid_primal_loss']), np.mean(epoch_stats['valid_eval']),
            np.mean(epoch_stats['valid_ineq_max']), np.mean(epoch_stats['valid_ineq_mean']),
            np.mean(epoch_stats['valid_eq_max']),np.mean(epoch_stats['valid_eq_mean']),
            np.mean(epoch_stats['valid_instab_pen'])
        ), flush=True)

        if args['losstype'] == 'ld' and i%args['ldupdatefreq']==0 and i>0:
            lamg, lamh, lamu = update_lamda(train_loader, net, feature_Generator, data, lamg, lamh, lamu, args)

        for idx,(Xtrain, Ytrain) in enumerate(train_loader):
            epoch_stats, loss_item = eval_net(data, Xtrain, Ytrain, net, feature_Generator, idx,  lamg, lamh, lamu, args, 'train', epoch_stats)
        
        batch_loss_test = []

        for X, Y in test_loader:
            epoch_stats, loss_item = eval_net(data, X, Y, net, feature_Generator, -10,  lamg, lamh, lamu, args, 'test', epoch_stats)
            mse_loss_list_test.append(loss_item[0]+loss_item[1]+loss_item[2]+loss_item[3])
            balance_loss_list_test.append(loss_item[4])
            bc_loss_list_test.append(loss_item[5])
            unstable_loss_list_test.append(loss_item[6])
            batch_loss_test.append(loss_item[7].mean().detach().cpu().item())
            epoch_stats = eval_net(data, X, Y, net, feature_Generator, -100,  lamg, lamh, lamu, args, 'test_gt', epoch_stats)
            break
        test_loss_list.append(np.mean(batch_loss_test))

        epoch_stats = epoch_stats[0]

        #tmp = epoch_stats['test_opt_gap']
        #print("Quantile of opt. gap", np.quantile(tmp, [0.25, 0.5, 0.75, 1]))
        print("        train |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f | instab pen:%.4f "%(
            np.mean(epoch_stats['train_eval']),
            np.mean(epoch_stats['train_ineq_max']), np.mean(epoch_stats['train_ineq_mean']),
            np.mean(epoch_stats['train_eq_max']),np.mean(epoch_stats['train_eq_mean']),
            np.mean(epoch_stats['train_instab_pen'])
        ), flush=True)
        print("         test |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f | optgap max:%.4f mean:%.4f | instab pen:%.4f "%( #after:%.4f"%(
            np.mean(epoch_stats['test_eval']),
            np.mean(epoch_stats['test_ineq_max']), np.mean(epoch_stats['test_ineq_mean']),
            np.mean(epoch_stats['test_eq_max']),np.mean(epoch_stats['test_eq_mean']),
            100*np.max(epoch_stats['test_opt_gap']), 100*np.mean(epoch_stats['test_opt_gap']), #, 100*np.mean(epoch_stats['test_opt_gap_after'])
            np.mean(epoch_stats['test_instab_pen'])
        ), flush=True)
        print("      test gt |              | obj:%.4f | ineq max:%.4f mean:%.4f | eq max:%.4f mean:%.4f"%(
            np.mean(epoch_stats['test_gt_eval']),
            np.mean(epoch_stats['test_gt_ineq_max']), np.mean(epoch_stats['test_gt_ineq_mean']),
            np.mean(epoch_stats['test_gt_eq_max']),np.mean(epoch_stats['test_gt_eq_mean'])
        ), flush=True)
        out = {
            'obj': np.mean(epoch_stats['test_eval']),
            'eq_max': np.mean(epoch_stats['test_eq_max']),
            'ineq_max': np.mean(epoch_stats['test_ineq_max']),
            'eq_mean': np.mean(epoch_stats['test_eq_mean']),
            'ineq_mean': np.mean(epoch_stats['test_ineq_mean']),
            'opt_gap_max': 100*np.max(epoch_stats['test_opt_gap']),
            'opt_gap_mean': 100*np.mean(epoch_stats['test_opt_gap']),
            'insatb_pen' : np.mean(epoch_stats['test_instab_pen'])
        }
        n_epochs += 1
        tmp = np.mean(epoch_stats['test_opt_gap'])
        if  tmp < min_opt_gap:
            min_opt_gap = tmp
            max_opt_gap = 100*np.max(epoch_stats['test_opt_gap'])
            max_eq = out['eq_max']
            mean_eq = out['eq_mean']
            best_results = [min_opt_gap, max_opt_gap, max_eq, mean_eq]
            patience = 0
        else:
            patience += 1
            if patience==max_patience:
                break
        n_epochs += 1
    
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    n_bins = 100
    pg_mse, v_mse, theta_mse = np.zeros((len(test_loader),n_bins)), np.zeros((len(test_loader),n_bins)), np.zeros((len(test_loader),n_bins))

    sd_interval_length, sd_max = 0.1, .37
    n_loads = 42

    for n_sample, (X, Y) in enumerate(test_loader):
        Yhat = net(X)
        epoch_stats, loss_item = total_loss(data, X, Y, Yhat, lamg, lamh, lamu, args, -100)
        avg_load = sum([torch.sqrt(X["pd"][0][i]**2 + X["qd"][0][i]**2).item() for i in range(n_loads)]) / n_loads
        tmp = int(n_bins/sd_interval_length*(sd_max-avg_load))
        pg_mse[n_sample, tmp] = loss_item[0]
        v_mse[n_sample, tmp] = loss_item[2]
        theta_mse[n_sample, tmp] = loss_item[3]

    avg_pg_mse, avg_v_mse, avg_theta_mse = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins)
    std_dev_pg_mse, std_dev_v_mse, std_dev_theta_mse = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins)
        
    avg_pg_mse = pg_mse.mean(axis=0)
    std_dev_pg_mse = pg_mse.std(axis=0)

    avg_v_mse = v_mse.mean(axis=0)
    std_dev_v_mse = v_mse.std(axis=0)

    avg_theta_mse = theta_mse.mean(axis=0)
    std_dev_theta_mse = theta_mse.std(axis=0)

    run_id = args['id'] 

    direc='LD_DYN'

    save_lists(direc, run_id, c_0_list, c_1_list, c_2_list, train_loss_list, mse_loss_list_train, balance_loss_list_train, bc_loss_list_train, unstable_loss_list_train,
            val_loss_list, mse_loss_list_valid, balance_loss_list_valid, bc_loss_list_valid, unstable_loss_list_valid,
            test_loss_list, mse_loss_list_test, balance_loss_list_test, bc_loss_list_test, unstable_loss_list_test,
            true_unstable_list_train, false_unstable_list_train, detected_unstable_list_train, 
            true_unstable_list_valid, false_unstable_list_valid, detected_unstable_list_valid,
            true_unstable_list_test, false_unstable_list_test, detected_unstable_list_test, args['T'])
    
    save_error_data(direc, run_id, avg_pg_mse, std_dev_pg_mse, avg_v_mse, std_dev_v_mse, avg_theta_mse, std_dev_theta_mse, args['T'])
    
    return out, net, best_results


def eval_net(data, X, Ygt, net, feature_Generator, batch_idx, lamg, lamh, lamu, args, prefix, stats):
    torch.set_grad_enabled(False)
    make_prefix = lambda x: "{}_{}".format(prefix, x)
    mse = torch.nn.MSELoss()
    start_time = time.time()

    if prefix == 'test_gt':
        Y = data.testY
    elif prefix == 'test':
        Ygt = data.testY
        X_feat = X
        if 'predopt' in args['probtype'] and 'acopf' not in args['probtype']:
            X_feat = featGen(X, feature_Generator).to(DEVICE)
        Y = net(X_feat)
    else:
        X_feat = X
        if 'predopt' in args['probtype'] and 'acopf' not in args['probtype']:
            X_feat = featGen(X, feature_Generator).to(DEVICE)
        Y = net(X_feat)

    '''
    if prefix == 'test_gt':
        Y = data.testY
    else:
        Y = net(X)   
    '''

    eqval = data.eq_resid(X, Y).float()
    ineqval = data.ineq_dist(X, Y)
    end_time = time.time()
    
    dict_agg(stats, make_prefix('time'), end_time - start_time, op='sum')
    dict_agg(stats, make_prefix('eval'), data.obj_fn(X, Y).detach().cpu().numpy() * data.obj_scaler)
    
    if prefix == 'valid':
        batch_idx = -1
    elif prefix == 'test':
        batch_idx = -10
    elif prefix == 'test_gt':
        batch_idx = -100

    total_loss_value, loss_item =  total_loss(data, X, Ygt, Y, lamg, lamh, lamu, args, batch_idx)
    total_loss_value = total_loss_value.detach().cpu().numpy()
    
    dict_agg(stats, make_prefix('primal_loss'), total_loss_value)
    #dict_agg(stats, make_prefix('primal_loss'), total_loss(data, X, Ygt, Y, lamg, lamh, lamu, args, batch_idx).detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_max'), torch.max(data.ineq_dist(X, Y), dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_mean'), torch.mean(data.ineq_dist(X, Y), dim=1).detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_max'), torch.max(torch.abs(data.eq_resid(X, Y)), dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_mean'), torch.mean(torch.abs(data.eq_resid(X, Y)), dim=1).detach().cpu().numpy())

    if 'gt' not in prefix and 'test' in prefix:
        dict_agg(stats, make_prefix('opt_gap'), data.opt_gap(X, Y, Ygt).detach().cpu().numpy())
    if 'gt' not in prefix :
        dict_agg(stats, make_prefix('instab_pen'), data.instab_resid(X, Y, n_epochs, args, batch_idx).detach().cpu().numpy())

    to_be_restored = {}
    #if prefix == 'test':
    #    dict_agg(stats, make_prefix('opt_gap_after'), data.opt_gap(X, Y, Ygt).detach().cpu().numpy())
    #    if n_epochs == 0:
    #        for i in range(len(stats['test_opt_gap'])):
    #            to_be_restored["bs_regret"] = stats['test_opt_gap'][i]
    #            to_be_restored[str(i+1)] = {}
    #            for j in range(50):
    #                to_be_restored[str(i+1)][str(j+1)] = str(Y[i][j].detach().cpu().numpy())
    #        with open('bs_regret_'+str(args['featNet_nlayer'])+'_.json', 'w') as fp:
    #            json.dump(to_be_restored, fp)

    torch.set_grad_enabled(True)
    return stats, loss_item

def total_loss(data, X, Ygt, Y, lamg, lamh, lamu, args, batch_idx):

    ineq_val = data.ineq_resid(X,Y)
    eq_val = data.eq_resid(X,Y)
    instability_val = data.instab_resid(X, Y, n_epochs, args, batch_idx) if batch_idx!=-100 else torch.tensor([0]).to(DEVICE)
        
    if 'mae' in args['losstype'] or args['losstype'] == 'ld':
        if 'acopf' in args['probtype']:
            pg_loss = (Ygt['pg'].to(DEVICE)-Y['pg'].to(DEVICE)).abs().mean(dim=1)
            qg_loss = (Ygt['qg'].to(DEVICE)-Y['qg'].to(DEVICE)).abs().mean(dim=1)
            vm_loss = (Ygt['vm'].to(DEVICE)-Y['vm'].to(DEVICE)).abs().mean(dim=1)
            dva_loss = (Ygt['dva'].to(DEVICE)-Y['dva'].to(DEVICE)).abs().mean(dim=1)
            va_loss = (Ygt['va'].to(DEVICE)-Y['va'].to(DEVICE)).abs().mean(dim=1)
            loss = 0.2*(pg_loss + qg_loss + vm_loss + dva_loss + va_loss)
        else:
            raise NotImplementedError
        eq_viols = eq_val.abs()
        ineq_viols = torch.clamp(ineq_val, min=0.)
    elif 'mse' in args['losstype']:
        if 'acopf' in args['probtype']:
            pg_loss = (Ygt['pg'].to(DEVICE)-Y['pg'].to(DEVICE)).pow(2).mean(dim=1)
            qg_loss = (Ygt['qg'].to(DEVICE)-Y['qg'].to(DEVICE)).pow(2).mean(dim=1)
            vm_loss = (Ygt['vm'].to(DEVICE)-Y['vm'].to(DEVICE)).pow(2).mean(dim=1)
            dva_loss = (Ygt['dva'].to(DEVICE)-Y['dva'].to(DEVICE)).pow(2).mean(dim=1)
            va_loss = (Ygt['va'].to(DEVICE)-Y['va'].to(DEVICE)).abs().mean(dim=1)
            loss = 0.2*(pg_loss + qg_loss + vm_loss + dva_loss + va_loss)
        else:
            raise NotImplementedError
        eq_viols = eq_val.pow(2)
        ineq_viols = (torch.clamp(ineq_val, min=0.)).pow(2)

    if 'p' in args['losstype'] or args['losstype'] == 'ld':
        eq_term = (lamh*eq_viols).mean(dim=1).mean()
        ineq_term = (lamg*ineq_viols).mean(dim=1).mean()
        instability_term = lamu*instability_val
        total_loss = loss + eq_term + ineq_term + 100*instability_term
    else:
        total_loss = loss

    return total_loss, [pg_loss.mean().detach().cpu().numpy(), qg_loss.mean().detach().cpu().numpy(), vm_loss.mean().detach().cpu().numpy(), va_loss.mean().detach().cpu().numpy(), eq_term.detach().cpu().numpy(), ineq_term.detach().cpu().numpy(), instability_term[0].detach().cpu().numpy(), total_loss]

### VDVF
def featGen(x, feat_Gen_Net):
    features = feat_Gen_Net(x).to(DEVICE)
    return features

class BaselineDataSet(Dataset):
    def __init__(self, X, Y):
        super().__init__()
        self.X = X
        self.Y = Y
        try:
            self.nex = self.X.shape[0]
        except:
            self.nex = self.X["pd"].shape[0]

    def __len__(self):
        return self.nex

    def __getitem__(self, idx):
        if isinstance(self.X,dict):
            x = {k:v[idx] for k,v in self.X.items()}
            y = {k:v[idx] for k,v in self.Y.items()}
            return x, y
        else:
            return self.X[idx], self.Y[idx]


def update_lamda(train_loader, net, feature_Generator, data, lamg, lamh, args):
    torch.set_grad_enabled(False)
    net.eval()
    eq_viols, ineq_viols, instab_viols = [], [], []
    idx = 0
    for Xtrain, _ in train_loader:
        Xfeat = Xtrain
        if 'predopt' in args['probtype'] and 'acopf' not in args['probtype']:
            Xfeat = featGen(Xtrain, feature_Generator).to(DEVICE)
        Yhat_train = net(Xfeat)
        ineq_val = data.ineq_resid(Xtrain,Yhat_train)
        eq_val = data.eq_resid(Xtrain,Yhat_train)
        eq_viol = eq_val.abs()
        ineq_viol = torch.clamp(ineq_val, min=0.)
        instab_viol = data.instab_resid(Xtrain, Yhat_train, n_epochs, args, idx )
        eq_viols.append(eq_viol)
        ineq_viols.append(ineq_viol)
        instab_viols.append(instab_viol)
        idx += 1
    eq_viols = torch.cat(eq_viols, dim=0).mean(dim=0)
    ineq_viols = torch.cat(ineq_viols, dim=0).mean(dim=0)

    lamg = lamg + args['ldstepsize']*ineq_viols.mean()
    lamh = lamh + args['ldstepsize']*eq_viols
    lamu = lamu + args['ldstepsize']*instab_viols.mean()
    net.train()
    torch.set_grad_enabled(True)
    print("Update lambdas: lamg=", lamg.max().item(), ", lamh=", lamh.max().item(), ", lamu=", lamu.max().item(), flush=True)

    return lamg, lamh, lamu

if __name__=='__main__':
    main()