from typing import Type

from time import time
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver
from src.methods.models.layers.feed_forward import FeedForwardLayer
from src.methods.models.layers.scalers.standardizer import DeStandardizer

from src.methods.models.layers.surrogates.lodl.lodl_surrogate import LODLSurrogate
from src.methods.models.layers.surrogates.lodl.weighted_mse_surrogate import WeightedMSELODLSurrogate
from src.methods.models.layers.surrogates.lodl.quadratic_surrogate import QuadraticLODLSurrogate
from src.methods.models.layers.surrogates.lodl.directed_quadratic_surrogate import DirectedQuadraticLODLSurrogate
from src.methods.models.layers.surrogates.lodl.directed_weighted_mse_surrogate import DirectedWeightedMSELODLSurrogate


class LODLDataset(Dataset):

    def __init__(self, x: torch.Tensor):

        self._length = len(x)

        self._x = x

    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        return self._x[idx], idx


class LODL(DFL):

    SURROGATES_CLASSES: dict[str, Type[LODLSurrogate]] = {
        "weighted mse": WeightedMSELODLSurrogate,
        "quadratic": QuadraticLODLSurrogate,
        "directed mse": DirectedWeightedMSELODLSurrogate,
        "directed quadratic": DirectedQuadraticLODLSurrogate
    }

    def __init__(self, network: torch.nn.Module, lr: float, surrogate_cls: Type[LODLSurrogate],
                 perturbation_sigma: float, perturbation_samples: int, solver: Solver | None = None,
                 problem_params: dict[str, np.ndarray] | None = None, apply_early_stopping: bool = True,
                 early_stopping_epochs: int = 10, destandardizer: DeStandardizer | None = None,
                 pretrain_epochs: int = 0, device: torch.device = torch.device("cpu"), name: str = "LODL"):

        super().__init__(name, network, lr, solver, problem_params, apply_early_stopping, early_stopping_epochs,
                         destandardizer, pretrain_epochs, device)

        assert perturbation_samples > 0
        assert perturbation_sigma > 0.0

        self._surrogate_cls = surrogate_cls
        self._perturbation_sigma = perturbation_sigma
        self._perturbation_samples = perturbation_samples

        self._surrogates: list[LODLSurrogate] = []

    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray, x_val: np.ndarray,
                         y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray, epochs: int,
                         batch_size: int = 32, time_limit: float | None = None) -> int:

        pre_train_start = time()
        self._build_surrogates(torch.Tensor(x), torch.Tensor(y))
        pre_train_duration = time() - pre_train_start
        self._pre_processing_runtime = pre_train_duration

        dataset = LODLDataset(torch.Tensor(x))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self._network.train(True)

        start_time = time()

        for epoch in range(epochs):

            epoch_loss = 0.0

            for _x, indices in tqdm(loader):

                _x = _x.to(self._device)

                self._optimizer.zero_grad()

                y_hat = self._network(_x)
                if self._destandardizer is not None:
                    y_hat = self._destandardizer(y_hat)

                loss = None
                for y_hat_i, i in zip(y_hat, indices):
                    y_hat_i = torch.unsqueeze(y_hat_i, dim=0)
                    if loss is None:
                        loss = self._surrogates[i](y_hat_i).sum()
                    else:
                        loss += self._surrogates[i](y_hat_i).sum()

                loss /= len(_x)

                loss.backward()

                self._optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(loader)

            stop = self._early_stopping_check(x_val, y_val, z_val, cost_val)
            print("Epoch {} - Train surrogate cost: {} - Val cost: {}".format(epoch + 1, epoch_loss,
                                                                              self._early_stopping_value))

            if stop or (time_limit is not None and time() - start_time > time_limit):
                return epoch + 1

        return epochs

    def _build_surrogates(self, x: torch.Tensor, y: torch.Tensor) -> None:

        x = x.detach().numpy()
        y_std = np.std(y.detach().numpy(), axis=0)

        self._surrogates.clear()

        for x_true, y_true in tqdm(zip(x, y), desc="Training surrogates"):
            surrogate = self._surrogate_cls(y_true)
            y_samples, cost_samples = surrogate.sample(self._perturbation_sigma * y_std, self._perturbation_samples,
                                                       x_true, self._solver, self._problem_params)
            surrogate.train_surrogate(y_samples, cost_samples)
            self._surrogates.append(surrogate)

    @staticmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:

        network = FeedForwardLayer(input_dim, output_dim, parameters["hidden_units"])
        destandardizer = DeStandardizer()

        if parameters["surrogate_cls"] not in LODL.SURROGATES_CLASSES:
            raise Exception("Invalid Surrogate class {}".format(parameters["surrogate_cls"]))

        surrogate_cls = LODL.SURROGATES_CLASSES[parameters["surrogate_cls"]]

        model = LODL(network, parameters["lr"], surrogate_cls, parameters["perturbation_sigma"],
                     parameters["perturbation_samples"], pretrain_epochs=parameters["pretrain_epochs"],
                     destandardizer=destandardizer, name=parameters["name"])

        return model
