#!/usr/bin/env python
# coding: utf-8

import gurobipy as gp
from gurobipy import GRB
import numpy as np
import pyepo
from pyepo.model.grb import optGrbModel
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from pyepo.func.abcmodule import optModule
from torch.autograd import Function
from pyepo import EPO
from pyepo.func.abcmodule import optModule
from pyepo.func.utlis import _solve_or_cache
from pyepo.func.abcmodule import optModule
from pyepo.data.dataset import optDataset
from pyepo.func.utlis import _solveWithObj4Par, _solve_in_pass

from pyepo.model.opt import optModel
from copy import copy
import copy as cpy2
from torch.optim.lr_scheduler import StepLR

from sklearn.model_selection import train_test_split
import pandas as pd
import sys

def regret_func(predmodel, x, c):
    """
    A function to evaluate model performance with normalized true regret

    Args:
        predmodel (nn): a regression neural network for cost prediction
        optmodel (optModel): an PyEPO optimization model
        dataloader (DataLoader): Torch dataloader from optDataSet

    Returns:
        float: true regret loss
    """
    # evaluate
    predmodel.eval()
    with torch.no_grad():  # no grad
        cp = predmodel(x).to("cpu").detach().numpy()

    sol_cp, obj_cp = shortest_path_solver(cp, 5)
    sol_c, obj_c = shortest_path_solver(c, 5)
    out_obj_cp = torch.sum(sol_cp * c)
    out_obj_c = torch.sum(sol_c * c)

    return torch.sum(out_obj_cp - out_obj_c).item()/(torch.sum(abs(out_obj_c)).item() + 1e-7)

def genData(num_data, num_features, grid, ep_type, deg=1, noise_width=0, seed=135):
    """
    A function to generate synthetic data and features for shortest path

    Args:
        num_data (int): number of data points
        num_features (int): dimension of features
        grid (int, int): size of grid network
        deg (int): data polynomial degree
        noise_width (float): half witdth of data random noise
        seed (int): random seed

    Returns:
       tuple: data features (np.ndarray), costs (np.ndarray)
    """
    # positive integer parameter
    if type(deg) is not int:
        raise ValueError("deg = {} should be int.".format(deg))
    if deg <= 0:
        raise ValueError("deg = {} should be positive.".format(deg))
    # set seed
    # rnd = np.random.RandomState(seed)
    # numbrnda points
    n = num_data
    # dimension of features
    p = num_features
    # dimension of the cost vector
    d = (grid[0] - 1) * grid[1] + (grid[1] - 1) * grid[0]
    # random matrix parameter B
    rnd = np.random.RandomState(1)
    B = rnd.binomial(1, 0.5, (d, p))
    # feature vectors
    rnd = np.random.RandomState(seed)
    x = rnd.normal(0, 1, (n, p))
    # cost vectors
    c = np.zeros((n, d))
    c_hat = np.zeros((n, d))
    for i in range(n):
        # cost without noise
        ci = (np.dot(B, x[i].reshape(p, 1)).T / np.sqrt(p) + 3) ** deg + 1
        # rescale
        ci /= 3.5 ** deg
        # noise
        if ep_type == 'unif':
            epsilon = rnd.uniform(1 - noise_width, 1 + noise_width, d)
            ci_hat = ci * epsilon
        elif ep_type == 'normal':
            epsilon = rnd.normal(0, noise_width, d)
            ci_hat = ci + epsilon
        # epsilon = (rnd.exponential(.5, d) - .5)
        c[i, :] = ci
        c_hat[i, :] = ci_hat

    return torch.FloatTensor(x), torch.FloatTensor(c), torch.FloatTensor(c_hat)


def shortest_path_solver(costs, size, sens = 1e-4):
    # Forward Pass
    starting_ind = 0
    starting_ind_c = 0
    samples = costs.shape[0]
    V_arr = torch.zeros(samples, size ** 2)
    for i in range(0, 2 * (size - 1)):
        num_nodes = min(i + 1, 9 - i)
        num_nodes_next = min(i + 2, 9 - i - 1)
        num_arcs = 2 * (max(num_nodes, num_nodes_next) - 1)
        V_1 = V_arr[:, starting_ind:starting_ind + num_nodes]
        layer_costs = costs[:, starting_ind_c:starting_ind_c + num_arcs]
        l_costs = layer_costs[:, 0::2]
        r_costs = layer_costs[:, 1::2]
        next_V_val_l = torch.ones(samples, num_nodes_next) * float('inf')
        next_V_val_r = torch.ones(samples, num_nodes_next) * float('inf')
        if num_nodes_next > num_nodes:
            next_V_val_l[:, :num_nodes_next - 1] = V_1 + l_costs
            next_V_val_r[:, 1:num_nodes_next] = V_1 + r_costs
        else:
            next_V_val_l = V_1[:, :num_nodes_next] + l_costs
            next_V_val_r = V_1[:, 1:num_nodes_next + 1] + r_costs
        next_V_val = torch.minimum(next_V_val_l, next_V_val_r)
        V_arr[:, starting_ind + num_nodes:starting_ind + num_nodes + num_nodes_next] = next_V_val

        starting_ind += num_nodes
        starting_ind_c += num_arcs

    # Backward Pass
    starting_ind = size ** 2
    starting_ind_c = costs.shape[1]
    prev_act = torch.ones(samples, 1)
    sol = torch.zeros(costs.shape)
    for i in range(2 * (size - 1), 0, -1):
        num_nodes = min(i + 1, 9 - i)
        num_nodes_next = min(i, 9 - i + 1)
        V_1 = V_arr[:, starting_ind - num_nodes:starting_ind]
        V_2 = V_arr[:, starting_ind - num_nodes - num_nodes_next:starting_ind - num_nodes]

        num_arcs = 2 * (max(num_nodes, num_nodes_next) - 1)
        layer_costs = costs[:, starting_ind_c - num_arcs: starting_ind_c]

        if num_nodes < num_nodes_next:
            l_cs_res = ((V_2[:, :num_nodes_next - 1] - V_1 + layer_costs[:, ::2]) < sens) * prev_act
            r_cs_res = ((V_2[:, 1:num_nodes_next] - V_1 + layer_costs[:, 1::2]) < sens) * prev_act
            prev_act = torch.zeros(V_2.shape)
            prev_act[:, :num_nodes_next - 1] += l_cs_res
            prev_act[:, 1:num_nodes_next] += r_cs_res
        else:
            l_cs_res = ((V_2 - V_1[:, :num_nodes - 1] + layer_costs[:, ::2]) < sens) * prev_act[:, :num_nodes - 1]
            r_cs_res = ((V_2 - V_1[:, 1:num_nodes] + layer_costs[:, 1::2]) < sens) * prev_act[:, 1:num_nodes]
            prev_act = torch.zeros(V_2.shape)
            prev_act += l_cs_res
            prev_act += r_cs_res
        cs = torch.zeros(layer_costs.shape)
        cs[:, ::2] = l_cs_res
        cs[:, 1::2] = r_cs_res
        sol[:, starting_ind_c - num_arcs: starting_ind_c] = cs

        starting_ind = starting_ind - num_nodes
        starting_ind_c = starting_ind_c - num_arcs
    # Dimension (samples, num edges)
    obj = torch.sum(sol * costs, axis=1)
    # Dimension (samples, 1)
    return sol, obj.reshape(-1,1)

# optimization model
class optGenModel(optModel):
    """
    This is an abstract class for Pyomo-based optimization model

    Attributes:
        _model (PyOmo model): Pyomo model
        solver (str): optimization solver in the background
    """

    def __init__(self):
        """
        Args:
            solver (str): optimization solver in the background
        """
        super().__init__()
        # init obj
        if self._model.modelSense == EPO.MINIMIZE:
            self.modelSense = EPO.MINIMIZE
        if self._model.modelSense == EPO.MAXIMIZE:
            self.modelSense = EPO.MAXIMIZE

    def __repr__(self):
        return "optGenModel " + self.__class__.__name__

    def setObj(self, c):
        """
        A method to set objective function

        Args:
            c (np.ndarray / list): cost of objective function
        """
        self._model.costvec = c

    def copy(self):
        """
        A method to copy model

        Returns:
            optModel: new copied model
        """
        new_model = copy(self)
        return new_model

    def addConstr(self):
        new_model = self.copy()
        # add constraint
        return new_model

class modelclass():
    def __init__(self, size):
        self.size = size
        self.costvec = None
        self.modelSense = EPO.MINIMIZE
        self.x = np.ones(2 * size * (size - 1))
class shortestPathModel(optGenModel):

    def __init__(self):
        self.grid = (5,5)
        super().__init__()

    def _getModel(self):
        """
        A method to build Gurobi model

        Returns:
            tuple: optimization model and variables
        """
        m = modelclass(self.grid[0])
        x = m.x
        # sense
        m.modelSense = EPO.MINIMIZE
        return m, x

    def solve(self):
        sol, obj = shortest_path_solver(self._model.costvec.reshape(-1,len(self.x)), self._model.size)
        return sol, obj

##### Surrogates #######
class SPOPlus2(optModule):
    """
    An autograd module for SPO+ Loss, as a surrogate loss function of SPO Loss,
    which measures the decision error of the optimization problem.

    For SPO/SPO+ Loss, the objective function is linear and constraints are
    known and fixed, but the cost vector needs to be predicted from contextual
    data.

    The SPO+ Loss is convex with subgradient. Thus, it allows us to design an
    algorithm based on stochastic gradient descent.

    Reference: <https://doi.org/10.1287/mnsc.2020.3922>
    """

    def __init__(self, optmodel, processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # build carterion
        self.spop = SPOPlusFunc()

    def forward(self, pred_cost, true_cost, true_sol, true_obj):
        """
        Forward pass
        """
        loss = self.spop.apply(pred_cost, true_cost, true_sol, true_obj, self)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss
class SPOPlusFunc(Function):
    """
    A autograd function for SPO+ Loss
    """

    @staticmethod
    def forward(ctx, pred_cost, true_cost, true_sol, true_obj, module):
        """
        Forward pass for SPO+

        Args:
            pred_cost (torch.tensor): a batch of predicted values of the cost
            true_cost (torch.tensor): a batch of true values of the cost
            true_sol (torch.tensor): a batch of true optimal solutions
            true_obj (torch.tensor): a batch of true optimal objective values
            module (optModule): SPOPlus modeul

        Returns:
            torch.tensor: SPO+ loss
        """
        # get device
        device = pred_cost.device
        # convert tenstor
        cp = pred_cost.detach().to("cpu")
        c = true_cost.detach().to("cpu")
        w = true_sol.detach().to("cpu")
        z = true_obj.detach().to("cpu")
        # check sol
        #_check_sol(c, w, z)
        # solve
        # sol, obj = _solve_or_cache(2 * cp - c, module)
        module.optmodel.setObj(2 * cp - c)
        sol, obj = module.optmodel.solve()
        # calculate loss
        loss = - obj + 2 * torch.sum(cp * w, axis = 1).reshape(-1,1) - z
        # sense
        if module.optmodel.modelSense == EPO.MAXIMIZE:
            loss = - loss
        # convert to tensor
        # save solutions
        ctx.save_for_backward(true_sol, sol)
        # add other objects to ctx
        ctx.optmodel = module.optmodel
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for SPO+
        """
        w, wq = ctx.saved_tensors
        optmodel = ctx.optmodel
        if optmodel.modelSense == EPO.MINIMIZE:
            grad = 2 * (w - wq)
        if optmodel.modelSense == EPO.MAXIMIZE:
            grad = 2 * (wq - w)
        return grad_output * grad, None, None, None, None


class PG_Loss(optModule):
    """
    An autograd module for SPO+ Loss, as a surrogate loss function of SPO Loss,
    which measures the decision error of the optimization problem.

    For SPO/SPO+ Loss, the objective function is linear and constraints are
    known and fixed, but the cost vector needs to be predicted from contextual
    data.

    The SPO+ Loss is convex with subgradient. Thus, it allows us to design an
    algorithm based on stochastic gradient descent.

    Reference: <https://doi.org/10.1287/mnsc.2020.3922>
    """

    def __init__(self, optmodel, h = 1, finite_diff_type='B', processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # build carterion
        self.spop = PGLossFunc()
        self.h = h
        self.finite_diff_type = finite_diff_type

    def forward(self, pred_cost, true_cost):
        """
        Forward pass
        """
        loss = self.spop.apply(pred_cost, true_cost, self.h, self.finite_diff_type, self)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss
class PGLossFunc(Function):
    """
    A autograd function for SPO+ Loss
    """

    @staticmethod
    def forward(ctx, pred_cost, true_cost, h, finite_diff_type, module):
        """
        Forward pass for SPO+

        Args:
            pred_cost (torch.tensor): a batch of predicted values of the cost
            true_cost (torch.tensor): a batch of true values of the cost
            true_sol (torch.tensor): a batch of true optimal solutions
            true_obj (torch.tensor): a batch of true optimal objective values
            module (optModule): SPOPlus modeul

        Returns:
            torch.tensor: SPO+ loss
        """
        # get device
        device = pred_cost.device
        # convert tenstor
        cp = pred_cost.detach().to("cpu")
        c = true_cost.detach().to("cpu")

        if finite_diff_type == 'C':
            cp_plus = cp + h * c
            cp_minus = cp - h * c
            step_size = 1 / (2 * h)
        elif finite_diff_type == 'B':
            cp_plus = cp
            cp_minus = cp - h * c
            step_size = 1 / h
        elif finite_diff_type == 'F':
            cp_plus = cp + h * c
            cp_minus = cp
            step_size = 1 / h

        # check sol
        #_check_sol(c, w, z)
        # solve
        # sol, obj = _solve_or_cache(2 * cp - c, module)
        module.optmodel.setObj(cp_plus)
        sol_plus, obj_plus = module.optmodel.solve()
        module.optmodel.setObj(cp_minus)
        sol_minus, obj_minus = module.optmodel.solve()
        # calculate loss
        loss = (obj_plus - obj_minus) * step_size
        # sense
        if module.optmodel.modelSense == EPO.MAXIMIZE:
            loss = - loss
        # convert to tensor
        # save solutions
        ctx.save_for_backward(sol_plus, sol_minus)
        # add other objects to ctx
        ctx.optmodel = module.optmodel
        ctx.step_size = step_size
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for SPO+
        """
        sol_plus, sol_minus = ctx.saved_tensors
        optmodel = ctx.optmodel
        step_size = ctx.step_size

        grad = step_size * (sol_plus - sol_minus)

        # if optmodel.modelSense == EPO.MINIMIZE:
        #     grad = 2 * (w - wq)
        # if optmodel.modelSense == EPO.MAXIMIZE:
        #     grad = 2 * (wq - w)
        return grad_output * grad, None, None, None, None


class DCA_PG_Loss(optModule):
    """
    An autograd module for SPO+ Loss, as a surrogate loss function of SPO Loss,
    which measures the decision error of the optimization problem.

    For SPO/SPO+ Loss, the objective function is linear and constraints are
    known and fixed, but the cost vector needs to be predicted from contextual
    data.

    The SPO+ Loss is convex with subgradient. Thus, it allows us to design an
    algorithm based on stochastic gradient descent.

    Reference: <https://doi.org/10.1287/mnsc.2020.3922>
    """

    def __init__(self, optmodel, h = 1, finite_diff_type='B', processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # build carterion
        self.spop = DCAPGLossFunc()
        self.h = h
        self.finite_diff_type = finite_diff_type

    def forward(self, pred_cost, pred_cost_0, true_cost):
        """
        Forward pass
        """
        loss = self.spop.apply(pred_cost, pred_cost_0, true_cost, self.h, self.finite_diff_type, self)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss
class DCAPGLossFunc(Function):
    """
    A autograd function for SPO+ Loss
    """

    @staticmethod
    def forward(ctx, pred_cost, pred_cost_0, true_cost, h, finite_diff_type, module):
        """
        Forward pass for SPO+

        Args:
            pred_cost (torch.tensor): a batch of predicted values of the cost
            true_cost (torch.tensor): a batch of true values of the cost
            true_sol (torch.tensor): a batch of true optimal solutions
            true_obj (torch.tensor): a batch of true optimal objective values
            module (optModule): SPOPlus modeul

        Returns:
            torch.tensor: SPO+ loss
        """
        # get device
        device = pred_cost.device
        # convert tenstor
        cp = pred_cost.detach().to("cpu")
        cp_0 = pred_cost_0.detach().to("cpu")
        c = true_cost.detach().to("cpu")

        if finite_diff_type == 'C':
            cp_plus = cp_0 + h * c
            cp_minus = cp - h * c
            step_size = 1 / (2 * h)
        elif finite_diff_type == 'B':
            cp_plus = cp_0
            cp_minus = cp - h * c
            step_size = 1 / h
        elif finite_diff_type == 'F':
            cp_plus = cp_0 + h * c
            cp_minus = cp
            step_size = 1 / h

        # check sol
        #_check_sol(c, w, z)
        # solve
        # sol, obj = _solve_or_cache(2 * cp - c, module)
        module.optmodel.setObj(cp_plus)
        sol_plus, obj_plus = module.optmodel.solve()
        module.optmodel.setObj(cp_minus)
        sol_minus, obj_minus = module.optmodel.solve()
        obj_plus_0 = torch.sum(sol_plus * cp, axis = 1).reshape(-1, 1)
        # calculate loss
        loss = (obj_plus_0 - obj_minus) * step_size
        # sense
        if module.optmodel.modelSense == EPO.MAXIMIZE:
            loss = - loss
        # convert to tensor
        # save solutions
        ctx.save_for_backward(sol_plus, sol_minus)
        # add other objects to ctx
        ctx.optmodel = module.optmodel
        ctx.step_size = step_size
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for SPO+
        """
        sol_plus, sol_minus = ctx.saved_tensors
        optmodel = ctx.optmodel
        step_size = ctx.step_size

        grad = step_size * (sol_plus - sol_minus)

        # if optmodel.modelSense == EPO.MINIMIZE:
        #     grad = 2 * (w - wq)
        # if optmodel.modelSense == EPO.MAXIMIZE:
        #     grad = 2 * (wq - w)
        return grad_output * grad, None, None, None, None, None

class listwiseLTR(optModule):
    """
    An autograd module for listwise learning to rank, where the goal is to learn
    an objective function that ranks a pool of feasible solutions correctly.

    For the listwise LTR, the cost vector needs to be predicted from the
    contextual data and the loss measures the scores of the whole ranked lists.

    Thus, it allows us to design an algorithm based on stochastic gradient
    descent.

    Reference: <https://proceedings.mlr.press/v162/mandi22a.html>
    """

    def __init__(self, optmodel, processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (optDataset): the training data, usually this is simply the training set
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # solution pool
        if not isinstance(dataset, optDataset): # type checking
            raise TypeError("dataset is not an optDataset")
        w = dataset.sols.copy()
        w = dataset.sols.copy().reshape(w.shape[0], -1)
        self.solpool = np.unique(w, axis=0) # remove duplicate

    def forward(self, pred_cost, true_cost):
        """
        Forward pass
        """
        # get device
        device = pred_cost.device
        # convert tensor
        cp = pred_cost.detach().to("cpu")
        # solve
        if np.random.uniform() <= self.solve_ratio:
            self.optmodel.setObj(cp)
            sol, _ = self.optmodel.solve()
            # sol, _ = _solve_in_pass(cp, self.optmodel, self.processes, self.pool)
            # add into solpool
            self._update_solution_pool(sol)
        # convert tensor
        solpool = torch.from_numpy(self.solpool.astype(np.float32)).to(device)
        # obj for solpool
        objpool_c = true_cost @ solpool.T # true cost
        objpool_cp = pred_cost @ solpool.T # pred cost
        # cross entropy loss
        if self.optmodel.modelSense == EPO.MINIMIZE:
            loss = - (F.log_softmax(objpool_cp, dim=1) *
                      F.softmax(objpool_c, dim=1))
        if self.optmodel.modelSense == EPO.MAXIMIZE:
            loss = - (F.log_softmax(- objpool_cp, dim=1) *
                      F.softmax(- objpool_c, dim=1))
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss


class pairwiseLTR(optModule):
    """
    An autograd module for pairwise learning to rank, where the goal is to learn
    an objective function that ranks a pool of feasible solutions correctly.

    For the pairwise LTR, the cost vector needs to be predicted from the
    contextual data and the loss learns the relative ordering of pairs of items.

    Thus, it allows us to design an algorithm based on stochastic gradient
    descent.

    Reference: <https://proceedings.mlr.press/v162/mandi22a.html>
    """

    def __init__(self, optmodel, processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # solution pool
        if not isinstance(dataset, optDataset): # type checking
            raise TypeError("dataset is not an optDataset")
        w = dataset.sols.copy()
        w = dataset.sols.copy().reshape(w.shape[0], -1)
        self.solpool = np.unique(w, axis=0)  # remove duplicate# remove duplicate

    def forward(self, pred_cost, true_cost):
        """
        Forward pass
        """
        # get device
        device = pred_cost.device
        # convert tensor
        cp = pred_cost.detach().to("cpu").numpy()
        # solve
        if np.random.uniform() <= self.solve_ratio:
            self.optmodel.setObj(cp)
            sol, _ = self.optmodel.solve()
            # sol, _ = _solve_in_pass(cp, self.optmodel, self.processes, self.pool)
            # add into solpool
            self._update_solution_pool(sol)
        # convert tensor
        solpool = torch.from_numpy(self.solpool.astype(np.float32)).to(device)
        # obj for solpool
        objpool_c = torch.einsum("bd,nd->bn", true_cost, solpool) # true cost
        objpool_cp = torch.einsum("bd,nd->bn", pred_cost, solpool) # pred cost
        # init relu as max(0,x)
        relu = nn.ReLU()
        # init loss
        loss = []
        for i in range(len(pred_cost)):
            # best sol
            if self.optmodel.modelSense == EPO.MINIMIZE:
                best_ind = torch.argmin(objpool_c[i])
            if self.optmodel.modelSense == EPO.MAXIMIZE:
                best_ind = torch.argmax(objpool_c[i])
            objpool_cp_best = objpool_cp[i, best_ind]
            # rest sol
            rest_ind = [j for j in range(len(objpool_cp[i])) if j != best_ind]
            objpool_cp_rest = objpool_cp[i, rest_ind]
            # best vs rest loss
            if self.optmodel.modelSense == EPO.MINIMIZE:
                loss.append(relu(objpool_cp_best - objpool_cp_rest).mean())
            if self.optmodel.modelSense == EPO.MAXIMIZE:
                loss.append(relu(objpool_cp_rest - objpool_cp_best).mean())
        loss = torch.stack(loss)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss


class pointwiseLTR(optModule):
    """
    An autograd module for pointwise learning to rank, where the goal is to
    learn an objective function that ranks a pool of feasible solutions
    correctly.

    For the pointwise LTR, the cost vector needs to be predicted from contextual
    data, and calculates the ranking scores of the items.

    Thus, it allows us to design an algorithm based on stochastic gradient
    descent.

    Reference: <https://proceedings.mlr.press/v162/mandi22a.html>
    """

    def __init__(self, optmodel, processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # solution pool
        if not isinstance(dataset, optDataset): # type checking
            raise TypeError("dataset is not an optDataset")
        w = dataset.sols.copy()
        w = dataset.sols.copy().reshape(w.shape[0], -1)
        self.solpool = np.unique(w, axis=0)  # remove duplicate

    def forward(self, pred_cost, true_cost):
        """
        Forward pass
        """
        # get device
        device = pred_cost.device
        # convert tensor
        cp = pred_cost.detach().to("cpu").numpy()
        # solve
        if np.random.uniform() <= self.solve_ratio:
            self.optmodel.setObj(cp)
            sol, _ = self.optmodel.solve()
            # sol, _ = _solve_in_pass(cp, self.optmodel, self.processes, self.pool)
            # add into solpool
            self._update_solution_pool(sol)
        # convert tensor
        solpool = torch.from_numpy(self.solpool.astype(np.float32)).to(device)
        # obj for solpool as score
        objpool_c = true_cost @ solpool.T # true cost
        objpool_cp = pred_cost @ solpool.T # pred cost
        # squared loss
        loss = (objpool_c - objpool_cp).square().mean(axis=1)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss

class NCE(optModule):
    """
    An autograd module for noise contrastive estimation as surrogate loss
    functions, based on viewing suboptimal solutions as negative examples.

    For the NCE, the cost vector needs to be predicted from contextual data and
    maximizes the separation of the probability of the optimal solution.

    Thus allows us to design an algorithm based on stochastic gradient descent.

    Reference: <https://www.ijcai.org/proceedings/2021/390>
    """

    def __init__(self, optmodel, processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data, usually this is simply the training set
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # solution pool
        if not isinstance(dataset, optDataset): # type checking
            raise TypeError("dataset is not an optDataset")
        w = dataset.sols.copy()
        w = dataset.sols.copy().reshape(w.shape[0], -1)
        self.solpool = np.unique(w, axis=0)  # remove duplicate

    def forward(self, pred_cost, true_sol):
        """
        Forward pass
        """
        # get device
        device = pred_cost.device
        # convert tensor
        cp = pred_cost.detach().to("cpu").numpy()
        # solve
        if np.random.uniform() <= self.solve_ratio:
            self.optmodel.setObj(cp)
            sol, _ = self.optmodel.solve()
            # sol, _ = _solve_in_pass(cp, self.optmodel, self.processes, self.pool)
            # add into solpool
            self._update_solution_pool(sol)
        solpool = torch.from_numpy(self.solpool.astype(np.float32)).to(device)
        # get current obj
        obj_cp = torch.einsum("bd,bd->b", pred_cost, true_sol).unsqueeze(1)
        # get obj for solpool
        objpool_cp = torch.einsum("bd,nd->bn", pred_cost, solpool)
        # get loss
        if self.optmodel.modelSense == EPO.MINIMIZE:
            loss = (obj_cp - objpool_cp).mean(axis=1)
        if self.optmodel.modelSense == EPO.MAXIMIZE:
            loss = (objpool_cp - obj_cp).mean(axis=1)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss


class contrastiveMAP(optModule):
    """
    An autograd module for Maximum A Posterior contrastive estimation as
    surrogate loss functions, which is an efficient self-contrastive algorithm.

    For the MAP, the cost vector needs to be predicted from contextual data and
    maximizes the separation of the probability of the optimal solution.

    Thus, it allows us to design an algorithm based on stochastic gradient descent.

    Reference: <https://www.ijcai.org/proceedings/2021/390>
    """

    def __init__(self, optmodel, processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data, usually this is simply the training set
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # solution pool
        if not isinstance(dataset, optDataset): # type checking
            raise TypeError("dataset is not an optDataset")
        w = dataset.sols.copy()
        w = dataset.sols.copy().reshape(w.shape[0], -1)
        self.solpool = np.unique(w, axis=0)  # remove duplicate

    def forward(self, pred_cost, true_sol):
        """
        Forward pass
        """
        # get device
        device = pred_cost.device
        # convert tensor
        cp = pred_cost.detach().to("cpu").numpy()
        # solve
        if np.random.uniform() <= self.solve_ratio:
            self.optmodel.setObj(cp)
            sol, _ = self.optmodel.solve()
            # sol, _ = _solve_in_pass(cp, self.optmodel, self.processes, self.pool)
            # add into solpool
            self._update_solution_pool(sol)
        solpool = torch.from_numpy(self.solpool.astype(np.float32)).to(device)
        # get current obj
        obj_cp = torch.einsum("bd,bd->b", pred_cost, true_sol).unsqueeze(1)
        # get obj for solpool
        objpool_cp = torch.einsum("bd,nd->bn", pred_cost, solpool)
        # get loss
        if self.optmodel.modelSense == EPO.MINIMIZE:
            loss, _ = (obj_cp - objpool_cp).max(axis=1)
        if self.optmodel.modelSense == EPO.MAXIMIZE:
            loss, _ = (objpool_cp - obj_cp).max(axis=1)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss

class perturbedFenchelYoung(optModule):
    """
    An autograd module for Fenchel-Young loss using perturbation techniques. The
    use of the loss improves the algorithmic by the specific expression of the
    gradients of the loss.

    For the perturbed optimizer, the cost vector need to be predicted from
    contextual data and are perturbed with Gaussian noise.

    The Fenchel-Young loss allows to directly optimize a loss between the features
    and solutions with less computation. Thus, allows us to design an algorithm
    based on stochastic gradient descent.

    Reference: <https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html>
    """

    def __init__(self, optmodel, n_samples=10, sigma=1.0, processes=1,
                 seed=135, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            n_samples (int): number of Monte-Carlo samples
            sigma (float): the amplitude of the perturbation
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            seed (int): random state seed
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # number of samples
        self.n_samples = n_samples
        # perturbation amplitude
        self.sigma = sigma
        # random state
        self.rnd = np.random.RandomState(seed)
        # build optimizer
        self.pfy = perturbedFenchelYoungFunc()

    def forward(self, pred_cost, true_sol):
        """
        Forward pass
        """
        loss = self.pfy.apply(pred_cost, true_sol, self)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss


class perturbedFenchelYoungFunc(Function):
    """
    A autograd function for Fenchel-Young loss using perturbation techniques.
    """

    @staticmethod
    def forward(ctx, pred_cost, true_sol, module):
        """
        Forward pass for perturbed Fenchel-Young loss

        Args:
            pred_cost (torch.tensor): a batch of predicted values of the cost
            true_sol (torch.tensor): a batch of true optimal solutions
            module (optModule): perturbedFenchelYoung module

        Returns:
            torch.tensor: solution expectations with perturbation
        """
        # get device
        device = pred_cost.device
        # convert tenstor
        cp = pred_cost.detach().to("cpu").numpy()
        w = true_sol.detach().to("cpu")
        # sample perturbations
        noises = module.rnd.normal(0, 1, size=(module.n_samples, *cp.shape))

        ptb_c = cp + module.sigma * noises
        ptb_c = ptb_c.reshape(-1, noises.shape[2])
        # solve with perturbation
        # ptb_sols, ptb_obj = _solve_or_cache(ptb_c, module)
        module.optmodel.setObj(ptb_c)
        ptb_sols, ptb_obj = module.optmodel.solve()

        ptb_sols = ptb_sols.reshape(module.n_samples, -1, ptb_sols.shape[1])
        # solution expectation
        e_sol = ptb_sols.mean(axis=0)

        # ptb_c = cp + module.sigma * noises
        # solve with perturbation
        # ptb_sols = _solve_or_cache(ptb_c, module)
        # solution expectation
        # e_sol = ptb_sols.mean(axis=1)
        # difference
        if module.optmodel.modelSense == EPO.MINIMIZE:
            diff = w - e_sol
        if module.optmodel.modelSense == EPO.MAXIMIZE:
            diff = e_sol - w
        # loss
        loss = torch.sum(diff**2, axis=1)
        # convert to tensor
        diff = torch.FloatTensor(diff).to(device)
        loss = torch.FloatTensor(loss).to(device)
        # save solutions
        ctx.save_for_backward(diff)
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for perturbed Fenchel-Young loss
        """
        grad, = ctx.saved_tensors
        grad_output = torch.unsqueeze(grad_output, dim=-1)
        return grad * grad_output, None, None



# prediction model
class LinearRegression(nn.Module):

    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(num_feat, 40)

    def forward(self, x):
        out = self.linear(x)
        return out

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform(m.weight.data)


def trainModel(reg, loss_func, loss_name, optmodel, loader_train, val_x, val_c, test_x, test_c, trial, num_data, use_gpu=False, num_epochs=100, lr=1e-2,
               h_schedule=False, lr_schedule=False, early_stopping_cfg=None):
    # set adam optimizer
    optimizer = torch.optim.Adam(reg.parameters(), lr=lr)

    if lr_schedule == True:
        scheduler = StepLR(optimizer, step_size=20, gamma=0.1)  # Define scheduler

    # train mode
    reg.train()
    # init log
    loss_log = [['trial', 'n', 'epoch', 'rand_start', 'h', 'loss_name', 'regret', 'val_regret']]
    if 'PG' in loss_name:
        num_rand_starts = 1
        h = loss_func.h
    elif loss_name == "DCA":
        num_rand_starts = 1
        h = loss_func.h
    else:
        num_rand_starts = 1
        h = 0
    # init elpased time
    with torch.no_grad():
        predmodel_0 = LinearRegression()
        predmodel_0.linear.weight.copy_(torch.nn.Parameter(reg.linear.weight.detach(), requires_grad=False))
    for r in range(num_rand_starts):
        # reg.apply(weights_init)
        for epoch in range(num_epochs):
            # start timing

            # if epoch % 10 == 0 and h_schedule == True:
            #     loss_func.h = loss_func.h / 2
            #     print("h: ", loss_func.h)

            # load data
            batch_loss = [0]
            if epoch % 30 == 0:
                with torch.no_grad():
                    predmodel_0 = LinearRegression()
                    predmodel_0.linear.weight.copy_(torch.nn.Parameter(reg.linear.weight.detach(), requires_grad=False))
                    # predmodel_0.linear.weight = torch.nn.Parameter(reg.linear.weight.detach(), requires_grad=False)
                    predmodel_0.eval()
            for i, data in enumerate(loader_train):
                x, c, w, z = data
                w = w.reshape(w.shape[0], -1)
                z = z.reshape(z.shape[0], -1)
                with torch.no_grad():
                    predmodel_0.eval()
                    cp_0 = predmodel_0(x)
                # cuda
                if use_gpu == True:
                    x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()
                # forward pass
                cp = reg(x)

                if loss_name in ['SPO+']:
                    loss = loss_func(cp, c, w, z)
                elif loss_name in ['PGB', 'PGF', 'PGC', 'MSE', 'DBB', 'LTR_pair', 'LTR_point', 'LTR_list']:
                    loss = loss_func(cp, c)
                elif loss_name in ['DCA']:
                    loss = loss_func(cp, cp_0, c)
                elif loss_name in ['FYL']:
                    loss = loss_func(cp, w)


                # backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                batch_loss.append(loss.item())

            train_regret = sum(batch_loss)/num_data
            regret = regret_func(reg, torch.FloatTensor(test_x), torch.FloatTensor(test_c))
            val_regret = regret_func(reg, torch.FloatTensor(val_x), torch.FloatTensor(val_c))

            loss_log.append([trial, num_data, epoch, r, h, loss_name, regret, val_regret])

            if lr_schedule == True:
                scheduler.step()

            print(
                "Epoch {:2},  Train_Regret: {:7.4f}%, Val_Regret: {:7.4f}%, Regret: {:7.4f}%".format(epoch + 1, train_regret * 100,
                                                                                                   val_regret * 100,
                                                                                                   regret * 100))

    return loss_log, predmodel


if __name__ == "__main__":
    torch.manual_seed(105)
    indices_arr = torch.randperm(100000)
    indices_arr_test = torch.randperm(100000)

    sim = int(sys.argv[1])
    # sim = 0

    n_arr = [200, 400, 800, 1600]
    ep_arr = ['unif', 'normal']
    trials = 100

    exp_arr = []
    for n in n_arr:
        for ep in ep_arr:
            for t in range(trials):
                exp_arr.append([n, ep, t])

    exp = exp_arr[sim]
    ep_type = exp[1]
    trial = exp[2]
    # trial = int(sys.argv[1])
    # for exp in exp_arr:
    #     num_data = exp[0]
    #     ep_type = exp[1]
    #     trial = exp[2]

    # for trial in range(50):

    # generate data
    grid = (5, 5)  # grid size
    num_data = exp[0]  # number of training data
    num_feat = 5  # size of feature
    deg = 6  # polynomial degree
    e = 0.3  # noise width
    feat, cost_true, cost = genData(num_data + 200, num_feat, grid, ep_type, deg, e, seed=indices_arr[trial])
    x_train, x_val, c_train, c_val = train_test_split(feat, cost, test_size=200, random_state=42)
    # x_val, x_test, c_val, c_test = train_test_split(x_1, c_1, cost_true,test_size=1000, random_state=42)

    x_test, c_test, c_hat_test = genData(10000, num_feat, grid, ep_type, deg, e, seed=indices_arr_test[trial])

    # init optimization model
    optmodel = shortestPathModel()
    # build dataset
    dataset = pyepo.data.dataset.optDataset(optmodel, x_train, c_train)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # LTR

    # SPO+
    print("SPO+")
    spop_loss_func = SPOPlus2(optmodel)
    # init prediction model
    predmodel = LinearRegression()
    spop_out, spop_reg = trainModel(predmodel, spop_loss_func, 'SPO+', optmodel, dataloader, x_val, c_val, x_test, c_test, trial, num_data,
                        use_gpu=False, num_epochs=100, lr=1e-2,
                        h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
    spop_reg_pgb = cpy2.deepcopy(spop_reg)
    spop_reg_pgf = cpy2.deepcopy(spop_reg)
    spop_reg_pgc = cpy2.deepcopy(spop_reg)
    spop_reg_dca = cpy2.deepcopy(spop_reg)
    spop_df = pd.DataFrame(columns=spop_out[0], data=spop_out[1:])

    # FYL
    print("FYL")
    fy_loss_func = perturbedFenchelYoung(optmodel)
    # init prediction model
    predmodel = LinearRegression()
    fy_out, fy_reg = trainModel(predmodel, fy_loss_func, 'FYL', optmodel, dataloader, x_val, c_val, x_test, c_test, trial, num_data,
                        use_gpu=False, num_epochs=100, lr=1e-2,
                        h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
    fy_df = pd.DataFrame(columns=fy_out[0], data=fy_out[1:])

    # DBB
    print("DBB")
    dbb_loss_func = PG_Loss(optmodel, h = 10, finite_diff_type='F')
    # init prediction model
    predmodel = LinearRegression()
    dbb_out, dbb_reg = trainModel(predmodel, dbb_loss_func, 'DBB', optmodel, dataloader, x_val, c_val, x_test, c_test,
                                trial, num_data,
                                use_gpu=False, num_epochs=100, lr=1e-2,
                                h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
    dbb_df = pd.DataFrame(columns=dbb_out[0], data=dbb_out[1:])

    # LTR_list
    print("LTR_list")
    LTR_list_loss_func = listwiseLTR(optmodel, dataset = dataset)
    # init prediction model
    predmodel = LinearRegression()
    LTR_list_out, LTR_list_reg = trainModel(predmodel, LTR_list_loss_func, 'LTR_list', optmodel, dataloader, x_val, c_val, x_test,
                                  c_test,
                                  trial, num_data,
                                  use_gpu=False, num_epochs=100, lr=1e-2,
                                  h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
    LTR_list_df = pd.DataFrame(columns=LTR_list_out[0], data=LTR_list_out[1:])

    # LTR_pair
    print("LTR_pair")
    LTR_pair_loss_func = pairwiseLTR(optmodel, dataset = dataset)
    # init prediction model
    predmodel = LinearRegression()
    LTR_pair_out, LTR_pair_reg = trainModel(predmodel, LTR_pair_loss_func, 'LTR_pair', optmodel, dataloader, x_val, c_val, x_test,
                                  c_test,
                                  trial, num_data,
                                  use_gpu=False, num_epochs=100, lr=1e-2,
                                  h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
    LTR_pair_df = pd.DataFrame(columns=LTR_pair_out[0], data=LTR_pair_out[1:])

    # LTR_point
    print("LTR_point")
    LTR_point_loss_func = pointwiseLTR(optmodel, dataset = dataset)
    # init prediction model
    predmodel = LinearRegression()
    LTR_point_out, LTR_point_reg = trainModel(predmodel, LTR_point_loss_func, 'LTR_point', optmodel, dataloader, x_val, c_val, x_test,
                                  c_test,
                                  trial, num_data,
                                  use_gpu=False, num_epochs=100, lr=1e-2,
                                  h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
    LTR_point_df = pd.DataFrame(columns=LTR_point_out[0], data=LTR_point_out[1:])


    # MSE
    print("MSE")
    mse_loss_func = nn.MSELoss()
    # init prediction model
    predmodel = LinearRegression()
    mse_out, mse_reg = trainModel(predmodel, mse_loss_func, 'MSE', optmodel, dataloader, x_val, c_val, x_test, c_test, trial, num_data,
                        use_gpu=False, num_epochs=100, lr=1e-2,
                        h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
    mse_df = pd.DataFrame(columns=mse_out[0], data=mse_out[1:])

    df_arr = [spop_df, fy_df, mse_df, dbb_df, LTR_point_df, LTR_pair_df, LTR_list_df]

    # PGLoss
    print("PG Loss")
    rand_starts = 1
    h_arr = [num_data**-.125, num_data**-.25, num_data**-.5, num_data**-1]
    # h_arr = [num_data ** -.25]
    for i in range(rand_starts):
        for h in h_arr:
            # PGB
            pgb_loss_func = PG_Loss(optmodel, h=h, finite_diff_type='B')
            # init prediction model
            predmodel_b = LinearRegression()
            predmodel_b.linear.weight = spop_reg_pgb.linear.weight

            pgb_out, pgb_reg = trainModel(predmodel_b, pgb_loss_func, 'PGB', optmodel, dataloader, x_val, c_val, x_test, c_test, trial, num_data,
                                use_gpu=False, num_epochs=50, lr=1e-2,
                                h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
            pgb_df = pd.DataFrame(columns=pgb_out[0], data=pgb_out[1:])
            pgb_df['rand_start'] = i
            df_arr.append(pgb_df)

            # PGF
            pgf_loss_func = PG_Loss(optmodel, h=h, finite_diff_type='F')
            # init prediction model
            predmodel_f = LinearRegression()
            predmodel_f.linear.weight = spop_reg_pgb.linear.weight

            pgf_out, pgf_reg = trainModel(predmodel_f, pgf_loss_func, 'PGF', optmodel, dataloader, x_val, c_val,
                                          x_test, c_test, trial, num_data,
                                          use_gpu=False, num_epochs=50, lr=1e-2,
                                          h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
            pgf_df = pd.DataFrame(columns=pgf_out[0], data=pgf_out[1:])
            pgf_df['rand_start'] = i
            df_arr.append(pgf_df)

            # PGC
            pgc_loss_func = PG_Loss(optmodel, h=h, finite_diff_type='C')
            # init prediction model
            predmodel_c = LinearRegression()
            predmodel_c.linear.weight = spop_reg_pgb.linear.weight

            pgc_out, pgc_reg = trainModel(predmodel_c, pgb_loss_func, 'PGC', optmodel, dataloader, x_val, c_val,
                                          x_test, c_test, trial, num_data,
                                          use_gpu=False, num_epochs=50, lr=1e-2,
                                          h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
            pgc_df = pd.DataFrame(columns=pgc_out[0], data=pgc_out[1:])
            pgc_df['rand_start'] = i
            df_arr.append(pgc_df)

    # DCA
    print("DCA")
    # h_arr = [0.5]
    for h in h_arr:
        pg_loss_func = DCA_PG_Loss(optmodel, h=h, finite_diff_type='B')
        # predmodel = cpy2.deepcopy(spop_reg)
        predmodel = LinearRegression()
        predmodel.linear.weight = spop_reg_dca.linear.weight
        # predmodel.linear.weight.copy_(torch.nn.Parameter(spop_reg.linear.weight.detach()))
        dca_out, dca_reg = trainModel(predmodel, pg_loss_func, 'DCA', optmodel, dataloader, x_val, c_val, x_test,
                                        c_test, trial, num_data,
                                        use_gpu=False, num_epochs=300, lr=1e-2,
                                        h_schedule=False, lr_schedule=False, early_stopping_cfg=None)
        dca_df = pd.DataFrame(columns=dca_out[0], data=dca_out[1:])
        df_arr.append(dca_df)

    df_all = pd.concat(df_arr)
    df_all["ep_type"] = ep_type

    df_all.to_csv("sp_experiment_" + str(sim) + ".csv", index=False)
    # set optimizer
    # optimizer = torch.optim.Adam(predmodel.parameters(), lr=1e-2)




    # init SPO+ loss
    # spop = SPOPlus2(optmodel, processes=1)
    # spop = PG_Loss(optmodel, h=num_data**-0.25)
    # spop = perturbedFenchelYoung(optmodel)




    # # training
    # num_epochs = 100
    # num_rand_starts = 10
    #
    # for r in range(num_rand_starts):
    #     predmodel.apply(weights_init)
    #     for epoch in range(num_epochs):
    #         for data in dataloader:
    #             x, c, w, z = data
    #             w = w.reshape(w.shape[0], -1)
    #             z = z.reshape(z.shape[0], -1)
    #             # forward pass
    #             cp = predmodel(x)
    #             # loss = spop(cp, c, w, z)
    #             # loss = spop(cp, w)
    #             loss = spop(cp, c)
    #             # backward pass
    #             optimizer.zero_grad()
    #             loss.backward()
    #             optimizer.step()
    #         # regret = pyepo.metric.regret(predmodel, optmodel, dataloader)
    #         # regret = regret[0]
    #         regret = regret_func(predmodel, torch.FloatTensor(feat), torch.FloatTensor(cost))
    #         # regret_3 = 0
    #         print("Epoch {:2},  Loss: {:9.4f},  Regret: {:7.4f}%".format(epoch + 1, loss.item(), regret * 100))
    # # eval
    #     regret = pyepo.metric.regret(predmodel, optmodel, dataloader)
    #     print("Regret on Training Set: {:.4f}".format(regret[0]))