from time import time
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm import tqdm

from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver
from src.methods.models.layers.feed_forward import FeedForwardLayer
from src.methods.models.layers.stochastic.stochastic import StochasticLayer
from src.methods.models.layers.stochastic.gaussian import GaussianLayer
from src.methods.models.layers.scalers.standardizer import Standardizer, DeStandardizer


class SFGE(DFL):

    def __init__(self, network: StochasticLayer, lr: float, solver: Solver | None = None,
                 problem_params: dict[str, np.ndarray] | None = None, apply_early_stopping: bool = True,
                 early_stopping_epochs: int = 10, destandardizer: DeStandardizer | None = None,
                 standardize_regrets: bool = True, pretrain_epochs: int = 0,
                 device: torch.device = torch.device("cpu"), name: str = "SFGE"):

        super().__init__(name, network, lr, solver, problem_params, apply_early_stopping, early_stopping_epochs,
                         destandardizer, pretrain_epochs, device)

        self._standardize_regrets = standardize_regrets

    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
                         x_val: np.ndarray, y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray,
                         epochs: int, batch_size: int = 32, time_limit: float | None = None) -> int:

        dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y), torch.Tensor(cost))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self._network.train(True)

        start_time = time()

        for epoch in range(epochs):

            epoch_loss = 0.0

            for _x, _y, _optimal_cost in tqdm(loader):
                _x = _x.to(self._device)
                _y = _y.to(self._device)
                _optimal_cost = _optimal_cost.to(self._device)

                self._optimizer.zero_grad()

                distribution: torch.distributions.Distribution = self._network.build_distribution(_x)

                with torch.no_grad():
                    y_hat_normalized = distribution.sample()
                    y_hat = self._destandardizer(y_hat_normalized) if self._destandardizer else y_hat_normalized
                    regrets = []
                    for i in range(len(y_hat)):
                        regret_i = self._compute_regret(_x[i], y_hat[i], _y[i], _optimal_cost[i])
                        regrets.append(regret_i)

                    regrets = np.array(regrets)
                    epoch_loss += np.sum(regrets)

                    if self._standardize_regrets:
                        regrets = Standardizer.standardize(regrets)

                    regrets = torch.from_numpy(regrets)

                log_prob = distribution.log_prob(y_hat_normalized)

                loss = torch.multiply(-log_prob, -regrets)
                loss = torch.mean(loss)

                loss.backward()

                self._optimizer.step()

            epoch_loss /= len(dataset)

            stop = self._early_stopping_check(x_val, y_val, z_val, cost_val)
            print("Epoch {} - Train regret: {} - Val regret: {}".format(epoch + 1, epoch_loss,
                                                                        self._early_stopping_value))

            if stop or (time_limit is not None and time() - start_time > time_limit):
                return epoch + 1

        return epochs

    @staticmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:

        network = FeedForwardLayer(input_dim, output_dim, parameters["hidden_units"])
        stochastic_network = GaussianLayer(network, input_dim, output_dim, parameters["mode"], parameters["std"])
        destandardizer = DeStandardizer()

        model = SFGE(stochastic_network, parameters["lr"], destandardizer=destandardizer,
                     standardize_regrets=parameters["standardize_regrets"],
                     pretrain_epochs=parameters["pretrain_epochs"], name=parameters["name"])

        return model
