
import numpy as np
import torch
import torch.nn.functional as F

from gurobipy import GRB  # pylint: disable=no-name-in-module

from openpto.method.Models.abcOptModel import optModel
from openpto.method.utils_method import do_reduction, to_tensor

class SAApairwiseLTR(optModel):
    """
    Reference:
    """

    def __init__(self, ptoSolver, coef_dim, **kwargs):
        """ """
        super().__init__(ptoSolver)
        # solution pool
        n_vars = ptoSolver.num_vars
        self.solpool = np.empty((0, n_vars), dtype=np.float32)
        self.coef_pool = np.empty((0, coef_dim), dtype=np.float32)

    def forward(self, problem, coeff_hat, coeff_true, params, **hyperparams):
        """
        Forward pass
        """
        # obtain solution cache if empty
        if len(self.coef_pool) == 0:
            self.coef_pool=coeff_true
        self.coef_pool+=coeff_true
        coef_pool = to_tensor(self.coef_pool).to(coeff_hat.device)
        saa_sol , _ =problem.get_decision(
            coef_pool,
            params=params,
            ptoSolver=self.ptoSolver,
            isTrain=True,
            **problem.init_API(),
        )
        if len(self.solpool) == 0:
            _, Y_train, Y_train_aux = problem.get_train_data()
            self.solpool, _ = problem.get_decision(
                Y_train,
                params=Y_train_aux,
                ptoSolver=self.ptoSolver,
                isTrain=False,
                **problem.init_API(),
            )
        # add into solpool
        self.solpool = np.concatenate((self.solpool, saa_sol))
        # remove duplicate
        self.solpool = np.unique(self.solpool, axis=0)
        solpool = to_tensor(self.solpool).to(coeff_hat.device)
        # transform to tensor
        expand_shape = torch.Size([solpool.shape[0]] + list(coeff_hat.shape[1:]))
        coeff_hat_pool = coeff_hat.expand(*expand_shape)
        coeff_true_pool = coeff_true.expand(*expand_shape)
        # obj for solpool
        objpool_c_true = problem.get_objective(coeff_true_pool, solpool, params)
        objpool_c_hat_pool = problem.get_objective(coeff_hat_pool, solpool, params)
        # TODO: currently, only support batch-1 training
        # init loss
        loss = []
        for i in range(len(coeff_hat)):
            # best sol
            if self.ptoSolver.modelSense == GRB.MINIMIZE:
                # best_ind = torch.argmin(objpool_c_true[i])
                best_ind = torch.argmin(objpool_c_true)
            elif self.ptoSolver.modelSense == GRB.MAXIMIZE:
                # best_ind = torch.argmax(objpool_c_true[i])
                best_ind = torch.argmax(objpool_c_true)
            else:
                raise NotImplementedError
            objpool_cp_best = objpool_c_hat_pool[best_ind]
            # rest sol
            rest_ind = [j for j in range(len(objpool_c_hat_pool)) if j != best_ind]
            objpool_cp_rest = objpool_c_hat_pool[rest_ind]
            # best vs rest loss
            if self.ptoSolver.modelSense == GRB.MINIMIZE:
                loss.append(F.relu(objpool_cp_best - objpool_cp_rest))
            elif self.ptoSolver.modelSense == GRB.MAXIMIZE:
                loss.append(F.relu(objpool_cp_rest - objpool_cp_best))
            else:
                raise NotImplementedError
        loss = torch.stack(loss)
        # reduction
        loss = do_reduction(loss, hyperparams["reduction"])
        return loss