from time import time
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm

from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver
from src.methods.models.layers.feed_forward import FeedForwardLayer
from src.methods.models.layers.stochastic.stochastic import StochasticLayer
from src.methods.models.layers.stochastic.gaussian import GaussianLayer
from src.methods.models.layers.scalers.standardizer import Standardizer, DeStandardizer
from src.methods.models.layers.surrogates.baselines import Baseline, NullBaseline, QuadraticBaseline
from src.methods.models.layers.surrogates.gp.gaussian_process import GaussianProcess
from src.methods.models.layers.surrogates.gp.gaussian_processes_handler import GaussianProcessesHandler

from src.utils.probabilities import bernoulli


class GPDataset(Dataset):

    def __init__(self, x: torch.Tensor, y: torch.Tensor, cost: torch.Tensor):

        assert len(x) == len(y) == len(cost)

        self._length = len(x)

        self._x = x
        self._y = y
        self._cost = cost

    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        return self._x[idx], self._y[idx], self._cost[idx], idx


class GaussianProcessSFGE(DFL):

    def __init__(self, network: StochasticLayer, lr: float, gp_handler: GaussianProcessesHandler,
                 solver: Solver | None = None, gp_std_threshold: float = 0.1, epsilon: float = 0.0,
                 baseline: Baseline = NullBaseline, problem_params: dict[str, np.ndarray] | None = None,
                 apply_early_stopping: bool = True, early_stopping_epochs: int = 10,
                 destandardizer: DeStandardizer | None = None, standardize_regrets: bool = True,
                 differentiate_gp: bool = False, smooth_gp: bool = True, pre_train_gp: bool = True,
                 use_lcb: bool = True, pretrain_epochs: int = 0, device: torch.device = torch.device("cpu"),
                 name: str = "GP SFGE"):

        super().__init__(name, network, lr, solver, problem_params, apply_early_stopping, early_stopping_epochs,
                         destandardizer, pretrain_epochs, device)

        assert 0.0 <= epsilon <= 1.0

        self._gp_handler = gp_handler

        self._std_threshold = gp_std_threshold
        self._epsilon = epsilon
        self._baseline = baseline
        self._standardize_regrets = standardize_regrets
        self._differentiate_gp = differentiate_gp
        self._smooth_gp = smooth_gp
        self._use_lcb = use_lcb
        self._pre_train_gp = pre_train_gp

    def set_solver(self, solver: Solver) -> None:

        super().set_solver(solver)
        self._gp_handler.set_solver(solver)

    def set_problem_params(self, problem_params: dict[str, np.ndarray]) -> None:

        super().set_problem_params(problem_params)
        self._gp_handler.set_problem_params(problem_params)

    def _train_procedure(self, x: np.ndarray, y: np.ndarray, z: np.ndarray, cost: np.ndarray,
                         x_val: np.ndarray, y_val: np.ndarray, z_val: np.ndarray, cost_val: np.ndarray,
                         epochs: int, batch_size: int = 32, time_limit: float | None = None) -> int:

        self._gp_handler.initialize_gaussian_processes(y)

        if self._pre_train_gp:
            pre_train_start = time()
            self._gp_handler.pre_train_gaussian_processes(x, y, cost)
            pre_train_duration = time() - pre_train_start
            self._pre_processing_runtime = pre_train_duration

        dataset = GPDataset(torch.Tensor(x), torch.Tensor(y), torch.Tensor(cost))
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self._network.train(True)

        start_time = time()

        for epoch in range(epochs):

            epoch_loss = 0.0
            gp_count = 0

            self._gp_handler.auto_train()

            for _x, _y, _optimal_cost, indices in tqdm(loader):
                _x = _x.to(self._device)
                _y = _y.to(self._device)
                _optimal_cost = _optimal_cost.to(self._device)

                self._optimizer.zero_grad()

                gp_regrets = []
                rl_regrets = []
                rl_log_probs = []

                for i in range(len(_x)):
                    x_i = torch.unsqueeze(_x[i], dim=0)
                    y_i = _y[i]
                    optimal_cost_i = _optimal_cost[i]
                    gp_index = indices[i]
                    gp_i = self._gp_handler.gaussian_process(gp_index)

                    apply_gp = self._check_apply_gp(x_i, gp_i)
                    if apply_gp:
                        gp_count += 1

                    if apply_gp and self._differentiate_gp:
                        regret_i = self._train_gp_step(x_i, y_i, gp_i)
                        gp_regrets.append(regret_i)

                    # SFGE step
                    else:
                        distribution: torch.distributions.Distribution = self._network.build_distribution(x_i)
                        with torch.no_grad():
                            regret_i, y_hat_i_norm = self._black_box_step(gp_index, x_i, y_i, optimal_cost_i, gp_i,
                                                                          distribution, apply_gp)
                            rl_regrets.append(regret_i)

                            if not apply_gp:
                                self._gp_handler.add_to_train_list(gp_index)

                        log_prob = distribution.log_prob(y_hat_i_norm)
                        rl_log_probs.append(log_prob)

                loss, new_epoch_loss = self._compute_loss(rl_regrets, gp_regrets, rl_log_probs)
                epoch_loss += new_epoch_loss

                loss.backward()

                self._optimizer.step()

            epoch_loss /= len(dataset)
            gp_usage = round(100 * (gp_count / len(x)), 2)

            stop = self._early_stopping_check(x_val, y_val, z_val, cost_val)
            print("Epoch {} - Train regret: {} - Val regret: {} - GP usage: {}%".format(epoch + 1,
                                                                                        epoch_loss,
                                                                                        self._early_stopping_value,
                                                                                        gp_usage))
            if stop or (time_limit is not None and time() - start_time > time_limit):
                return epoch + 1

        return epochs

    def _check_apply_gp(self, x_i: torch.Tensor, gp_i: GaussianProcess) -> bool:

        with torch.no_grad():

            if self._epsilon > 0.0:
                if bernoulli(self._epsilon):
                    return False

            if gp_i.is_empty or not gp_i.is_trained:
                apply_gp = False
            else:
                y_mean = self._network(x_i)
                y_mean = self._destandardizer(y_mean) if self._destandardizer else y_mean
                _, std, _ = gp_i.predict(y_mean)
                apply_gp = std < self._std_threshold

        return apply_gp

    def _train_gp_step(self, x_i: torch.Tensor, y_i: torch.Tensor, gp_i: GaussianProcess) -> torch.Tensor:

        y_hat_i_norm = self._network(x_i)
        y_hat_i = self._destandardizer(y_hat_i_norm) if self._destandardizer else y_hat_i_norm

        baseline = self._baseline(y_hat_i, y_i)
        baseline_correction, _, confidence_bound = gp_i.predict(y_hat_i)
        if self._use_lcb:
            baseline_correction -= confidence_bound

        regret_i = baseline + baseline_correction

        return regret_i

    def _black_box_step(self, i: int, x_i: torch.Tensor, y_i: torch.Tensor, optimal_cost_i: torch.Tensor,
                        gp_i: GaussianProcess, distribution: torch.distributions.Distribution,
                        apply_gp: bool) -> tuple[float, torch.Tensor]:

        y_hat_i_norm = distribution.sample()
        y_hat_i = self._destandardizer(y_hat_i_norm) if self._destandardizer else y_hat_i_norm
        baseline_i = self._baseline(y_hat_i, y_i)

        if apply_gp:
            baseline_correction, _, confidence_bound = gp_i.predict(y_hat_i)
            if self._use_lcb:
                baseline_correction -= confidence_bound
            regret_i = float(baseline_i + baseline_correction)
        else:
            regret_i = self._compute_regret(x_i[0], y_hat_i[0], y_i, optimal_cost_i)
            # Add sample to GP train set
            baseline_correction_i = regret_i - baseline_i
            if self._smooth_gp:
                y_mean = self._destandardizer(distribution.mean) if self._destandardizer else distribution.mean
                self._gp_handler.add_smoothed_samples(i, y_mean, np.array([baseline_correction_i]),
                                                      [y_hat_i_norm], [distribution])
            else:
                self._gp_handler.add_samples(i, y_hat_i, np.array([baseline_correction_i]))

        return regret_i, y_hat_i_norm

    def _compute_loss(self, rl_regrets: list[float], gp_regrets: list[torch.Tensor],
                      rl_log_probs: list[torch.Tensor]) -> tuple[torch.Tensor, float]:

        epoch_loss = 0.0

        with torch.no_grad():

            if len(rl_regrets) > 0:
                rl_regrets = np.array(rl_regrets)
                epoch_loss += np.sum(rl_regrets)
                if self._standardize_regrets:
                    rl_regrets = Standardizer.standardize(rl_regrets)
                rl_regrets = torch.unsqueeze(torch.from_numpy(rl_regrets), dim=1)

            if len(gp_regrets) > 0:
                epoch_loss += np.sum(gp_regrets)

        if len(rl_regrets) > 0:
            log_probs = torch.stack(rl_log_probs, dim=0)
            loss_rl = torch.multiply(-log_probs, -rl_regrets)
        else:
            loss_rl = None

        if len(gp_regrets) > 0:
            loss_gp = torch.stack(gp_regrets, dim=0)
        else:
            loss_gp = None

        if loss_rl is None:
            loss = loss_gp
        elif loss_gp is None:
            loss = loss_rl
        else:
            loss = torch.cat([loss_rl, loss_gp], dim=0)

        loss = torch.mean(loss)

        return loss, epoch_loss

    @staticmethod
    def build_from_config(parameters: dict, input_dim: int, output_dim: int) -> DFL:

        network = FeedForwardLayer(input_dim, output_dim, parameters["hidden_units"])
        stochastic_network = GaussianLayer(network, input_dim, output_dim, parameters["mode"], parameters["std"])
        destandardizer = DeStandardizer()

        if parameters["baseline"] is None:
            baseline = NullBaseline()
        elif parameters["baseline"] == "quadratic":
            baseline = QuadraticBaseline()
        else:
            raise Exception("Invalid baseline parameter: {}".format(parameters["baseline"]))

        gp_handler = GaussianProcessesHandler(gp_training_delta=parameters["gp_training_delta"],
                                              baseline=baseline,
                                              likelihood_ub=parameters["likelihood_ub"],
                                              training_multiplier=parameters["training_delta_multiplier"],
                                              reset_kernel=parameters["reset_kernel"],
                                              default_train_epochs=parameters["default_train_epochs"],
                                              epochs_decay=parameters["epochs_decay"],
                                              smooth_gp=parameters["smooth_gp"],
                                              pool_size=parameters["pool_size"],
                                              shared_samples=parameters["shared_samples"],
                                              stochastic_network=stochastic_network,
                                              max_noise_threshold=parameters["max_noise_threshold"],
                                              set_regrets_stats=parameters["set_regrets_stats"],
                                              scale_correlation_by_distance=parameters["scale_correlation_by_distance"],
                                              parallelize=parameters["parallelize"])

        model = GaussianProcessSFGE(stochastic_network, parameters["lr"], gp_handler,
                                    gp_std_threshold=parameters["gp_std_threshold"],
                                    epsilon=parameters["epsilon"],
                                    baseline=baseline,
                                    differentiate_gp=parameters["differentiate_gp"],
                                    smooth_gp=parameters["smooth_gp"],
                                    use_lcb=parameters["use_lcb"],
                                    pre_train_gp=parameters["pre_train_gp"],
                                    destandardizer=destandardizer,
                                    standardize_regrets=parameters["standardize_regrets"],
                                    pretrain_epochs=parameters["pretrain_epochs"],
                                    name=parameters["name"])

        return model
