import math
import multiprocessing
from multiprocessing.pool import ThreadPool
import numpy as np
import torch
import gpytorch.constraints
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
from gpytorch.priors import LogNormalPrior
from scipy.stats import qmc
from tqdm import tqdm

from src.solvers.solver import Solver
from src.methods.models.layers.stochastic.stochastic import StochasticLayer
from src.methods.dfl_abstract import DFL
from src.methods.models.layers.surrogates.gp.gaussian_process import GaussianProcess
from src.methods.models.layers.surrogates.baselines import Baseline, NullBaseline


class GaussianProcessesHandler:

    def __init__(self, solver: Solver | None = None, problem_params: dict[str, np.ndarray] | None = None,
                 gp_training_delta: int = 50, baseline: Baseline = NullBaseline(), likelihood_ub: float = 0.01,
                 training_multiplier: float = 1.0, reset_kernel: bool = True, default_train_epochs: int = 50,
                 epochs_decay: float = 1.0, smooth_gp: bool = True, pool_size: int | None = None,
                 shared_samples: bool = True, stochastic_network: StochasticLayer | None = None,
                 max_noise_threshold: float = 0.05, set_regrets_stats: bool = False,
                 scale_correlation_by_distance: bool = True, device: torch.device = torch.device("cpu"),
                 parallelize: bool = False):

        if smooth_gp and stochastic_network is None:
            print("WARNING: Smoothing requires a stochastic network")

        self._solver = solver
        self._problem_params = problem_params
        self._gp_training_delta = gp_training_delta
        self._baseline = baseline
        self._likelihood_ub = likelihood_ub
        self._training_multiplier = training_multiplier
        self._reset_kernel = reset_kernel
        self._default_train_epochs = default_train_epochs
        self._epochs_decay = epochs_decay
        self._smooth_gp = smooth_gp
        self._pool_size = pool_size
        self._max_noise_threshold = max_noise_threshold
        self._stochastic_network = stochastic_network
        self._shared_samples = shared_samples
        self._set_regrets_stats = set_regrets_stats
        self._scale_correlation_by_distance = scale_correlation_by_distance and shared_samples
        self._device = device
        self._parallelize = parallelize

        if self._solver is not None:
            self._mm = 1 if self._solver.is_minimization_problem else -1

        self._gaussian_processes: list[GaussianProcess] = []
        self._train_list: list[int] = []
        self._distance_matrix: np.ndarray | None = None
        self._gp_samples_indices: list[list[int]] = []

    @property
    def gaussian_processes(self) -> list[GaussianProcess]:
        return self._gaussian_processes

    @property
    def len(self) -> int:
        return len(self._gaussian_processes)

    def set_solver(self, solver: Solver) -> None:

        self._solver = solver
        self._mm = 1 if self._solver.is_minimization_problem else -1

    def set_problem_params(self, problem_params: dict[str, np.ndarray]) -> None:
        self._problem_params = problem_params

    def gaussian_process(self, i: int) -> GaussianProcess:
        return self._gaussian_processes[i]

    def reset(self) -> None:
        self._gaussian_processes = []
        self._train_list = []
        self._distance_matrix = None
        self._gp_samples_indices = []

    def add_to_train_list(self, i: int) -> None:
        self._train_list.append(i)

    def auto_train(self) -> None:

        if not self._parallelize:
            for index in tqdm(self._train_list, desc="Training GPs"):
                gp = self._gaussian_processes[index]
                if self._scale_correlation_by_distance and gp.is_trainable:
                    distance_correlation_matrix = self._build_distance_correlation_matrix(index)
                    gp.set_distance_noise_matrix(distance_correlation_matrix)
                gp.auto_train()

        else:
            with ThreadPool(processes=multiprocessing.cpu_count()) as pool:
                pool.map(self._multi_processing_auto_train, self._train_list)

        self._train_list.clear()

    def initialize_gaussian_processes(self, y: np.ndarray) -> None:

        self.reset()

        input_min = np.min(y, axis=0)
        input_max = np.max(y, axis=0)

        for i, y_i in enumerate(y):
            input_vector = np.expand_dims(y_i, axis=0)
            target_value = np.array([0.0])

            noise_constraint = gpytorch.constraints.Interval(lower_bound=0.0, upper_bound=self._likelihood_ub)
            prior_noise = torch.zeros(1)
            likelihood = FixedNoiseGaussianLikelihood(noise=prior_noise,
                                                      learn_additional_noise=True,
                                                      noise_constraint=noise_constraint)

            mu_0 = 0.0
            sigma_0 = 1.0
            n_dimensions = len(y_i)
            length_scale_prior = LogNormalPrior(mu_0 + np.log(n_dimensions) / 2, sigma_0)

            gp = GaussianProcess(likelihood,
                                 input_vector,
                                 target_value,
                                 auto_training_delta=self._gp_training_delta,
                                 length_scale_prior=length_scale_prior,
                                 apply_smoothing=self._smooth_gp,
                                 training_multiplier=self._training_multiplier,
                                 reset_kernel=self._reset_kernel,
                                 default_train_epochs=self._default_train_epochs,
                                 epochs_decay=self._epochs_decay,
                                 input_min=input_min,
                                 input_max=input_max,
                                 device=self._device,
                                 pool_size=self._pool_size,
                                 scale_correlation_by_noise=self._scale_correlation_by_distance,
                                 name="{}/{}".format(i + 1, len(y)))

            gp.add_noise(0.0)

            self._gaussian_processes.append(gp)

            if self._scale_correlation_by_distance:
                self._gp_samples_indices.append([i])

    def pre_train_gaussian_processes(self, x: np.ndarray, y: np.ndarray,
                                     cost: np.ndarray, n_samples_base: int = 5) -> None:

        self._solver.freeze_calls_count()

        l_bounds = np.min(y, axis=0).tolist()
        u_bounds = np.max(y, axis=0).tolist()
        d = y.shape[1]

        if d > 1:
            n_samples = int(n_samples_base * math.log2(d))
            sampler = qmc.LatinHypercube(d=d, optimization="random-cd")
            samples = qmc.scale(sampler.random(n=n_samples), l_bounds, u_bounds)
            samples = torch.Tensor(samples)

        else:
            n_samples = n_samples_base
            samples = torch.linspace(start=l_bounds[0], end=u_bounds[0], steps=n_samples)
            samples = torch.unsqueeze(samples, dim=-1)

        regrets: list[np.ndarray] = []

        for gp_i, x_i, y_i, optimal_cost_i in tqdm(zip(self._gaussian_processes, x, y, cost), desc="Pretraining GPs"):

            i = self._gaussian_processes.index(gp_i)

            x_i = torch.unsqueeze(torch.Tensor(x_i), dim=0)
            y_i = torch.Tensor(y_i)
            optimal_cost_i = torch.Tensor([optimal_cost_i])[0]
            regrets_i = []

            random_state = torch.get_rng_state()

            for sample in samples:

                sample = torch.unsqueeze(sample, dim=0)
                baseline_i = self._baseline(sample, y_i)
                test_regret = DFL.compute_regret(self._solver, self._problem_params, self._mm,
                                                 x_i, sample[0], y_i, optimal_cost_i)
                baseline_correction_i = test_regret - baseline_i
                regrets_i.append(test_regret)

                if self._smooth_gp and self._stochastic_network is not None:
                    distribution: torch.distributions.Distribution = self._stochastic_network.build_distribution(x_i,
                                                                                                                 y_mean=sample)
                    torch.set_rng_state(random_state)  # Force to sample in the same point for all GPs
                    sample_sample = distribution.sample()
                    gp_i.add_smoothed_samples(sample, np.array([baseline_correction_i]), [sample_sample],
                                              [distribution], noises=[0.0])
                    if self._scale_correlation_by_distance:
                        self._gp_samples_indices[i].append(i)
                else:
                    gp_i.add_samples(sample, np.array([baseline_correction_i]), noises=[0.0])
                    if self._scale_correlation_by_distance:
                        self._gp_samples_indices[i].append(i)

            regrets.append(np.array(regrets_i))

        if self._set_regrets_stats:
            self._adjust_regrets_stats(regrets)

        if self._shared_samples:
            self._compute_distance_matrix(regrets)

        for index, gp in tqdm(enumerate(self._gaussian_processes), "Pretraining GPs"):
            if self._scale_correlation_by_distance and gp.is_trainable:
                distance_correlation_matrix = self._build_distance_correlation_matrix(index)
                gp.set_distance_noise_matrix(distance_correlation_matrix)
            gp.auto_train()

        self._solver.unfreeze_calls_count()

    def add_smoothed_samples(self, i: int, sample: torch.Tensor, targets: torch.Tensor | np.ndarray,
                             sample_x: list[torch.Tensor | np.ndarray],
                             distributions: list[torch.distributions.Distribution]) -> None:

        if self._distance_matrix is None:
            noises = [0.0] * len(sample)
            self._gaussian_processes[i].add_smoothed_samples(sample, targets, sample_x, distributions,
                                                             noises=noises)

        else:
            for j, gp in enumerate(self._gaussian_processes):
                noise = self._distance_matrix[i][j]
                if noise <= self._max_noise_threshold:
                    noises = [noise] * len(sample)
                    gp.add_smoothed_samples(sample, targets, sample_x, distributions, noises=noises)
                    if self._scale_correlation_by_distance:
                        self._gp_samples_indices[j].append(i)

    def add_samples(self, i: int, sample: torch.Tensor, targets: torch.Tensor | np.ndarray) -> None:

        if self._distance_matrix is None:
            noises = [0.0] * len(sample)
            self._gaussian_processes[i].add_samples(sample, targets, noises=noises)

        else:
            for j, gp in enumerate(self._gaussian_processes):
                noise = self._distance_matrix[i][j]
                if noise <= self._max_noise_threshold:
                    noises = [noise] * len(sample)
                    gp.add_samples(sample, targets, noises=noises)
                    if self._scale_correlation_by_distance:
                        self._gp_samples_indices[j].append(i)

    def _adjust_regrets_stats(self, regrets: list[np.ndarray]) -> None:

        np_regrets = np.array(regrets)
        target_mean = float(np.mean(np_regrets))
        target_std = float(np.std(np_regrets))

        for gp in self._gaussian_processes:
            gp.set_target_mean(target_mean)
            gp.set_target_std(target_std)

    def _compute_distance_matrix(self, regrets: list[np.ndarray]) -> None:

        std_regret = np.std(np.array(regrets))

        self._distance_matrix = np.zeros(shape=(len(regrets), len(regrets)))

        for i, regret_i in tqdm(enumerate(regrets), desc="Computing distance matrix"):
            for j, regret_j in enumerate(regrets):
                distance = np.sqrt(np.mean((regret_i - regret_j) ** 2))
                scaled_distance = distance / std_regret
                self._distance_matrix[i][j] = scaled_distance

        avg_sharing = np.mean([(row < self._max_noise_threshold).mean() for row in self._distance_matrix])
        print("Average GP sharing: {}%".format(round(avg_sharing * 100, 2)))

    def _build_distance_correlation_matrix(self, index: int) -> torch.Tensor:

        indices = self._gp_samples_indices[index]
        correlation_matrix = []

        for i in indices:
            row = []
            for j in indices:
                noise = self._distance_matrix[i][j]
                row.append(noise)
            correlation_matrix.append(row)

        correlation_matrix = torch.Tensor(np.array(correlation_matrix))

        return correlation_matrix

    def _multi_processing_auto_train(self, index: int):

        gp = self._gaussian_processes[index]

        if self._scale_correlation_by_distance and gp.is_trainable:
            distance_correlation_matrix = self._build_distance_correlation_matrix(index)
            gp.set_distance_noise_matrix(distance_correlation_matrix)

        gp.auto_train()
