#!/usr/bin/env python
# coding: utf-8
"""
Surrogate Loss function
"""

import numpy as np
import torch
from torch.autograd import Function

from pyepo import EPO
from pyepo.func.abcmodule import optModule
# from pyepo.func.utlis import _solve_or_cache
from pyepo.utlis import getArgs

def _solve_or_cache(cp, module):
    """
    A function to get optimization solution in the forward/backward pass
    """
    # solve optimization
    if np.random.uniform() <= module.solve_ratio:
        sol, obj = _solve_in_pass(cp, module.optmodel, module.processes, module.pool)
        if module.solve_ratio < 1:
            # add into solpool
            module._update_solution_pool(sol)
    # best cached solution
    else:
        sol, obj = _cache_in_pass(cp, module.optmodel, module.solpool)
    return sol, obj


def _solve_in_pass(cp, optmodel, processes, pool):
    """
    A function to solve optimization in the forward/backward pass
    """
    # number of instance
    ins_num = len(cp)
    # print(ins_num)    [bs]
    # print(cp.shape)   [bs, d]
    # single-core
    if processes == 1:
        sol = []
        obj = []
        for i in range(ins_num):
            # solve
            optmodel.setObj(cp[i])
            solp, objp = optmodel.solve()
            sol.append(solp)
            obj.append(objp)
        # to numpy
        sol = np.array(sol)
        obj = np.array(obj)
    # multi-core
    else:
        # get class
        model_type = type(optmodel)
        # get args
        args = getArgs(optmodel)
        # parallel computing
        res = pool.amap(_solveWithObj4Par, cp, [args] * ins_num,
                        [model_type] * ins_num).get()
        # get res
        sol = np.array(list(map(lambda x: x[0], res)))
        obj = np.array(list(map(lambda x: x[1], res)))
    return sol, obj


def _cache_in_pass(cp, optmodel, solpool):
    """
    A function to use solution pool in the forward/backward pass
    """
    # number of instance
    ins_num = len(cp)
    # best solution in pool
    solpool_obj = cp @ solpool.T
    if optmodel.modelSense == EPO.MINIMIZE:
        ind = np.argmin(solpool_obj, axis=1)
    if optmodel.modelSense == EPO.MAXIMIZE:
        ind = np.argmax(solpool_obj, axis=1)
    obj = np.take_along_axis(solpool_obj, ind.reshape(-1,1), axis=1).reshape(-1)
    sol = solpool[ind]
    return sol, obj

class SPOPlus(optModule):
    """
    An autograd module for SPO+ Loss, as a surrogate loss function of SPO
    (regret) Loss, which measures the decision error of the optimization problem.

    For SPO/SPO+ Loss, the objective function is linear and constraints are
    known and fixed, but the cost vector needs to be predicted from contextual
    data.

    The SPO+ Loss is convex with subgradient. Thus, it allows us to design an
    algorithm based on stochastic gradient descent.

    Reference: <https://doi.org/10.1287/mnsc.2020.3922>
    """

    def __init__(self, optmodel, processes=1, solve_ratio=1, reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # build carterion
        self.spop = SPOPlusFunc()

    def forward(self, pred_cost, true_cost, true_sol, true_obj):
        """
        Forward pass
        """
        loss = self.spop.apply(pred_cost, true_cost, true_sol, true_obj, self)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss


class SPOPlusFunc(Function):
    """
    A autograd function for SPO+ Loss
    """

    @staticmethod
    def forward(ctx, pred_cost, true_cost, true_sol, true_obj, module):
        """
        Forward pass for SPO+

        Args:
            pred_cost (torch.tensor): a batch of predicted values of the cost
            true_cost (torch.tensor): a batch of true values of the cost
            true_sol (torch.tensor): a batch of true optimal solutions
            true_obj (torch.tensor): a batch of true optimal objective values
            module (optModule): SPOPlus modeul

        Returns:
            torch.tensor: SPO+ loss
        """
        # get device
        device = pred_cost.device
        # convert tenstor
        cp = pred_cost.detach().to("cpu").numpy()
        c = true_cost.detach().to("cpu").numpy()
        w = true_sol.detach().to("cpu").numpy()
        z = true_obj.detach().to("cpu").numpy()
        # check sol
        #_check_sol(c, w, z)
        # solve
        # print(cp.shape)   [bs, d]
        sol, obj = _solve_or_cache(2 * cp - c, module)
        # calculate loss
        loss = []
        for i in range(len(cp)):
            loss.append(- obj[i] + 2 * np.dot(cp[i], w[i]) - z[i])
        # sense
        if module.optmodel.modelSense == EPO.MINIMIZE:
            loss = np.array(loss)
        elif module.optmodel.modelSense == EPO.MAXIMIZE:
            loss = - np.array(loss)
        else:
            raise ValueError("Invalid modelSense. Must be EPO.MINIMIZE or EPO.MAXIMIZE.")
        # convert to tensor
        loss = torch.tensor(loss, dtype=torch.float, device=device)
        sol = torch.tensor(sol, dtype=torch.float, device=device)
        # save solutions
        ctx.save_for_backward(true_sol, sol)
        # add other objects to ctx
        ctx.optmodel = module.optmodel
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for SPO+
        """
        w, wq = ctx.saved_tensors
        optmodel = ctx.optmodel
        if optmodel.modelSense == EPO.MINIMIZE:
            grad = 2 * (w - wq)
        elif optmodel.modelSense == EPO.MAXIMIZE:
            grad = 2 * (wq - w)
        else:
            raise ValueError("Invalid modelSense. Must be EPO.MINIMIZE or EPO.MAXIMIZE.")
        # print(grad_output.shape, grad.shape)
        return grad_output * grad, None, None, None, None   # [bs, 1] * [bs, d]


class perturbationGradient(optModule):
    """
    An autograd module for PG Loss, as a surrogate loss function of objective
    value, which measures the decision quality of the optimization problem.

    For PG Loss, the objective function is linear, and constraints are
    known and fixed, but the cost vector needs to be predicted from contextual
    data.

    According to Danskin’s Theorem, the PG Loss is derived from different zeroth
    order approximations and has the informative gradient. Thus, it allows us to
    design an algorithm based on stochastic gradient descent.

    Reference: <https://arxiv.org/abs/2402.03256>
    """
    def __init__(self, optmodel, sigma=0.1, two_sides=False, processes=1, solve_ratio=1,
                 reduction="mean", dataset=None):
        """
        Args:
            optmodel (optModel): an PyEPO optimization model
            sigma (float): the amplitude of the finite difference width used for loss approximation
            two_sides (bool): approximate gradient by two-sided perturbation or not
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            reduction (str): the reduction to apply to the output
            dataset (None/optDataset): the training data
        """
        super().__init__(optmodel, processes, solve_ratio, reduction, dataset)
        # finite difference width
        self.sigma = sigma
        # symmetric perturbation
        self.two_sides = two_sides

    def forward(self, pred_cost, true_cost):
        """
        Forward pass
        """
        loss = self._finiteDifference(pred_cost, true_cost)
        # reduction
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        elif self.reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(self.reduction))
        return loss

    def _finiteDifference(self, pred_cost, true_cost):
        """
        Zeroth order approximations for surrogate objective value
        """
        # get device
        device = pred_cost.device
        # convert tenstor
        cp = pred_cost.detach().to("cpu").numpy()
        c = true_cost.detach().to("cpu").numpy()
        # central differencing
        if self.two_sides:
            # solve
            wp, _ = _solve_or_cache(cp + self.sigma * c, self)
            wm, _ = _solve_or_cache(cp - self.sigma * c, self)
            # convert numpy
            sol_plus = torch.tensor(wp, dtype=torch.float, device=device)
            sol_minus = torch.tensor(wm, dtype=torch.float, device=device)
            # differentiable objective value
            obj_plus = torch.einsum("bi,bi->b", pred_cost + self.sigma * true_cost, sol_plus)
            obj_minus = torch.einsum("bi,bi->b", pred_cost - self.sigma * true_cost, sol_minus)
            # loss
            if self.optmodel.modelSense == EPO.MINIMIZE:
                loss = (obj_plus - obj_minus) / (2 * self.sigma)
            elif self.optmodel.modelSense == EPO.MAXIMIZE:
                loss = (obj_minus - obj_plus) / (2 * self.sigma)
            else:
                raise ValueError("Invalid modelSense. Must be EPO.MINIMIZE or EPO.MAXIMIZE.")
        # back differencing
        else:
            # solve
            w, _ = _solve_or_cache(cp, self)
            wm, _ = _solve_or_cache(cp - self.sigma * c, self)
            # convert numpy
            sol = torch.tensor(w, dtype=torch.float, device=device)
            sol_minus = torch.tensor(wm, dtype=torch.float, device=device)
            # differentiable objective value
            obj = torch.einsum("bi,bi->b", pred_cost, sol)
            obj_minus = torch.einsum("bi,bi->b", pred_cost - self.sigma * true_cost, sol_minus)
            # loss
            if self.optmodel.modelSense == EPO.MINIMIZE:
                loss = (obj - obj_minus) / self.sigma
            elif self.optmodel.modelSense == EPO.MAXIMIZE:
                loss = (obj_minus - obj) / self.sigma
            else:
                raise ValueError("Invalid modelSense. Must be EPO.MINIMIZE or EPO.MAXIMIZE.")
        return loss