from abc import ABC, abstractmethod
from typing import Callable

import numba
import numpy as np
import scipy.stats
from scipy.stats import uniform

from typing import Optional


class MetropolisHastings(ABC):
    def __init__(self, log_p_unnormalized: Callable, num_chains, data_dim):
        self.log_p_unnormalized = log_p_unnormalized
        self.num_chains = num_chains
        self.data_dim = data_dim

        self.samples = np.empty((0, self.num_chains, self.data_dim))
        self.counter = 0
        self.acceptance_rate = np.nan

    def callback(self, i, n_samples, counter, burnin):
        return

    def sample_chains(self, n_samples, burnin=1000, x_init: np.array = None):
        # assert burnin < n_samples
        samples = np.empty((n_samples, self.num_chains, self.data_dim))
        if x_init is not None and self.samples.shape[0] < 1:
            samples[0, ...] = x_init
        else:
            samples[0, ...] = self.samples[-1, ...]
        proposal_u = np.random.rand(self.num_chains, n_samples)

        self.count = 0
        for i in range(1, n_samples):
            prev_sample = samples[i - 1, ...]
            proposal = self.proposal(prev_sample, i)
            a_part = (self.log_p_unnormalized(proposal) + self.proposal_logp(x=prev_sample, cond=proposal, i=i))
            b_part = self.log_p_unnormalized(samples[i - 1, :]) + self.proposal_logp(x=proposal,
                                                                                     cond=prev_sample, i=i)
            # a_part = (self.log_p_unnormalized(proposal) - self.log_p_unnormalized(prev_sample))
            # b_part = (self.proposal_logp(prev_sample, proposal, i) - self.proposal_logp(proposal, prev_sample, i))

            # ratio = np.exp(a_part + b_part)
            ratio = np.exp(a_part - b_part)
            accepted = proposal_u[..., [i]] < ratio  # (np.clip(ratio, 0., 1.))
            new_sample = accepted * proposal + (~accepted) * samples[i - 1, ...]
            samples[i, :] = new_sample

            self.count += np.sum(accepted)

            self.callback(i, n_samples, self.count, burnin)
        self.acceptance_rate = self.count / (n_samples * self.num_chains)

        self.samples = np.concatenate([self.samples, samples], 0)
        return samples[burnin:, ...].reshape(-1, self.data_dim)

    @staticmethod
    @numba.jit(nopython=True)
    def gaussian_logpdf(x, mean, cov):
        """
        Can't directly pass scipy functions through numba, so I just use this instead..
        https://gregorygundersen.com/blog/2019/10/30/scipy-multivariate/
        """
        logpdf = np.empty((x.shape[0], 1))
        vals, vecs = np.linalg.eigh(cov)
        # `eigh` assumes the matrix is Hermitian.
        logdet = np.sum(np.log(vals))
        valsinv = np.array([1. / v for v in vals])

        # `vecs` is R times D while `vals` is a R-vector where R is the matrix
        # rank. The asterisk performs element-wise multiplication.
        U = vecs * np.sqrt(valsinv)
        rank = len(vals)
        dev = x - mean

        for i in range(x.shape[0]):
            # "maha" for "Mahalanobis distance".
            maha = np.square(np.dot(dev[i], U)).sum()
            log2pi = np.log(2 * np.pi)
            logpdf[i] = -0.5 * (rank * log2pi + maha + logdet)
        return logpdf

    @abstractmethod
    def proposal(self, previous_sample, i):
        ...

    @abstractmethod
    def proposal_logp(self, cond, x, i):
        ...


class LocalSampler(MetropolisHastings):
    def __init__(self, log_p_unnormalized: Callable,
                 data_dim,
                 num_chains: int = 500, eps=(2 ** (2 / 3)) / 5, **kwargs):
        super(LocalSampler, self).__init__(log_p_unnormalized=log_p_unnormalized, num_chains=num_chains,
                                           data_dim=data_dim)

        self.eps = eps
        self.proposal_scale = self.eps * np.eye(self.data_dim)
        self.L = np.linalg.cholesky(self.proposal_scale + 1e-9 * np.eye(self.data_dim))

    def proposal(self, previous_sample, i):
        noise = np.random.normal(size=(self.data_dim, self.num_chains)).T @ self.L
        return previous_sample + noise

    def proposal_logp(self, cond, x, i):
        return MetropolisHastings.gaussian_logpdf(x, cond, self.proposal_scale)


class MALA(LocalSampler):
    def __init__(self, log_p_unnormalized: Callable,
                 log_dp_unnormalized: Callable,
                 data_dim,
                 num_chains: int = 500, eps=(2 ** (2 / 3)) / 4):
        super(MALA, self).__init__(log_p_unnormalized=log_p_unnormalized, num_chains=num_chains,
                                   data_dim=data_dim, eps=eps / 2)
        self.log_dp_unnormalized = log_dp_unnormalized
        print(f"eps: {eps:.2f}")
        self.eps = eps
        self.proposal_scale = self.eps * np.eye(data_dim)
        self.proportion_random = 0.2

    def proposal(self, previous_sample, i):
        if np.random.random() < self.proportion_random or i < 100:
            drift = 0
            noise = np.random.normal(size=(self.data_dim, self.num_chains)).T @ self.L * 2

        else:
            drift = (self.eps / 2) * self.log_dp_unnormalized(previous_sample)
            noise = np.random.normal(size=(self.data_dim, self.num_chains)).T @ self.L

        return previous_sample + drift + noise

    def proposal_logp(self, cond, x, i):
        drift = (self.eps / 2) * self.log_dp_unnormalized(cond)
        p1 = MetropolisHastings.gaussian_logpdf(x=x, mean=cond, cov=self.proposal_scale)
        if i >= 100:
            p2 = MetropolisHastings.gaussian_logpdf(x=x, mean=cond + drift, cov=self.proposal_scale)
            return self.proportion_random * p1 + (1 - self.proportion_random) * p2
        else:
            return p1


class IndependentSampler(MetropolisHastings):
    def __init__(self, log_p_unnormalized: Callable,
                 data_dim,
                 prior_sample: Callable,
                 prior_logprob: Callable,
                 num_chains: int = 500,
                 **kwargs):
        super(IndependentSampler, self).__init__(log_p_unnormalized=log_p_unnormalized, num_chains=num_chains,
                                                 data_dim=data_dim)

        self.proposal_dist = prior_sample
        self.proposal_logprob = prior_logprob

    def proposal(self, previous_sample, i):
        sample = self.proposal_dist(self.num_chains)
        # sample = np.random.rand(self.num_chains, self.data_dim) * self.scale[np.newaxis, ...] + self.min_lim[
        #     np.newaxis, ...]
        return sample

    def proposal_logp(self, cond, x, i):
        # return np.ones((x.shape[0], 1))  # MetropolisHastings.logpdf(x, cond, self.proposal_scale)
        return self.proposal_logprob(x).reshape(-1,1) # MetropolisHastings.logpdf(x, cond, self.proposal_scale)

class LocalMixSampler(MetropolisHastings):
    def __init__(self, log_p_unnormalized: Callable,
                 prior_sample: Callable,
                 prior_logprob: Callable,
                 data_dim,
                 num_chains: int = 500, eps=(2 ** (2 / 3)) / 5):
        super(LocalMixSampler, self).__init__(log_p_unnormalized=log_p_unnormalized, num_chains=num_chains,
                                           data_dim=data_dim)

        self.eps = eps
        self.proposal_scale = self.eps * np.eye(self.data_dim)
        self.L = np.linalg.cholesky(self.proposal_scale + 1e-9 * np.eye(self.data_dim))

        self.proposal_dist = prior_sample
        self.proposal_logprob = prior_logprob

    def proposal(self, previous_sample, i):
        independent_sample = np.random.rand(self.num_chains) > 0.5

        noise = np.random.normal(size=(self.data_dim, self.num_chains)).T @ self.L
        local_proposal = previous_sample + noise

        independent_proposal = self.proposal_dist(self.num_chains)
        final_proposal = independent_sample[:, np.newaxis] * independent_proposal + (~independent_sample[:, np.newaxis]) * local_proposal
        return final_proposal

    def proposal_logp(self, cond, x, i):
        logp_local = MetropolisHastings.gaussian_logpdf(x, cond, self.proposal_scale)
        logp_independent = self.proposal_logprob(x).reshape(-1,1)

        return np.log(0.5 * (np.exp(logp_local) + np.exp(logp_independent)))

