from time import time
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm

from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver
from src.methods.models.layers.feed_forward import FeedForwardLayer
from src.methods.models.layers.scalers.standardizer import DeStandardizer


class SPO(DFL):

    def __init__(self, network: torch.nn.Module, lr: float, alpha: float, solver: Solver | None = None,
                 problem_params: dict[str, np.ndarray] | None = None, apply_early_stopping: bool = True,
                 early_stopping_epochs: int = 10, destandardizer: DeStandardizer | None = None,
                 pretrain_epochs: int = 0, device: torch.device = torch.device("cpu"), name: str = "SFGE"):

        super().__init__(name, network, lr, solver, problem_params, apply_early_stopping, early_stopping_epochs,
                         destandardizer, pretrain_epochs, device)

        self._alpha = alpha

    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
                         x_val: np.ndarray, y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray,
                         epochs: int, batch_size: int = 32, time_limit: float | None = None) -> int:

        dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y), torch.Tensor(z))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self._network.train(True)

        start_time = time()

        for epoch in range(epochs):

            epoch_loss = 0.0

            for _x, _y, _z in tqdm(loader):
                _x = _x.to(self._device)
                _y = _y.to(self._device)
                _z = _z.to(self._device)

                self._optimizer.zero_grad()

                y_hat_normalized = self._network(_x)
                y_hat = self._destandardizer(y_hat_normalized) if self._destandardizer else y_hat_normalized

                y_hat_noise = (self._alpha * y_hat) - _y

                with torch.no_grad():
                    z_hat = []
                    for i in range(len(y_hat)):
                        x_i = _x[i].detach().numpy()[0]
                        y_hat_noise_i = y_hat_noise[i].detach().numpy()
                        z_hat_i, _ = self._solver.solve(x_i, y_hat_noise_i, self._problem_params)
                        z_hat.append(z_hat_i)
                    z_hat = torch.Tensor(np.array(z_hat))

                loss = torch.mean(y_hat_noise * (_z - z_hat) * self._mm)

                loss.backward()

                self._optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(dataset)

            stop = self._early_stopping_check(x_val, y_val, z_val, cost_val)
            print("Epoch {} - Train loss: {} - Val regret: {}".format(epoch + 1, epoch_loss,
                                                                      self._early_stopping_value))

            if stop or (time_limit is not None and time() - start_time > time_limit):
                return epoch + 1

        return epochs

    @staticmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:

        network = FeedForwardLayer(input_dim, output_dim, parameters["hidden_units"])
        destandardizer = DeStandardizer()

        model = SPO(network, parameters["lr"], parameters["alpha"], destandardizer=destandardizer,
                    pretrain_epochs=parameters["pretrain_epochs"], name=parameters["name"])

        return model
