#!/usr/bin/env python
# coding: utf-8
"""
CSPO+ Loss function
"""

import numpy as np
import torch
from torch.autograd import Function
import random
from multiprocessing import Pool

from pyepo import EPO
from pyepo.func.abcmodule import cspo_optModule
from pyepo.func.utlis import _cspo_cache_in_pass,_cache_in_pass, _cspo_solve_in_pass
import time, os

def init_worker():
    random.seed(42)
    np.random.seed(42)

class CSPOPlus(cspo_optModule):
    """
    An autograd module for CSPO+ Loss, as a surrogate loss function of CSPO Loss,
    which measures the decision error of the optimization problem.

    For SPO/SPO+ Loss, the objective function is linear and constraints are
    known and fixed, but the cost vector needs to be predicted from contextual
    data.

    The SPO+ Loss is convex with subgradient. Thus, it allows us to design an
    algorithm based on stochastic gradient descent.
    """

    def __init__(self, optmodel_list, processes=1, solve_ratio=1, dataset=None, warm_start=False):
        """
        Args:
            optmodel_list (optModel): a list of PyEPO optimization models
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            solve_ratio (float): the ratio of new solutions computed during training
            dataset (None/optDataset): the training data
            warm_start (bool): whether to use warm start for optimization
        """
        super().__init__(optmodel_list, processes, solve_ratio, dataset)
        # build criterion
        self.spop = CSPOPlusFunc()
        self.warm_start = warm_start
        # Initialize process pool with proper seeding
        if processes > 1:
            self.pool = Pool(processes=processes, initializer=init_worker)
        else:
            self.pool = None

    def forward(self, selected_models, pred_cost, true_cost, true_sol, true_obj, reduction="mean"):
        """
        Forward pass
        """
        loss = self.spop.apply(pred_cost, true_cost, true_sol, true_obj,
                               selected_models, self.processes, self.pool,
                               self.solve_ratio, self, self.warm_start)
        
        # reduction
        if reduction == "mean":
            loss = torch.mean(loss)
        elif reduction == "sum":
            loss = torch.sum(loss)
        elif reduction == "none":
            loss = loss
        else:
            raise ValueError("No reduction '{}'.".format(reduction))
        return loss


class CSPOPlusFunc(Function):
    """
    A autograd function for SPO+ Loss
    """

    @staticmethod
    def forward(ctx, pred_cost, true_cost, true_sol, true_obj,
                selected_models, processes, pool, solve_ratio, module, warm_start):
        """
        Forward pass for CSPO+

        Args:
            pred_cost (torch.tensor): a batch of predicted values of the cost
            true_cost (torch.tensor): a batch of true values of the cost
            true_sol (torch.tensor): a batch of true optimal solutions
            true_obj (torch.tensor): a batch of true optimal objective values
            selected_models (list): list of optimization models
            processes (int): number of processors, 1 for single-core, 0 for all of cores
            pool (ProcessPool): process pool object
            solve_ratio (float): the ratio of new solutions computed during training
            module (optModule): SPOPlus module
            warm_start (bool): whether to use warm start

        Returns:
            torch.tensor: SPO+ loss
        """
        # get device
        device = pred_cost.device
        # convert tensor
        cp = pred_cost.detach().to("cpu").numpy()
        c = true_cost.detach().to("cpu").numpy()
        w = true_sol.detach().to("cpu").numpy()
        z = true_obj.detach().to("cpu").numpy()
        
        # solve
        # start_time =time.time()
        if np.random.uniform() <= solve_ratio:
            # print("-------------Solving new instances-------------")
            # Solve new instances
            # print("Solving new instances")
            sol, obj = _cspo_solve_in_pass(2*cp-c, selected_models, processes, pool, warm_start)
            if solve_ratio < 1:
                # Add new solution to the pool
                module.solpool = np.concatenate((module.solpool, sol))
                # Remove duplicates
                module.solpool = np.unique(module.solpool, axis=0)

        else:
            sol, obj = _cache_in_pass(2*cp-c, selected_models, module.solpool)
            
        # calculate loss
        loss = []
        for i in range(len(cp)):
            loss.append(- obj[i] + 2 * np.dot(cp[i], w[i]) - z[i])
            
        # sense
        if selected_models[0].modelSense == EPO.MINIMIZE:
            loss = np.array(loss)
        if selected_models[0].modelSense == EPO.MAXIMIZE:
            loss = - np.array(loss)
            
        # convert to tensor
        loss = torch.FloatTensor(loss).to(device)
        sol = np.array(sol)
        sol = torch.FloatTensor(sol).to(device)
        
        # save solutions
        ctx.save_for_backward(true_sol, sol)
        # add other objects to ctx
        ctx.selected_models = selected_models
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for SPO+
        """
        w, wq = ctx.saved_tensors
        optmodel = ctx.selected_models[0]
        if optmodel.modelSense == EPO.MINIMIZE:
            grad = 2 * (w - wq)
        if optmodel.modelSense == EPO.MAXIMIZE:
            grad = 2 * (wq - w)
        return grad_output * grad, None, None, None, None, None, None, None, None, None
