from time import time
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver
from src.methods.models.layers.feed_forward import FeedForwardLayer
from src.methods.models.layers.scalers.standardizer import DeStandardizer


class PFL(DFL):

    def __init__(self, network: torch.nn.Module, lr: float, solver: Solver | None = None,
                 problem_params: dict[str, np.ndarray] | None = None, apply_early_stopping: bool = True,
                 early_stopping_epochs: int = 10, destandardizer: DeStandardizer | None = None,
                 device: torch.device = torch.device("cpu"), name: str = "PFL"):

        super().__init__(name, network, lr, solver, problem_params, apply_early_stopping, early_stopping_epochs,
                         destandardizer, 0, device)

        self._mse_loss = torch.nn.MSELoss()

    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
                         x_val: np.ndarray, y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray,
                         epochs: int, batch_size: int = 32, time_limit: float | None = None) -> int:

        dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self._network.train(True)

        start_time = time()

        for epoch in range(epochs):

            epoch_loss = 0.0

            for _x, _y in tqdm(loader):
                _x = _x.to(self._device)
                _y = _y.to(self._device)

                self._optimizer.zero_grad()

                y_hat = self._network(_x)
                if self._destandardizer is not None:
                    y_hat = self._destandardizer(y_hat)

                loss = self._mse_loss(_y, y_hat)
                loss = torch.mean(loss)

                loss.backward()

                self._optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(loader)

            stop = self._early_stopping_check(x_val, y_val, z_val, cost_val)
            print("Epoch {} - Train MSE: {} - Val MSE: {}".format(epoch + 1, epoch_loss,
                                                                  self._early_stopping_value))

            if stop or (time_limit is not None and time() - start_time > time_limit):
                return epoch + 1

        return epochs

    def _early_stopping_check(self, x_val: np.ndarray, y_val: np.ndarray,
                              z_val: np.ndarray, cost_val: np.ndarray) -> bool:

        if not self._apply_early_stopping:
            return False

        with torch.no_grad():

            y_hat_val = torch.Tensor(self._network_inference(x_val))
            mse = float(torch.mean(self._mse_loss(torch.Tensor(y_val), y_hat_val)).numpy())

            if self._early_stopping_value is not None:
                if mse < self._early_stopping_value:
                    self._early_stopping_value = mse
                    self._early_stopping_count = 0
                    return False
                else:
                    self._early_stopping_count += 1
                    return self._early_stopping_count >= self._early_stopping_epochs

            else:
                self._early_stopping_value = mse
                return False

    @staticmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:

        network = FeedForwardLayer(input_dim, output_dim, parameters["hidden_units"])
        destandardizer = DeStandardizer()

        model = PFL(network, parameters["lr"], destandardizer=destandardizer, name=parameters["name"])

        return model
