#!/usr/bin/env python
# coding: utf-8
"""
CSPO+ Loss function
"""

import numpy as np
import torch
from torch.autograd import Function

from pyepo import EPO
from pyepo.func.abcmodule import optModule, cspo_optModule
from pyepo.func.utlis import _solveWithObj4Par, _solve_in_pass, _cache_in_pass, _cspo_solve_in_pass

class CSPOPlus(cspo_optModule):
    """
    An autograd module for CSPO+ Loss, as a surrogate loss function of CSPO Loss,
    which measures the decision error of the optimization problem.

    For SPO/SPO+ Loss, the objective function is linear and constraints are
    known and fixed, but the cost vector needs to be predicted from contextual
    data.

    The SPO+ Loss is convex with subgradient. Thus, it allows us to design an
    algorithm based on stochastic gradient descent.
    """

    def __init__(self, optmodel_list, processes=1, solve_ratio=1, dataset=None, warm_start=False):
        """
        Args:
            optmodel_list (optModel): a list of PyEPO optimization models
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            dataset (None/optDataset): the training data
        """
        super().__init__(optmodel_list, processes, solve_ratio, dataset)
        # build carterion
        self.spop = CSPOPlusFunc()
        self.warm_start = warm_start

    def forward(self, selected_models, pred_cost, true_cost, true_sol, true_obj, reduction="mean"):
        """
        Forward pass
        """
        loss = self.spop.apply(pred_cost, true_cost, true_sol, true_obj,
                               selected_models, self.processes, self.pool,
                               self.solve_ratio, self, self.warm_start)
        # reduction
        if reduction == "mean":
            loss = torch.mean(loss)
        elif reduction == "sum":
            loss = torch.sum(loss)
        elif reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(reduction))
        return loss


class CSPOPlusFunc(Function):
    """
    A autograd function for SPO+ Loss
    """

    @staticmethod
    def forward(ctx, pred_cost, true_cost, true_sol, true_obj,
                selected_models, processes, pool, solve_ratio, module, warm_start):
        """
        Forward pass for CSPO+

        Args:
            pred_cost (torch.tensor): a batch of predicted values of the cost
            true_cost (torch.tensor): a batch of true values of the cost
            true_sol (torch.tensor): a batch of true optimal solutions
            true_obj (torch.tensor): a batch of true optimal objective values
            optmodel_list (optModel): a list of PyEPO optimization models
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            pool (ProcessPool): process pool object
            solve_ratio (float): the ratio of new solutions computed during training
            module (optModule): SPOPlus modeul

        Returns:
            torch.tensor: SPO+ loss
        """
        # get device
        device = pred_cost.device
        # convert tenstor
        cp = pred_cost.detach().to("cpu").numpy()
        c = true_cost.detach().to("cpu").numpy()
        w = true_sol.detach().to("cpu").numpy()
        z = true_obj.detach().to("cpu").numpy()
        # check sol
        #_check_sol(c, w, z)
        # solve
        if np.random.uniform() <= solve_ratio:
            sol, obj = _cspo_solve_in_pass(2*cp-c, selected_models, processes, pool, warm_start)
            if solve_ratio < 1:
                raise ValueError("Not implemented")
                # # add into solpool
                # module.solpool = np.concatenate((module.solpool, sol))
                # # remove duplicate
                # module.solpool = np.unique(module.solpool, axis=0)
        else:
            raise ValueError("Not implemented")
            # sol, obj = _cache_in_pass(2*cp-c, optmodel, module.solpool)
        # calculate loss
        loss = []
        for i in range(len(cp)):
            loss.append(- obj[i] + 2 * np.dot(cp[i], w[i]) - z[i])
        # sense
        if selected_models[0].modelSense == EPO.MINIMIZE:
            loss = np.array(loss)
        if selected_models[0].modelSense == EPO.MAXIMIZE:
            loss = - np.array(loss)
        # convert to tensor
        loss = torch.FloatTensor(loss).to(device)
        sol = np.array(sol)
        sol = torch.FloatTensor(sol).to(device)
        # save solutions
        ctx.save_for_backward(true_sol, sol)
        # add other objects to ctx
        ctx.selected_models = selected_models
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for SPO+
        TODO: check if it is correct.
        No need to change for CSPO+. 
        """
        w, wq = ctx.saved_tensors
        optmodel = ctx.selected_models[0]
        if optmodel.modelSense == EPO.MINIMIZE:
            grad = 2 * (w - wq)
        if optmodel.modelSense == EPO.MAXIMIZE:
            grad = 2 * (wq - w)
        return grad_output * grad, None, None, None, None, None, None, None, None, None
