from typing import Type

from time import time
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import tqdm

from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver
from src.methods.models.layers.feed_forward import FeedForwardLayer
from src.methods.models.layers.scalers.standardizer import DeStandardizer

from src.methods.models.layers.surrogates.egl.egl_surrogate import EGLSurrogate
from src.methods.models.layers.surrogates.egl.weighted_mse_surrogate import WeightedMSEEGLSurrogate
from src.methods.models.layers.surrogates.egl.quadratic_surrogate import QuadraticEGLSurrogate
from src.methods.models.layers.surrogates.egl.directed_quadratic_surrogate import DirectedQuadraticEGLSurrogate
from src.methods.models.layers.surrogates.egl.directed_weighted_mse_surrogate import DirectedWeightedMSEEGLSurrogate


class EGLDataset(Dataset):

    def __init__(self, x: torch.Tensor):

        self._length = len(x)

        self._x = x

    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        return self._x[idx], idx


class PretrainEGLDataset(Dataset):

    def __init__(self, x: torch.Tensor, y: torch.Tensor, cost: torch.Tensor):

        self._length = len(x)

        self._x = x
        self._y = y
        self._cost = cost

    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        return self._x[idx], self._y[idx], self._cost[idx], idx


class EGL(DFL):

    SURROGATES_CLASSES: dict[str, Type[EGLSurrogate]] = {
        "weighted mse": WeightedMSEEGLSurrogate,
        "quadratic": QuadraticEGLSurrogate,
        "directed mse": DirectedWeightedMSEEGLSurrogate,
        "directed quadratic": DirectedQuadraticEGLSurrogate
    }

    def __init__(self, network: torch.nn.Module, lr: float, surrogate_cls: Type[EGLSurrogate],
                 hidden_units: list[int], n_samples: int = 32, n_models: int = 4,
                 outer_lr: float = 0.001, outer_train_epochs: int = 100,
                 inner_lr: float = 0.1, inner_train_epochs: int = 100, solver: Solver | None = None,
                 problem_params: dict[str, np.ndarray] | None = None, apply_early_stopping: bool = True,
                 early_stopping_epochs: int = 10, destandardizer: DeStandardizer | None = None,
                 pretrain_epochs: int = 0, device: torch.device = torch.device("cpu"), name: str = "EGL"):

        super().__init__(name, network, lr, solver, problem_params, apply_early_stopping, early_stopping_epochs,
                         destandardizer, pretrain_epochs, device)

        assert n_models > 0

        self._surrogate_cls = surrogate_cls
        self._hidden_units = hidden_units
        self._n_samples = n_samples
        self._n_models = n_models
        self._outer_lr = outer_lr
        self._outer_train_epochs = outer_train_epochs
        self._inner_lr = inner_lr
        self._inner_train_epochs = inner_train_epochs

        self._surrogates: list[EGLSurrogate] = []

        self._pre_loss_layer = None

    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray, x_val: np.ndarray,
                         y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray, epochs: int,
                         batch_size: int = 32, time_limit: float | None = None) -> int:

        pre_train_start = time()
        self._build_surrogates(torch.Tensor(x), torch.Tensor(y))
        self._train_surrogate(torch.Tensor(x), torch.Tensor(y), torch.Tensor(cost))
        pre_train_duration = time() - pre_train_start
        self._pre_processing_runtime = pre_train_duration

        dataset = EGLDataset(torch.Tensor(x))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self._network.train(True)

        start_time = time()

        for epoch in range(epochs):

            epoch_loss = 0.0

            for _x, indices in tqdm(loader):

                _x = _x.to(self._device)

                self._optimizer.zero_grad()

                y_hat = self._network(_x)
                if self._destandardizer is not None:
                    y_hat = self._destandardizer(y_hat)

                params = self._pre_loss_layer(_x)

                loss = None
                for params_i, y_hat_i, i in zip(params, y_hat, indices):
                    y_hat_i = torch.unsqueeze(y_hat_i, dim=0)
                    if loss is None:
                        loss = self._surrogates[i](y_hat_i, params_i).sum()
                    else:
                        loss += self._surrogates[i](y_hat_i, params_i).sum()

                loss /= len(_x)

                loss.backward()

                self._optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(loader)

            stop = self._early_stopping_check(x_val, y_val, z_val, cost_val)
            print("Epoch {} - Train surrogate cost: {} - Val cost: {}".format(epoch + 1, epoch_loss,
                                                                              self._early_stopping_value))

            if stop or (time_limit is not None and time() - start_time > time_limit):
                return epoch + 1

        return epochs

    def _build_surrogates(self, x: torch.Tensor, y: torch.Tensor) -> None:

        self._surrogates.clear()
        self._pre_loss_layer = None

        for y_true in y:
            surrogate = self._surrogate_cls(y_true)
            self._surrogates.append(surrogate)

            if self._pre_loss_layer is None:
                self._pre_loss_layer = FeedForwardLayer(x.shape[1], surrogate.params_dim, self._hidden_units)

    def _train_surrogate(self, x: torch.Tensor, y: torch.Tensor, optimal_cost: torch.Tensor,
                         batch_size: int = 32) -> None:

        local_n_samples = self._n_samples // self._n_models
        y_samples = []

        for i in range(self._n_models):
            y_samples.append(self._sample("{}/{}".format(i+1, self._n_models), x, y, local_n_samples))

        y_samples = torch.concat(y_samples, dim=1)

        cost_samples = []

        for x_i, y_i, optimal_cost_i, batch_y_hat_i in tqdm(zip(x, y, optimal_cost, y_samples), desc="Computing regrets"):
            sample_regrets = [self._compute_regret(x_i, y_hat_i, y_i, optimal_cost_i) for y_hat_i in batch_y_hat_i]
            sample_regrets = torch.Tensor(np.array(sample_regrets))
            cost_samples.append(sample_regrets)

        cost_samples = torch.stack(cost_samples, dim=0)

        dataset = PretrainEGLDataset(x, y_samples, cost_samples)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        mse_loss = torch.nn.MSELoss()
        outer_optimizer = torch.optim.Adam(self._pre_loss_layer.parameters(), lr=self._outer_lr)

        self._pre_loss_layer.train(True)

        for epoch in range(self._outer_train_epochs):

            epoch_loss = 0.0

            for _x, _y, _cost, indices in tqdm(loader):
                _x = _x.to(self._device)
                _y = _y.to(self._device)
                _cost = _cost.to(self._device)

                outer_optimizer.zero_grad()

                params = self._pre_loss_layer(_x)

                loss = None
                for params_i, y_samples, cost_samples, i in zip(params, _y, _cost, indices):
                    cost_pred = self._surrogates[i](y_samples, params_i)
                    if loss is None:
                        loss = torch.mean(mse_loss(cost_samples, cost_pred))
                    else:
                        loss += torch.mean(mse_loss(cost_samples, cost_pred))

                loss /= len(x)

                loss.backward()

                outer_optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(loader)

            print("Epoch {} - Train surrogate MSE: {}".format(epoch + 1, epoch_loss))

    def _sample(self, name: str, x: torch.Tensor, y: torch.Tensor, n_samples: int, batch_size: int = 32) -> torch.Tensor:

        checkpoint_space = self._inner_train_epochs // n_samples
        y_samples = []

        dataset = TensorDataset(x, y)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        inner_model = FeedForwardLayer(x.shape[1], y.shape[1], hidden_units=[])
        mse_loss = torch.nn.MSELoss()
        inner_optimizer = torch.optim.Adam(inner_model.parameters(), lr=self._inner_lr)
        inner_model.train(True)

        for epoch in tqdm(range(self._inner_train_epochs), desc="Training sampling model " + name):

            epoch_loss = 0.0

            for _x, _y in loader:
                _x = _x.to(self._device)
                _y = _y.to(self._device)

                inner_optimizer.zero_grad()

                y_hat = inner_model(_x)
                if self._destandardizer is not None:
                    y_hat = self._destandardizer(y_hat)

                loss = mse_loss(_y, y_hat)
                loss = torch.mean(loss)

                loss.backward()

                inner_optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(loader)

            with torch.no_grad():
                if (epoch + 1) % checkpoint_space == 0:
                    inner_model.eval()

                    predictions = []

                    for _x, _y in loader:
                        _x = _x.to(self._device)
                        _y = _y.to(self._device)

                        y_hat = inner_model(x)
                        if self._destandardizer is not None:
                            y_hat = self._destandardizer(y_hat)

                        y_hat = torch.unsqueeze(y_hat, dim=1)

                        predictions.extend(list(y_hat.to("cpu").numpy()))

                    predictions = torch.Tensor(np.array(predictions))
                    y_samples.append(predictions)

                    inner_model.train(True)

        y_samples = torch.concat(y_samples, dim=1)

        return y_samples

    @staticmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:

        network = FeedForwardLayer(input_dim, output_dim, parameters["hidden_units"])
        destandardizer = DeStandardizer()

        if parameters["surrogate_cls"] not in EGL.SURROGATES_CLASSES:
            raise Exception("Invalid Surrogate class {}".format(parameters["surrogate_cls"]))

        surrogate_cls = EGL.SURROGATES_CLASSES[parameters["surrogate_cls"]]

        model = EGL(network, parameters["lr"], surrogate_cls, parameters["egl_hidden_units"],
                    parameters["n_samples"], parameters["n_models"], pretrain_epochs=parameters["pretrain_epochs"],
                    destandardizer=destandardizer, name=parameters["name"])

        return model
