
import numpy as np
import torch
import torch.nn.functional as F

from gurobipy import GRB  # pylint: disable=no-name-in-module

from openpto.method.Models.abcOptModel import optModel
from openpto.method.utils_method import do_reduction, to_tensor

class SAAlistwiseLTR(optModel):
    """
    Reference:
    Code from:
    """

    def __init__(self, ptoSolver, coef_dim, tau=1.0, **kwargs):

        """ """
        super().__init__(ptoSolver)

        if tau <= 0:
            raise ValueError("tau is not positive.")
        self.tau = tau
        # solution pool
        n_vars = ptoSolver.num_vars
        self.solpool = np.empty((0, n_vars), dtype=np.float32)
        self.coef_pool = np.empty((0, coef_dim), dtype=np.float32)

    def forward(self, problem, coeff_hat, coeff_true, params, **hyperparams):
        """
        Forward pass
        """
        # obtain solution cache if empty
        if len(self.solpool) == 0:
            _, Y_train, Y_train_aux = problem.get_train_data()
            self.solpool, _ = problem.get_decision(
                Y_train,
                params=Y_train_aux,
                ptoSolver=self.ptoSolver,
                isTrain=False,
                **problem.init_API(),
            )
        # solve #TODO: if sol pool reasonable?
        if len(self.coef_pool) == 0:
            self.coef_pool=coeff_true
        self.coef_pool+=coeff_true
        coef_pool = to_tensor(self.coef_pool).to(coeff_hat.device)
        saa_sol , _ =problem.get_decision(
            coef_pool,
            params=params,
            ptoSolver=self.ptoSolver,
            isTrain=True,
            **problem.init_API(),
        )
        # add into solpool

        self.solpool = np.concatenate((self.solpool, saa_sol))
        # remove duplicate
        self.solpool = np.unique(self.solpool, axis=0)
        # convert tensor
        solpool = to_tensor(self.solpool).to(coeff_hat.device)
        expand_shape = torch.Size([solpool.shape[0]] + list(coeff_hat.shape[1:]))
        coeff_hat = coeff_hat.expand(*expand_shape)
        coeff_true = coeff_true.expand(*expand_shape)
        # obj for solpool
        objpool_c = problem.get_objective(coeff_true, solpool, params)
        objpool_c_hat = problem.get_objective(coeff_hat, solpool, params)
        # cross entropy loss
        if self.ptoSolver.modelSense == GRB.MINIMIZE:
            loss = -(
                F.log_softmax(-objpool_c_hat / self.tau, dim=0)
                * F.softmax(-objpool_c / self.tau, dim=0)
            )
        elif self.ptoSolver.modelSense == GRB.MAXIMIZE:
            loss = -(F.log_softmax(objpool_c_hat, dim=0) * F.softmax(objpool_c, dim=0))
        else:
            raise NotImplementedError
        # reduction
        loss = do_reduction(loss, hyperparams["reduction"])
        return loss