"""
This code is based on the official DC3 implementation in https://github.com/locuslab/DC3
"""

import torch
import torch.nn as nn
import torch.optim as optim
torch.set_default_dtype(torch.float64)
import torchsde
import operator
from functools import reduce
from helper_new_portfolio import NNPrimalSolver, NNDualSolver, load_portfolio_data, load_portfolio_dyn_data
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import numpy as np
import time
import argparse
from pprint import pprint
from NSDE_training import NeuralSDE
from utils import str_to_bool, dict_agg, set_seed
import default_args
from helper_qp import load_qp_data
import time
from pathlib import Path
CURRENT_PATH = Path(__file__).absolute().parent

# VDVF
from helper_new_portfolio import NNPrimalSolver, NNDualSolver, featureNet, load_portfolio_data

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("DEVICE", DEVICE, flush=True)

def main():
    parser = argparse.ArgumentParser(description='DC3')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--probtype', type=str, default='portfolio', choices=['convexqp', 'nonconvexqp', 'portfolio','predopt_portfolio'], help='problem type')
    parser.add_argument('--nvar', type=int, help='number of decision variables')
    parser.add_argument('--nineq', type=int, help='number of inequality constraints')
    parser.add_argument('--neq', type=int, help='number of equality constraints')
    parser.add_argument('--nlayer', type=int, help='the number of layers')
    parser.add_argument('--nex', type=int, help='total number of datapoints')
    parser.add_argument('--epochs', type=int, help='number of neural network epochs')
    parser.add_argument('--batchsize', type=int, help='training batch size')
    parser.add_argument('--maxouteriter', type=int, help='maximum outer iterations')
    parser.add_argument('--lr', type=float, help='neural network learning rate')
    parser.add_argument('--hiddensize', type=int, help='hidden layer size for neural network')
    parser.add_argument('--softweight', type=float, help='total weight given to constraint violations in loss')
    parser.add_argument('--softweighteqfrac', type=float, help='fraction of weight given to equality constraints (vs. inequality constraints) in loss')
    parser.add_argument('--usecompl', type=str_to_bool, help='whether to use completion')
    parser.add_argument('--usetraincorr', type=str_to_bool, help='whether to use correction during training')
    parser.add_argument('--usetestcorr', type=str_to_bool, help='whether to use correction during testing')
    parser.add_argument('--corrmode', choices=['partial', 'full'], help='employ DC3 correction (partial) or naive correction (full)')
    parser.add_argument('--corrtrainsteps', type=int, help='number of correction steps during training')
    parser.add_argument('--corrtestmaxsteps', type=int, help='max number of correction steps during testing')
    parser.add_argument('--correps', type=float, help='correction procedure tolerance')
    parser.add_argument('--corrlr', type=float, help='learning rate for correction procedure')
    parser.add_argument('--corrmomentum', type=float, help='momentum for correction procedure')
    parser.add_argument('--lambda', type=float, default=2., help='regularization term amplitude of the opt. prob obj. function')
    parser.add_argument('--save', type=str_to_bool, default=False, help='whether to save statistics')
    parser.add_argument('--use_sigmoid', type=bool, default=False, help='whether to apply a sigmoid to the last layer')
    ### FEATURE GENERATOR NETWORK ARGS


    args = parser.parse_args()
    args = vars(args) # change to dictionary
    defaults = default_args.dc3_default_args(args['probtype'])
    for key in defaults.keys():
        if args[key] is None:
            args[key] = defaults[key]
    pprint(args)
    print(args['use_sigmoid'])
    set_seed(args['seed'], DEVICE)

    if args['probtype'] in ['convexqp','nonconvexqp']:
        data = load_qp_data(args, CURRENT_PATH, DEVICE)
    elif 'portfolio' in args['probtype']:
        data, args = load_portfolio_dyn_data(args, CURRENT_PATH, DEVICE)
    else:
        raise NotImplementedError

    # Run method
    tstart = time.time()
    out, net, best_results, best_results_batch = train_net(data, args)

    data = {
            #Hyperparams': [str(args['rho']) + "," + str(args['tau']) + "," + str(args['rhomax']) + "," + str(args['alpha']) + "," + str(args['lr'])   + ","+ str(args['featsize']) ],
            #'Probtype': [args['probtype']],
            
            'Seed': [args['seed']],
            'Max eq. viol': [best_results[0]],
            'Mean eq. viol': [best_results[1]],
            'Max ineq. viol.': [best_results[2]],
            'Mean ineq. viol.': [best_results[3]],
            'Mean opt.gap': [best_results[4]],
            'Mean opt.gap batch': [best_results_batch[5]],

            'Max eq. viol batch': [best_results_batch[0]],
            'Mean eq. viol batch': [best_results_batch[1]],
            'Max ineq. viol. batch': [best_results_batch[2]],
            'Mean ineq. viol. batch': [best_results_batch[3]],
            'Mean opt.gap per batch': [best_results_batch[4]],
            'Mean opt.gap batch per batch': [best_results_batch[5]],

    }
    #column_names = ['Hyperparams', 'K', 'Mapping', 'Seed', 'Opt.gap', 'Max eq. viol.', 'Mean eq. viol.']
    df = pd.DataFrame(data)
    df.to_csv('dc3_static_results.csv', mode='a', header=False, index=False)
    print("Record saved.")


    print("Train is done, Elapsed Time %.2fs"%(time.time()-tstart), flush=True)

    if args['save']:
        save_dir = CURRENT_PATH/"results"/args['probtype']
        save_dir.mkdir(exist_ok=True, parents=True)
        save_name = "DC3_%s_s%d.chpt"%(str(data),args['seed'])
        save_fn = save_dir/save_name
        save_dict = {
            'net': net.to('cpu').state_dict(),
            'args': args,
            'out': out
        }
        torch.save(save_dict, save_fn)

    return None


def train_net(data, args):
    solver_step = args['lr']
    nepochs = args['epochs']
    batch_size = args['batchsize']

    train_dataset = TensorDataset(data.trainX)
    valid_dataset = TensorDataset(data.validX)
    test_dataset = TensorDataset(data.testX)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle = False)
    valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle = False)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle = False)

    solver_net = NNSolver(data, args)
    solver_net.to(DEVICE)
    solver_opt = optim.Adam(solver_net.parameters(), lr=solver_step)
    feature_Generator = 0 #featureNet(args['featNet_nlayer'], args['featNet_hiddensize'],data, args).to(DEVICE)
    global min_opt_gap, min_opt_gap_batch
    min_opt_gap, min_opt_gap_batch = 100000, 100000

    avg_test_gt = -1.1378

    init_X_train = np.load(f'portfolio_data/init_asset_prices_training_{int(.8*args["nex"])}.npy').T
    init_X_valid = np.load(f'portfolio_data/init_asset_prices_validation_{int(.1*args["nex"])}.npy').T
    init_X_test= np.load(f'portfolio_data/init_asset_prices_test_{int(.1*args["nex"])}.npy').T


    for i in range(args['maxouteriter']):
        t0 = time.time()
        for i in range(nepochs):
            epoch_stats = {}

            # Get train loss
            solver_net.train()
            for j,Xtrain in enumerate(train_loader):
                Xtrain = Xtrain[0].to(DEVICE)
                Xfeat = torch.from_numpy(init_X_train[j*args['batchsize']:(j+1)*args['batchsize'], :]).double() #Xtrain
                start_time = time.time()
                solver_opt.zero_grad()
                Yhat_train = solver_net(Xfeat)
                Ynew_train = grad_steps(data, Xtrain, Yhat_train, args)
                train_loss = total_loss(data, Xtrain, Ynew_train, args)
                train_loss.sum().backward()
                solver_opt.step()
                train_time = time.time() - start_time
                #print("Training time: ",train_time)
                dict_agg(epoch_stats, 'train_loss', train_loss.detach().cpu().numpy())
                dict_agg(epoch_stats, 'train_time', train_time, op='sum')
            
            # Get valid loss
            solver_net.eval()
            for Xvalid in valid_loader:
                Xvalid = Xvalid[0].to(DEVICE)
                eval_net(data, Xvalid, solver_net, feature_Generator, args, 'valid', epoch_stats)

            # Get test loss
            solver_net.eval()
            for Xtest in test_loader:
                Xtest = Xtest[0].to(DEVICE)
                eval_net(data, Xtest, solver_net, feature_Generator, args, 'test', epoch_stats)
            print('Epoch {}: train loss {:.4f}, obj. {:.4f}, dist {:.4f}, ineq max {:.4f}, ineq mean {:.4f}, ineq num viol {:.4f}, eq max {:.4f}, steps {}, time {:.4f}'.format(
                i, np.mean(epoch_stats['train_loss']), np.mean(epoch_stats['test_eval']),
                np.mean(epoch_stats['test_dist']), np.mean(epoch_stats['test_ineq_max']),
                np.mean(epoch_stats['test_ineq_mean']), np.mean(epoch_stats['test_ineq_num_viol_0']),
                np.mean(epoch_stats['test_eq_max']), np.mean(epoch_stats['test_steps']), np.mean(epoch_stats['test_time'])), flush=True)
            
            tmp = data.testX
            Y = solver_net(tmp)
            Ycorr, _ = grad_steps_all(data, tmp, Y, args)
            optgap = data.opt_gap(data.testX, Ycorr, data.testY)
            mean_optgap_val = 100*np.mean(optgap.detach().cpu().numpy())
            max_optgap_val = 100*np.max(optgap.detach().cpu().numpy())
            #if mean_optgap_val < min_opt_gap:
            #    min_opt_gap = mean_optgap_val
            #    max_eq = out['eq_max']
            #    mean_eq = out['eq_mean']
            #    max_ineq = out['ineq_max']
            #    mean_ineq = out['ineq_mean']
            #    best_results = [min_opt_gap, max_eq, mean_eq, out['opt_gap_mean_after'], max_ineq, mean_ineq]
            print("mean opt. gap: %.4f"%(mean_optgap_val))
            print("max opt. gap: %.4f"% (max_optgap_val))
    
        print("Training time: ",time.time()-t0)
        # check optgap
        solver_net.eval()
        tmp = data.testX
        Y = solver_net(tmp)

        Ycorr, _ = grad_steps_all(data, tmp, Y, args)
        optgap = data.opt_gap(data.testX, Ycorr, data.testY)
        mean_optgap_val = 100*np.mean(optgap.detach().cpu().numpy())
        max_optgap_val = 100*np.max(optgap.detach().cpu().numpy())
        #if mean_optgap_val < min_opt_gap:
        #    min_opt_gap = mean_optgap_val
        #    max_eq = out['eq_max']
        #    mean_eq = out['eq_mean']
        #    max_ineq = out['ineq_max']
        #    mean_ineq = out['ineq_mean']
        #    best_results = [min_opt_gap, max_eq, mean_eq, out['opt_gap_mean_after'], max_ineq, mean_ineq]
        print("mean opt. gap: %.4f"%(mean_optgap_val))
        print("max opt. gap: %.4f"% (max_optgap_val))
    
    
        out = {
            'obj': np.mean(epoch_stats['test_eval']),
            'eq_max': np.mean(epoch_stats['test_eq_max']),
            'ineq_max': np.mean(epoch_stats['test_ineq_max']),
            'eq_mean': np.mean(epoch_stats['test_eq_mean']),
            'ineq_mean': np.mean(epoch_stats['test_ineq_mean']),
            'opt_gap_mean': mean_optgap_val,
            'opt_gap_max': max_optgap_val,
        }

        # if mean_optgap_val < min_opt_gap:
        #     min_opt_gap = mean_optgap_val
        #     max_eq = out['eq_max']
        #     mean_eq = out['eq_mean']
        #     max_ineq = out['ineq_max']
        #     mean_ineq = out['ineq_mean']
        #     best_results = [min_opt_gap, max_eq, mean_eq, 0, max_ineq, mean_ineq]

        if  mean_optgap_val < min_opt_gap:
            min_opt_gap = mean_optgap_val
            batch_opt_gap = 100*abs(avg_test_gt - epoch_stats['test_eval'].mean())/abs(avg_test_gt)
            max_eq = out['eq_max']
            mean_eq = out['eq_mean']
            max_ineq = out['ineq_max']
            mean_ineq = out['ineq_mean']
            best_results = [max_eq, mean_eq, max_ineq, mean_ineq, min_opt_gap, batch_opt_gap]

        if  100*abs(avg_test_gt - epoch_stats['test_eval'].mean())/abs(avg_test_gt) < min_opt_gap_batch:
            min_opt_gap = mean_optgap_val
            min_batch_opt_gap = 100*abs(avg_test_gt - epoch_stats['test_eval'].mean())/abs(avg_test_gt)
            max_eq = out['eq_max']
            mean_eq = out['eq_mean']
            max_ineq = out['ineq_max']
            mean_ineq = out['ineq_mean']
            best_results_batch = [max_eq, mean_eq, max_ineq, mean_ineq, min_opt_gap, min_batch_opt_gap]


    return out, solver_net, best_results, best_results_batch 

# Modifies stats in place
def eval_net(data, X, solver_net, feature_Generator, args, prefix, stats):
    eps_converge = args['correps']
    mse = torch.nn.MSELoss()
    make_prefix = lambda x: "{}_{}".format(prefix, x)
    #if 'predopt' in args['probtype']:
    #    feature_Generator = featureNet(args['featNet_nlayer'], args['featNet_hiddensize'],data).to(DEVICE)
    init_X_valid = np.load(f'portfolio_data/init_asset_prices_validation_{int(.1*args["nex"])}.npy') 
    init_X_test = np.load(f'portfolio_data/init_asset_prices_test_{int(.1*args["nex"])}.npy') 

    start_time = time.time()

    if 'test' in prefix:

        X_hat = torch.from_numpy(init_X_test).double().T
        print("MSE test: ", mse(X_hat,X))

        #X_hat = torch.squeeze(torchsde.sdeint(nsde_model, torch.from_numpy(init_X_valid), time_points).double()).permute(1,0,2)[:,-1,:]
    else:
        X_hat = torch.from_numpy(init_X_valid).double().T
        #X_hat = torch.squeeze(torchsde.sdeint(nsde_model, torch.from_numpy(init_X_test), time_points).double()).permute(1,0,2)[:,-1,:]


    Y = solver_net(X_hat)

    Ycorr, steps = grad_steps_all(data, X_hat, Y, args)
    end_time = time.time()

    Ynew = grad_steps(data, X_hat, Y, args)
    
    '''
    eqval = data.eq_resid(X, Y).float()
    ineqval = data.ineq_dist(X, Y)
    '''

    ineqval = data.ineq_dist(X, Ycorr)
    eqval = data.eq_resid(X, Ycorr).float()

    if 'portfolio' == args['probtype'] and prefix == 'test':
        if torch.count_nonzero(torch.abs(ineqval)).item() > 0 :
            Ycorr = torch.clamp(Ycorr, min = 0)
            ineqval = data.ineq_dist(X, Ycorr)
        print("Inequality violations: ", torch.count_nonzero(torch.abs(ineqval)).item())
        eqval = data.eq_resid(X, Ycorr).float()
        if torch.count_nonzero(torch.abs(eqval)).item() > 0 :
            Ycorr = Ycorr/Ycorr.sum(dim=1, keepdim=True)
            eqval = data.eq_resid(X, Ycorr).float()
        print("Equality violations: ", torch.count_nonzero(torch.abs(eqval)).item())

    dict_agg(stats, make_prefix('time'), end_time - start_time, op='sum')
    dict_agg(stats, make_prefix('steps'), np.array([steps]))
    dict_agg(stats, make_prefix('loss'), total_loss(data, X, Ynew, args).detach().cpu().numpy())
    dict_agg(stats, make_prefix('eval'), data.obj_fn(X,Ycorr).detach().cpu().numpy())
    dict_agg(stats, make_prefix('dist'), torch.norm(Ycorr - Y, dim=1).detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_max'), torch.max(data.ineq_dist(X, Ycorr), dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_mean'), torch.mean(data.ineq_dist(X, Ycorr), dim=1).detach().cpu().numpy())
    dict_agg(stats, make_prefix('ineq_num_viol_0'),torch.sum(data.ineq_dist(X, Ycorr) > eps_converge, dim=1).detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_max'),torch.max(torch.abs(data.eq_resid(X, Ycorr)), dim=1)[0].detach().cpu().numpy())
    dict_agg(stats, make_prefix('eq_mean'), torch.mean(torch.abs(data.eq_resid(X, Ycorr)), dim=1).detach().cpu().numpy())

    return stats

def total_loss(data, X, Y, args):
    obj_cost = data.obj_fn(X,Y)
    ineq_dist = data.ineq_dist(X, Y)
    ineq_cost = torch.norm(ineq_dist, dim=1)
    eq_cost = torch.norm(data.eq_resid(X, Y), dim=1)
    return obj_cost + args['softweight'] * (1 - args['softweighteqfrac']) * ineq_cost + args['softweight'] * args['softweighteqfrac'] * eq_cost

def grad_steps(data, X, Y, args):
    take_grad_steps = args['usetraincorr']
    if take_grad_steps:
        lr = args['corrlr']
        num_steps = args['corrtrainsteps']
        momentum = args['corrmomentum']
        partial_var = args['usecompl']
        partial_corr = True if args['corrmode'] == 'partial' else False
        if partial_corr and not partial_var:
            assert False, "Partial correction not available without completion."
        Y_new = Y
        old_Y_step = 0
        for i in range(num_steps):
            if partial_corr:
                Y_step = data.ineq_partial_grad(X, Y_new)
            else:
                ineq_step = data.ineq_grad(X, Y_new)
                eq_step = data.eq_grad(X, Y_new)
                Y_step = (1 - args['softweighteqfrac']) * ineq_step + args['softweighteqfrac'] * eq_step
            
            new_Y_step = lr * Y_step + momentum * old_Y_step
            Y_new = Y_new - new_Y_step

            old_Y_step = new_Y_step

        return Y_new
    else:
        return Y

# Used only at test time, so let PyTorch avoid building the computational graph
def grad_steps_all(data, X, Y, args):
    take_grad_steps = args['usetestcorr']
    if take_grad_steps:
        lr = args['corrlr']
        eps_converge = args['correps']
        max_steps = args['corrtestmaxsteps']
        momentum = args['corrmomentum']
        partial_var = args['usecompl']
        partial_corr = True if args['corrmode'] == 'partial' else False
        if partial_corr and not partial_var:
            assert False, "Partial correction not available without completion."
        Y_new = Y
        i = 0
        old_Y_step = 0
        old_ineq_step = 0
        old_eq_step = 0
        with torch.no_grad():
            while (i == 0 or torch.max(torch.abs(data.eq_resid(X, Y_new))) > eps_converge or
                           torch.max(data.ineq_dist(X, Y_new)) > eps_converge) and i < max_steps:
                if partial_corr:
                    Y_step = data.ineq_partial_grad(X, Y_new)
                else:
                    ineq_step = data.ineq_grad(X, Y_new)
                    eq_step = data.eq_grad(X, Y_new)
                    Y_step = (1 - args['softweighteqfrac']) * ineq_step + args['softweighteqfrac'] * eq_step
                
                new_Y_step = lr * Y_step + momentum * old_Y_step
                Y_new = Y_new - new_Y_step

                old_Y_step = new_Y_step
                i += 1

        return Y_new, i
    else:
        return Y, 0


######### Models

class NNSolver(nn.Module):
    def __init__(self, data, args):
        super().__init__()
        self._data = data
        self._args = args

        if 'portfolio' == args['probtype']:
            layer_sizes = [data.pdim, self._args['hiddensize'], self._args['hiddensize']]

        layer_sizes += (args['nlayer'])*[self._args['hiddensize']] # (self._args['featNet_nlayer'] + 2)*[self._args['hiddensize']]
        
        layers = reduce(operator.add,
            #[[nn.Linear(a,b), nn.BatchNorm1d(b), nn.ReLU(), nn.Dropout(p=0.2)]
            [[nn.Linear(a,b), nn.ReLU()]
                for a,b in zip(layer_sizes[0:-1], layer_sizes[1:])])
        
        output_dim = data.ydim - data.nknowns

        if self._args['usecompl']:
            layers += [nn.Linear(layer_sizes[-1], output_dim - data.neq)]
        else:
            layers += [nn.Linear(layer_sizes[-1], output_dim)]

        for layer in layers:
            if type(layer) == nn.Linear:
                nn.init.kaiming_normal_(layer.weight)

        self.net = nn.Sequential(*layers)
         
    def forward(self, x):
        out = self.net(x)
 
        if self._args['usecompl']:
            if 'acopf' in self._args['probtype']:
                out = nn.Sigmoid()(out)   # used to interpolate between max and min values
            return self._data.complete_partial(x, out)
        else:
            return self._data.process_output(x, out)

if __name__=='__main__':
    main()
