#!/usr/bin/env python
# coding: utf-8
"""
"""

from gurobipy import GRB
import torch
import torch.nn as nn
import torch.optim as optim
from openpto.method.Models.abcOptModel import optModel
from openpto.method.utils_method import do_reduction, to_tensor,to_device,to_array
import numpy as np

class SAApointLTR(optModel):
    """ """

    def __init__(self, ptoSolver, coef_dim, **kwargs):
        """ """
        super().__init__(ptoSolver)
        n_vars = ptoSolver.num_vars
        self.coef_pool = np.empty((0, coef_dim), dtype=np.float32)
        self.solpool = np.empty((0, n_vars), dtype=np.float32)

    def forward(
        self,
        problem,
        coeff_hat,
        coeff_true,
        params,
        **hyperparams,
    ):
        """
        Forward pass
        """
        # reduction
        ####
        if len(self.coef_pool) == 0:
            self.coef_pool=coeff_true
        self.coef_pool+=coeff_true
        coef_pool = to_tensor(self.coef_pool).to(coeff_hat.device)
        saa_sol , _ =problem.get_decision(
            coef_pool,
            params=params,
            ptoSolver=self.ptoSolver,
            isTrain=True,
            **problem.init_API(),
        )
        if len(self.solpool) == 0:
            _, Y_train, Y_train_aux = problem.get_train_data()
            self.solpool, _ = problem.get_decision(
                Y_train,
                params=Y_train_aux,
                ptoSolver=self.ptoSolver,
                isTrain=False,
                **problem.init_API(),
            )
        # solve
        
        # add into solpool
        self.solpool = np.concatenate((self.solpool, saa_sol))
        # remove duplicate
        self.solpool = np.unique(self.solpool, axis=0)
        # convert tensor
        solpool = to_tensor(self.solpool).to(coeff_hat.device)
        # obj for solpool as score
        expand_shape = torch.Size([solpool.shape[0]] + list(coeff_hat.shape[1:]))
        coeff_hat = coeff_hat.expand(*expand_shape)
        coeff_true = coeff_true.expand(*expand_shape)
        #
        objpool_c = problem.get_objective(coeff_true, solpool, params)
        objpool_c_hat = problem.get_objective(coeff_hat, solpool, params)
        # squared loss
        loss = (objpool_c - objpool_c_hat).square().mean(axis=0)
        # reduction
        loss = do_reduction(loss, hyperparams["reduction"])
        return loss


