from tkinter.messagebox import NO
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
from abc import abstractmethod, ABC
import numpy as np
from util import tf_random_choice, limit_bijector_to_dim
from typing import Callable, List, Optional, Union
import deepxde as dde
import time

tfb = tfp.bijectors
mcmc = tfp.mcmc


@limit_bijector_to_dim(0)
class SigmoidDimFirstDim(tfp.bijectors.Sigmoid):
    pass


@limit_bijector_to_dim(1)
class SigmoidDimSecondDim(tfp.bijectors.Sigmoid):
    pass


class Sampler:
    def __init__(self, n_samples: int, dim: int, sample_each: int = 1, **kwargs):
        self.dim = dim
        self.counter = 0  # tf.convert_to_tensor(0, dtype="float32")
        self.cur_samples: tf.Tensor = tf.ones((n_samples, dim), dtype="float32", name="cur_samples")
        self.sample_each = sample_each  # tf.convert_to_tensor(sample_each, dtype="float32")
        self.n_samples = n_samples

    @abstractmethod
    def sample(self, init=None) -> tf.Tensor:
        ...


class BoundedSampler(Sampler, ABC):
    def __init__(self, n_samples: int, dim: int, extent: list[tuple], **kwargs):
        super(BoundedSampler, self).__init__(n_samples, dim, **kwargs)
        self.extent = extent

        self.min = tf.reshape(tf.convert_to_tensor([ext[0] for ext in extent], dtype="float32"), (1, -1))
        self.max = tf.reshape(tf.convert_to_tensor([ext[1] for ext in extent], dtype="float32"), (1, -1))

        assert np.all([ext[0] < ext[1] for ext in extent]), "First value must be smaller than second one."
        assert len(self.extent) == self.dim, "dimension doesn't match extent."

    def transform_to_extent(self, x):
        return x * (self.max - self.min) + self.min


class SobolSampler(BoundedSampler):
    @tf.function
    def sample(self):
        raw_samples = tf.math.sobol_sample(dim=self.dim,
                                           num_results=self.n_samples)
        return self.transform_to_extent(raw_samples)


class HaltonSampler(BoundedSampler):
    @tf.function
    def sample(self):
        raw_samples = tfp.mcmc.sample_halton_sequence(dim=self.dim,
                                                      num_results=self.n_samples,
                                                      randomized=True)
        return self.transform_to_extent(raw_samples)


class UniformSampler(BoundedSampler):
    @tf.function
    def sample(self) -> tf.Tensor:
        # raw_samples = tf.random.uniform((self.n_samples, self.dim), dtype="float32")
        raw_samples = tf.random.uniform((self.n_samples, self.dim), dtype="float32")
        return self.transform_to_extent(raw_samples)


class DynamicMonteCarloSampler(Sampler, ABC):
    def __init__(self, n_samples: int, dim: int, target_log_prob: Callable, **kwargs):
        super(DynamicMonteCarloSampler, self).__init__(n_samples, dim, **kwargs)
        self.log_prob = target_log_prob


class RandomWalkMetropolis(DynamicMonteCarloSampler):
    """
    Random Walk Metropolis Hastings with a Replica Exchange Kernel.
    """

    def __init__(self, n_samples: int, dim: int, target_log_prob: Callable, **kwargs):
        super(RandomWalkMetropolis, self).__init__(n_samples, dim, target_log_prob, **kwargs)
        if dim == 1:
            self.state_fun = tfp.mcmc.random_walk_normal_fn(scale=kwargs.get("scale", .2))
            inverse_temperatures = 0.5 ** tf.range(6, dtype="float32")

        else:
            self.state_fun = tfp.mcmc.random_walk_normal_fn(scale=kwargs.get("scale", .05))
            inverse_temperatures = 0.5 ** tf.range(4, dtype="float32")
        self.kernel = tfp.mcmc.ReplicaExchangeMC(
            target_log_prob_fn=self.log_prob,
            inverse_temperatures=inverse_temperatures,
            make_kernel_fn=lambda log_prob: tfp.mcmc.RandomWalkMetropolis(log_prob,
                                                                          new_state_fn=self.state_fun))
        # self.kernel = tfp.mcmc.RandomWalkMetropolis(self.log_prob, new_state_fn=self.state_fun)
        self.kernel = mcmc.TransformedTransitionKernel(
            inner_kernel=self.kernel,
            bijector=tfp.bijectors.Sigmoid(low=kwargs.get("tmin", 0), high=kwargs.get("tmax", 3),
                                           validate_args=False,
                                           name='sigmoid')
        )

    @tf.function(jit_compile=True, autograph=False)
    def sample(self, init, kernel_result) -> tf.Tensor:
        result = tfp.mcmc.sample_chain(
            num_results=tf.cast(tf.math.maximum(tf.round(self.n_samples / tf.shape(init)[0]), 2), dtype="int32"),
            num_steps_between_results=0,
            current_state=init,
            kernel=self.kernel,
            num_burnin_steps=2000,
            return_final_kernel_results=True,
            previous_kernel_results=kernel_result,
            trace_fn=None)
        samples = result.all_states
        kernel_results = result.final_kernel_results

        # return tf.stop_gradient(tf_random_choice(tf.reshape(samples, (-1, self.dim)), self.n_samples))
        return tf.stop_gradient(tf.reshape(samples, (-1, self.dim))), kernel_results

    @tf.function(jit_compile=True, autograph=False)
    def benchmark_sample(self, init, n_samples) -> tf.Tensor:
        samples = tfp.mcmc.sample_chain(
            num_results=tf.cast(tf.math.maximum(tf.round(n_samples / tf.shape(init)[0]), 2), dtype="int32"),
            num_steps_between_results=0,
            current_state=init,
            kernel=self.kernel,
            num_burnin_steps=1000,
            trace_fn=None)
        # return tf.stop_gradient(tf_random_choice(tf.reshape(samples, (-1, self.dim)), self.n_samples))
        return tf.stop_gradient(tf.reshape(samples, (-1, self.dim)))


class HamiltonianMC(DynamicMonteCarloSampler):
    """
        Hamiltonian Monte-Carlo with a Replica Exchange Kernel and dual averaging step size adaption.
    """

    def __init__(self, n_samples: int, dim: int, target_log_prob: Callable, **kwargs):
        super(HamiltonianMC, self).__init__(n_samples, dim, target_log_prob, **kwargs)
        self.num_burnin = 1000
        make_kernel_fn = lambda log_prob: mcmc.DualAveragingStepSizeAdaptation(
            inner_kernel=mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=log_prob,
                step_size=0.001,
                num_leapfrog_steps=10),
            num_adaptation_steps=int(self.num_burnin * 0.8))

        inverse_temperatures = 0.5 ** tf.range(4, dtype="float32")
        self.kernel = tfp.mcmc.ReplicaExchangeMC(
            target_log_prob_fn=self.log_prob,
            inverse_temperatures=inverse_temperatures,
            make_kernel_fn=make_kernel_fn)
        bijectors = [tfp.bijectors.Sigmoid(low=kwargs.get("tmin"), high=kwargs.get("tmax"),
                                           validate_args=False,
                                           name='sigmoid')]

        self.kernel = mcmc.TransformedTransitionKernel(
            inner_kernel=self.kernel,
            bijector=tfb.Chain(bijectors)
        )

    @tf.function(jit_compile=True, autograph=False)
    def sample(self, init) -> tf.Tensor:
        samples = tfp.mcmc.sample_chain(
            num_results=tf.cast(tf.math.maximum(tf.round(self.n_samples / tf.shape(init)[0]), 2), dtype="int32"),
            num_steps_between_results=0,
            current_state=init,
            kernel=self.kernel,
            num_burnin_steps=self.num_burnin,
            trace_fn=None)
        return tf.stop_gradient(tf_random_choice(tf.reshape(samples, (-1, self.dim)), self.n_samples))

    @tf.function(jit_compile=True, autograph=False)
    def benchmark_sample(self, n_samples, init) -> tf.Tensor:
        samples = tfp.mcmc.sample_chain(
            num_results=tf.cast(tf.math.maximum(tf.round(n_samples / tf.shape(init)[0]), 2), dtype="int32"),
            num_steps_between_results=0,
            current_state=init,
            kernel=self.kernel,
            num_burnin_steps=self.num_burnin,
            trace_fn=None)
        return tf.stop_gradient(tf_random_choice(tf.reshape(samples, (-1, self.dim)), self.n_samples))


class NUTS(DynamicMonteCarloSampler):
    """
        Hamiltonian Monte-Carlo with a Replica Exchange Kernel and dual averaging step size adaption.
    """

    def __init__(self, n_samples: int, dim: int, target_log_prob: Callable, **kwargs):
        super(NUTS, self).__init__(n_samples, dim, target_log_prob, **kwargs)
        self.num_burnin = 10  # 00

        self.kernel = tfp.mcmc.NoUTurnSampler(
            target_log_prob_fn=target_log_prob,
            step_size=0.001,
        )

        bijectors = [tfp.bijectors.Sigmoid(low=kwargs.get("tmin"), high=kwargs.get("tmax"),
                                           validate_args=False,
                                           name='sigmoid')]

        self.kernel = mcmc.TransformedTransitionKernel(
            inner_kernel=self.kernel,
            bijector=tfb.Chain(bijectors)
        )

    @tf.function(jit_compile=True, autograph=False)
    def sample(self, init) -> tf.Tensor:
        samples = tfp.mcmc.sample_chain(
            num_results=tf.cast(tf.math.maximum(tf.round(self.n_samples / tf.shape(init)[0]), 2), dtype="int32"),
            num_steps_between_results=0,
            current_state=init,
            kernel=self.kernel,
            num_burnin_steps=self.num_burnin,
            trace_fn=None)
        return tf.stop_gradient(tf_random_choice(tf.reshape(samples, (-1, self.dim)), self.n_samples))

    @tf.function(jit_compile=True, autograph=False)
    def benchmark_sample(self, n_samples, init) -> tf.Tensor:
        samples = tfp.mcmc.sample_chain(
            num_results=tf.cast(tf.math.maximum(tf.round(n_samples / tf.shape(init)[0]), 2), dtype="int32"),
            num_steps_between_results=0,
            current_state=init,
            kernel=self.kernel,
            num_burnin_steps=self.num_burnin,
            trace_fn=None)
        return tf.stop_gradient(tf_random_choice(tf.reshape(samples, (-1, self.dim)), self.n_samples))


tf.keras.utils.get_custom_objects().update({
    'SobolSampler': SobolSampler,
    'HaltonSampler': HaltonSampler,
    'UniformSampler': UniformSampler,
    'RandomWalkMetropolis': RandomWalkMetropolis,
    'HamiltonianMC': HamiltonianMC
})
import optuna


class CollocationPointResampler(dde.callbacks.Callback):
    """Resamples periodically the collocation points. Subclasses implement specific ways, how the points are resampled.
   """

    def __init__(self, n_samples, n_background, n_iterations, dim, display_every=200, period=100, n_proposals=50_000,
                 model=None, plot_samples=False, optuna_trial: optuna.Trial = None):
        """

        :param n_samples:
        :param n_background:
        :param period:
        :param n_proposals:
        """
        super().__init__()
        self.period = period

        self.num_bcs_initial = None
        self.epochs_since_last_resample = 0
        self.n_samples = n_samples
        self.n_background = n_background
        self.n_proposals = n_proposals
        self.epoch = 0
        self.start_after_nepochs = 1_000
        self.plot_samples = plot_samples

        self.dim = dim
        self._model = model

        self.times = []
        self.full_times = []
        self.start_time = None

        self.display_every = display_every
        self.n_iterations = n_iterations
        self.kernel_result = None

        self.optuna_trial = optuna_trial
        self.prev_samples = None

        self.counter = tf.Variable(0, dtype="int32")

    def uniform(self, n_samples):
        self.counter.assign_add(1)
        return tf.random.stateless_uniform([n_samples], (self.counter, self.counter+1))

    def log_cur_time(self):
        self.times.append(time.perf_counter() - self.start_time)
        self.full_times.append(time.time())

    def it_sample(self, n_samples):
        proposal_points = self.model.data.geom.random_points(self.n_proposals,
                                                             random=self.model.data.train_distribution)
        proposals_density = tf.exp(self.model.net(proposal_points))
        # Inverse transform sampling from the empirical CDF
        cum_dist = tf.squeeze(tf.math.cumsum(proposals_density))
        cum_dist /= cum_dist[-1]
        # unif_samp = tf.random.uniform((n_samples,), 0, 1)
        unif_samp = self.uniform(n_samples)
        idxs = tf.searchsorted(cum_dist, unif_samp)

        samples_density = tf.gather(proposal_points, idxs)
        return samples_density

    def it_sample_t0(self, n_samples):

        # proposal_points = self.model.data.geom.random_points(self.n_proposals,
        #                                                      random=self.model.data.train_distribution)

        proposal_points = self.model.data.geom.random_initial_points(self.n_proposals,
                                                                     random=self.model.data.train_distribution)
        proposals_density = tf.exp(self.model.net(proposal_points))
        # Inverse transform sampling from the empirical CDF
        cum_dist = tf.squeeze(tf.math.cumsum(proposals_density))
        cum_dist /= cum_dist[-1]
        # unif_samp = tf.random.uniform((n_samples,), 0, 1)
        unif_samp = self.uniform(n_samples)

        idxs = tf.searchsorted(cum_dist, unif_samp)

        samples_density = tf.gather(proposal_points, idxs)
        return samples_density

    def on_train_begin(self):
        self.num_bcs_initial = self.model.data.num_bcs

        self.start_time = time.perf_counter()
        self.times.append(0)
        self.full_times.append(time.time())

    def sample_pseudo_uniform(self, n_samples):
        # select random points (pseudo-uniform) on the domain for approximating the empirical CDF
        samples_density = self.model.data.geom.random_points(n_samples,
                                                             random=self.model.data.train_distribution)
        return samples_density

    def on_epoch_end(self):
        self.epochs_since_last_resample += 1
        self.epoch += 1
        if self.epochs_since_last_resample < self.period:
            return
        self.epochs_since_last_resample = 0

        # some uniform background collocation points
        background_points = self.model.data.geom.random_points(
            self.n_background, random=self.model.data.train_distribution
        )
        # if isinstance(self, PdpinnRarResampler) or isinstance(self, MHPdpinnRarResampler):
        #     background_points = self.pdpinn_sampler.custom_sample(n_samples=self.n_background,
        #                                                           mcmc_init=self.it_sample_t0(10))
        # else:
        #     background_points = self.model.data.geom.random_points(
        #         self.n_background, random=self.model.data.train_distribution
        #     )
        # the initial condition points
        ic_points = self.model.data.geom.random_initial_points(
            self.model.data.num_initial, random=self.model.data.train_distribution
        )

        sample_weights = np.ones(self.n_samples)  # /1e12
        if self.epoch < self.start_after_nepochs:
            samples_density = self.sample_pseudo_uniform(n_samples=self.n_samples)
        else:
            samples_density = self.custom_sample(n_samples=self.n_samples,
                                                 #  mcmc_init=self.it_sample(10),
                                                 mcmc_init=tf_random_choice(self.prev_samples,
                                                                            10) if self.prev_samples is not None else self.it_sample_t0(
                                                     10),
                                                 inplace_weight=sample_weights)
        #
        # idx = np.random.randint(0, self.n_samples, 1_000)
        # plt.scatter(*samples_density[idx][..., [0, -1]].T, marker='+')
        # plt.savefig(f"tmp_plots/tmp_{self.epoch}")
        # plt.close()
        # concat_samples = tf.concat([samples_density, background_points, ic_points], 0).numpy()
        concat_samples = np.concatenate([samples_density, background_points, ic_points], 0)
        loss_weights = np.ones(concat_samples.shape[0])
        loss_weights[:self.n_samples] = sample_weights
        self.model.loss_weights = loss_weights

        self.model.data.replace_with_anchors(concat_samples)

        if self.model.train_state.step % self.display_every == 0 or self.epoch + 1 == self.n_iterations:
            self.log_cur_time()

        if self.optuna_trial:
            self.optuna_trial.report(value=self.model.losshistory.metrics_test[-1][0],
                                     step=self.model.losshistory.steps[-1])
            if self.optuna_trial.should_prune():
                raise optuna.TrialPruned()

    @abstractmethod
    def custom_sample(self, n_samples: int, mcmc_init=None, **kwargs) -> tf.Tensor:
        ...


class UniformResampler(CollocationPointResampler):
    """
    The default (pseudo) uniform sampler.
    """

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs) -> tf.Tensor:
        return self.sample_pseudo_uniform(n_samples)

    def benchmark_sample(self, n_samples):
        return self.custom_sample(n_samples)


class TrueDensityResampler(CollocationPointResampler):
    def init(self, **kwargs):
        super(TrueDensityResampler, self).init(**kwargs)
        self.start_after_nepochs = 0#2_000

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs) -> tf.Tensor:
        dt= min(2*(self.epoch / self.n_iterations) + 0.2, 2)
        return self.model.sample_true_dist(self.model.data, n_samples, dt)

    def benchmark_sample(self, n_samples):
        return self.custom_sample(n_samples)


class FlowResampler(CollocationPointResampler):
    def custom_sample(self, n_samples, mcmc_init=None, **kwargs) -> tf.Tensor:
        return self.model.net.net.sample(n_samples).numpy()

    def benchmark_sample(self, n_samples):
        return self.custom_sample(n_samples)


class InverseTransformResampler(CollocationPointResampler):
    """
    Inverse Transform sampling based on uniform proposals.
    Compare https://en.wikipedia.org/wiki/Inverse_transform_sampling

    """

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs):
        proposal_points = self.model.data.geom.random_points(self.n_proposals,
                                                             random=self.model.data.train_distribution)
        proposals_density = tf.exp(self.model.net(proposal_points))
        # Inverse transform sampling from the empirical CDF
        cum_dist = tf.squeeze(tf.math.cumsum(proposals_density))
        cum_dist /= cum_dist[-1]

        unif_samp = self.uniform(n_samples)
        idxs = tf.searchsorted(cum_dist, unif_samp)

        samples_density = tf.gather(proposal_points, idxs)

        return samples_density.numpy()

    def benchmark_sample(self, n_samples):
        return self.custom_sample(n_samples)


class InverseTransformImportanceResampler(CollocationPointResampler):
    """
    Inverse Transform sampling based on uniform proposals.
    Compare https://en.wikipedia.org/wiki/Inverse_transform_sampling

    """

    def init(self, **kwargs):
        super(InverseTransformImportanceResampler, self).init(**kwargs)

        # @tf.function
        # def op(inputs):
        #     y = self.model.net(inputs)
        #     return self.model.data.pde(inputs, y)
        # self.pde_loss_op = op
        self.start_after_nepochs = 2000

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs):
        proposal_points = self.model.data.geom.random_points(self.n_proposals,
                                                             random=self.model.data.train_distribution)
        # proposals_pdeloss = np.square(self.model.predict(proposal_points, operator=self.model.data.pde,
        #                                         callbacks=self.model.callbacks.callbacks if self.model.callbacks else None))
        proposals_pdeloss = np.square(self.model.evaluate_pde_loss(proposal_points))
        # Inverse transform sampling from the empirical CDF
        cum_dist = tf.squeeze(tf.math.cumsum(proposals_pdeloss))
        cum_dist /= cum_dist[-1]

        # unif_samp = tf.random.uniform((n_samples,), 0, 1)
        unif_samp = self.uniform(n_samples)
        idxs = tf.searchsorted(cum_dist, unif_samp)

        samples_density = tf.gather(proposal_points, idxs)

        # weight = kwargs.get("inplace_weight", None)
        # if weight is not None:
        #     weight /= (np.squeeze(proposals_pdeloss)[idxs] + 1e-10)
        return samples_density.numpy()

    def benchmark_sample(self, n_samples):
        return self.custom_sample(n_samples)


class RarResampler(CollocationPointResampler):
    """
    Inverse Transform sampling based on uniform proposals.
    Compare https://en.wikipedia.org/wiki/Inverse_transform_sampling

    """

    def init(self, **kwargs):
        super(RarResampler, self).init(**kwargs)
        self.start_after_nepochs = 5000

        self.pdpinn_sampler = InverseTransformResampler(n_samples=self.n_background, n_background=0., n_iterations=0.,
                                                        dim=self.dim)
        self.pdpinn_sampler.model = self.model

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs):
        proposal_points = self.model.data.geom.random_points(self.n_proposals,
                                                             random=self.model.data.train_distribution)

        # proposal_points= tf.convert_to_tensor(proposal_points)
        # log_density = self.model.net(proposal_points)
        # self.model.data.pde(proposal_points, log_density)
        # proposals_pdeloss = np.abs(self.model.predict(proposal_points, operator=self.model.data.pde,
        #                                               callbacks=self.model.callbacks.callbacks if self.model.callbacks else None))
        proposals_pdeloss = np.abs(self.model.evaluate_pde_loss(proposal_points))

        # values, idxs = tf.math.top_k(np.squeeze(proposals_pdeloss), n_samples)

        # top_samples = tf.gather(proposal_points, idxs)
        idxs = np.squeeze(-proposals_pdeloss).argsort()[:n_samples]
        top_samples = proposal_points[idxs, ...]
        return top_samples

    def benchmark_sample(self, n_samples, **kwargs):
        return self.custom_sample(n_samples)


class PdpinnRarResampler(CollocationPointResampler):
    """
    Inverse Transform sampling based on uniform proposals.
    Compare https://en.wikipedia.org/wiki/Inverse_transform_sampling

    """

    def init(self, **kwargs):
        super(PdpinnRarResampler, self).init(**kwargs)
        # self.start_after_nepochs = 5000

        self.pdpinn_sampler = InverseTransformResampler(n_samples=self.n_background, n_background=0., n_iterations=0.,
                                                        dim=self.dim)
        self.pdpinn_sampler.model = self.model

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs):
        # proposal_points = self.model.data.geom.random_points(self.n_proposals,
        #                                                      random=self.model.data.train_distribution)

        proposal_points = self.pdpinn_sampler.custom_sample(n_samples=n_samples)
        # proposal_points= tf.convert_to_tensor(proposal_points)
        # log_density = self.model.net(proposal_points)
        # self.model.data.pde(proposal_points, log_density)
        # proposals_pdeloss = np.abs(self.model.predict(proposal_points, operator=self.model.data.pde,
        #                                               callbacks=self.model.callbacks.callbacks if self.model.callbacks else None))
        proposals_pdeloss = np.abs(self.model.evaluate_pde_loss(proposal_points))

        # values, idxs = tf.math.top_k(np.squeeze(proposals_pdeloss), n_samples)

        # top_samples = tf.gather(proposal_points, idxs)
        idxs = np.squeeze(-proposals_pdeloss).argsort()[:n_samples]
        top_samples = proposal_points[idxs, ...]
        return top_samples

    def benchmark_sample(self, n_samples, **kwargs):
        return self.custom_sample(n_samples)


class TrueRarResampler(CollocationPointResampler):
    """
    Inverse Transform sampling based on uniform proposals.
    Compare https://en.wikipedia.org/wiki/Inverse_transform_sampling

    """

    def init(self, **kwargs):
        super(TrueRarResampler, self).init(**kwargs)
        # self.start_after_nepochs = 5000

        self.true_sampler = TrueDensityResampler(n_samples=self.n_background, n_background=0., n_iterations=0.,
                                                 dim=self.dim)
        self.true_sampler.model = self.model
        self.true_sampler.start_after_nepochs = 0

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs):
        # proposal_points = self.model.data.geom.random_points(self.n_proposals,
        #                                                      random=self.model.data.train_distribution)

        proposal_points = self.true_sampler.custom_sample(n_samples=n_samples)
        # proposal_points= tf.convert_to_tensor(proposal_points)
        # log_density = self.model.net(proposal_points)
        # self.model.data.pde(proposal_points, log_density)
        # proposals_pdeloss = np.abs(self.model.predict(proposal_points, operator=self.model.data.pde,
        #                                               callbacks=self.model.callbacks.callbacks if self.model.callbacks else None))
        proposals_pdeloss = np.abs(self.model.evaluate_pde_loss(proposal_points))

        # values, idxs = tf.math.top_k(np.squeeze(proposals_pdeloss), n_samples)

        # top_samples = tf.gather(proposal_points, idxs)
        idxs = np.squeeze(-proposals_pdeloss).argsort()[:n_samples]
        top_samples = proposal_points[idxs, ...]
        return top_samples

    def benchmark_sample(self, n_samples, **kwargs):
        return self.custom_sample(n_samples)


class HMCPdpinnRarResampler(CollocationPointResampler):
    """
    Inverse Transform sampling based on uniform proposals.
    Compare https://en.wikipedia.org/wiki/Inverse_transform_sampling

    """

    def init(self, **kwargs):
        super(HMCPdpinnRarResampler, self).init(**kwargs)
        # self.start_after_nepochs = 5000

        # self.pdpinn_sampler = RandomWalkMHResampler(n_samples=self.n_background, n_background=0., n_iterations=0.,
        #                                            dim=self.dim)
        self.pdpinn_sampler = HamiltonianMCResampler(n_samples=self.n_background, n_background=0., n_iterations=0.,
                                                     dim=self.dim)
        self.pdpinn_sampler.model = self.model

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs):
        # proposal_points = self.model.data.geom.random_points(self.n_proposals,
        #                                                      random=self.model.data.train_distribution)

        # proposal_points = self.pdpinn_sampler.custom_sample(n_samples=n_samples,
        #                                                     mcmc_init=self.it_sample_t0(10))

        proposal_points = self.pdpinn_sampler.custom_sample(n_samples=self.n_proposals,  # n_samples,
                                                            mcmc_init=self.it_sample(10))
        # proposal_points= tf.convert_to_tensor(proposal_points)
        # log_density = self.model.net(proposal_points)
        # self.model.data.pde(proposal_points, log_density)
        # proposals_pdeloss = np.abs(self.model.predict(proposal_points, operator=self.model.data.pde,
        #                                               callbacks=self.model.callbacks.callbacks if self.model.callbacks else None))
        proposals_pdeloss = np.abs(self.model.evaluate_pde_loss(proposal_points))

        # values, idxs = tf.math.top_k(np.squeeze(proposals_pdeloss), n_samples)

        # top_samples = tf.gather(proposal_points, idxs)
        idxs = np.squeeze(-proposals_pdeloss).argsort()[:n_samples]
        top_samples = proposal_points[idxs, ...]
        return top_samples

    def benchmark_sample(self, n_samples, **kwargs):
        return self.custom_sample(n_samples)


class HamiltonianMCResampler(CollocationPointResampler):
    """
    Resampler with HMC.
    """

    def __init__(self, **kwargs):
        super(HamiltonianMCResampler, self).__init__(**kwargs)

        self.mcmc_init = np.concatenate((np.zeros((10, self.dim), dtype="float32"),
                                         -1. * np.ones((10, 1), dtype="float32")), -1
                                        )
        self.start_after_nepochs = 1000
        assert self.n_proposals > self.n_samples

        def helper(x):
            """
            Reshape input and output to deal with the multiple chains.
            :param x:
            :return:
            """
            x_ = tf.reshape(x, (-1, self.dim + 1))
            out = self.model.net(x_)
            return tf.reshape(out, tf.shape(x)[:-1])

        # self.sampler = HamiltonianMC(n_samples=self.n_samples, dim=self.dim + 1,
        self.sampler = HamiltonianMC(n_samples=self.n_samples, dim=self.dim + 1,
                                     target_log_prob=helper, scale=.07,
                                     tmin=-1.05, tmax=1.05)
        self.prev_samples = None

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs) -> tf.Tensor:
        samples = self.sampler.sample(mcmc_init if mcmc_init is not None else self.mcmc_init)
        # selected_samples = tf_random_choice(samples, n_samples).numpy()
        self.prev_samples = samples.numpy()
        return self.prev_samples  # tf_random_choice(samples, n_samples).numpy()

    def benchmark_sample(self, n_samples):
        return self.sampler.benchmark_sample(init=self.mcmc_init, n_samples=n_samples)


class NUTSResampler(CollocationPointResampler):
    """
    Resampler with NUTS.
    """

    def __init__(self, **kwargs):
        super(NUTSResampler, self).__init__(**kwargs)

        self.mcmc_init = np.concatenate((np.zeros((10, self.dim), dtype="float32"),
                                         -1. * np.ones((10, 1), dtype="float32")), -1
                                        )
        self.start_after_nepochs = 2000

        def helper(x):
            """
            Reshape input and output to deal with the multiple chains.
            :param x:
            :return:
            """
            x_ = tf.reshape(x, (-1, self.dim + 1))
            out = self.model.net(x_)
            return tf.reshape(out, tf.shape(x)[:-1])

        self.sampler = NUTS(n_samples=self.n_samples, dim=self.dim + 1,
                            target_log_prob=helper, scale=.07,
                            tmin=-1.05, tmax=1.05)

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs) -> tf.Tensor:
        return self.sampler.sample(mcmc_init if mcmc_init is not None else self.mcmc_init).numpy()

    def benchmark_sample(self, n_samples):
        return self.sampler.benchmark_sample(init=self.mcmc_init, n_samples=n_samples)


class RandomWalkMHResampler(CollocationPointResampler):
    """
    Resampler with Random Walk Metropolis Hastings.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.mcmc_init = np.concatenate((np.zeros((10, self.dim), dtype="float32"),
                                         -1. * np.ones((10, 1), dtype="float32")), -1
                                        )
        self.kernel_result = None
        self.start_after_nepochs = 1000


        def helper(x):
            """
            Reshape input and output to deal with the multiple chains.
            :param x:
            :return:
            """
            x_ = tf.reshape(x, (-1, self.dim + 1))
            # out = self._model.net(x_)
            out = self.model.net(x_)
            return tf.reshape(out, tf.shape(x)[:-1])

        self.sampler = RandomWalkMetropolis(n_samples=self.n_samples, dim=self.dim + 1,
                                            target_log_prob=helper, scale=.3,
                                            tmin=-1.05, tmax=1.05)

    def custom_sample(self, n_samples, mcmc_init=None, **kwargs) -> tf.Tensor:


        rel_diff_var =  np.inf
        rel_diff_mean = np.inf
        tries = 0

        while tries < 2 and (rel_diff_var > 0.2 or rel_diff_mean > 0.2):
            samples, kernel_state = self.sampler.sample(mcmc_init if mcmc_init is not None else self.mcmc_init,
                                                        kernel_result=self.kernel_result)
            new_samples = samples.numpy()
            if self.prev_samples is not None:
                abs_diff_mean = np.abs(np.mean(self.prev_samples, 0) - np.mean(new_samples, 0))
                rel_diff_mean = np.mean(abs_diff_mean / np.abs(np.mean(self.prev_samples, 0)))
                abs_diff_var = np.abs(np.var(self.prev_samples, 0) - np.var(new_samples, 0))
                rel_diff_var = np.mean(abs_diff_var / np.abs(np.var(self.prev_samples, 0)))
                # print(f"mean diff {rel_diff_mean:.2f} var diff {rel_diff_var:.2f}")
            else:
                rel_diff_var = 0
                rel_diff_mean = 0
            if self.epoch < 3_000:
                rel_diff_var = 0
                rel_diff_mean = 0
            tries += 1
        self.prev_samples = new_samples
        return self.prev_samples

    def benchmark_sample(self, n_samples):
        return self.sampler.benchmark_sample(init=self.mcmc_init, n_samples=n_samples)
