import torch
import numpy as np
from tqdm import tqdm
from .proj_model import ProjectedModel
from .qmc_wrapper import QmcWrapper


class ImportanceSampler(torch.nn.Module):
    # Importance sampler for subspace inference
    def __init__(self, base, subspace, criterion, proposal_var=1., temperature=1.,
                 proposal_type='gaussian', deg_f=4, loader=None, data=None,
                 device=torch.device('cpu'), prior_scale=1.0, *args, **kwargs):
        super(ImportanceSampler, self).__init__()
        self.base_model = base(*args, **kwargs).to(device)
        self.base_model.eval()
        self.subspace = subspace
        self.data = data
        self.criterion = criterion
        self.loader = loader
        self.temperature = temperature
        self.proposal_type = proposal_type
        self.proposal_var = proposal_var
        self.device = device
        self.proj_mat = self.subspace.cov_factor.t().to(device)  # projection matrix, dim: (d, k), d >> k
        self.dim = self.proj_mat.shape[1]
        self.prior_scale = prior_scale

        # set prior distribution based on subspace
        inverse_D = torch.inverse(self.proj_mat.t() @ self.proj_mat)  # dim: (k, k)
        self.prior = torch.distributions.MultivariateNormal(
            -inverse_D @ self.proj_mat.t() @ self.subspace.mean.to(device),
            covariance_matrix=inverse_D * (self.prior_scale ** 2),
            validate_args=False
        )

        # set proposal distribution for calculate log density
        if proposal_type == 'uniform':
            self.proposal = torch.distributions.Uniform(
                -np.sqrt(self.proposal_var) * torch.ones(self.dim, device=device),
                np.sqrt(self.proposal_var) * torch.ones(self.dim, device=device)
            )
        elif proposal_type == 'gaussian':
            self.proposal = torch.distributions.MultivariateNormal(
                torch.zeros(self.dim, device=device),
                self.proposal_var * torch.eye(self.dim, device=device),
                validate_args=False
            )
        elif proposal_type == 'student-t':
            self.proposal = torch.distributions.StudentT(
                deg_f,
                torch.zeros(self.dim, device=device),
                self.proposal_var * torch.eye(self.dim, device=device)
            )
        else:
            raise NotImplementedError

    def sampling_with_weights(self, num_samples, enable_qmc=True, enable_tqdm=False, proposal_var=0):
        """
        Importance sampling with weights
        :param num_samples: number of samples to generate
        :param enable_qmc: whether to use Randomize QMC sampling
        :param enable_tqdm: whether to use tqdm for progress bar
        :param proposal_var: variance of the proposal distribution
        """
        if proposal_var == 0:
            proposal_var = self.proposal_var
        # (1). sample from the proposal distribution
        qmcwrapper = QmcWrapper(dim=self.dim, device=self.device)
        proposal_sample = qmcwrapper.sample(gen_num=num_samples, batch_size=1, enable_qmc=enable_qmc)[0]

        # generate proposal samples from self.proposal_var * [-1, 1]^dim
        if self.proposal_type == 'uniform':
            proposal_sample = (proposal_sample - 0.5) * 2.0 * np.sqrt(proposal_var)

        # generate proposal samples from self.proposal_var * N(0, I_dim)
        elif self.proposal_type == 'gaussian':
            proposal_sample = qmcwrapper.normal_transform(proposal_sample) * np.sqrt(proposal_var)

        # generate proposal samples from self.proposal_var * student-t
        elif self.proposal_type == 'student-t':
            raise NotImplementedError

        else:
            raise ValueError("Invalid proposal type")

        # (2). calculate the log weights
        log_weights = self.prior.log_prob(proposal_sample) - self.proposal.log_prob(proposal_sample)
        range_generator = range(num_samples) if not enable_tqdm else tqdm(range(num_samples))
        for i in range_generator:
            with torch.no_grad():
                proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proposal_sample[i])
                loss = 0.
                # num_datapoints = 0.
                for batch_num, (data, target) in enumerate(self.loader):
                    data = data.to(self.device)
                    target = target.to(self.device)
                    batch_loss, _, _ = self.criterion(proj_model, data, target)
                    loss += batch_loss * data.size(0)

            log_weights[i] -= loss / self.temperature

        # (3). normalize the weights
        weights = torch.exp(log_weights - torch.logsumexp(log_weights, dim=0))
        return proposal_sample, weights

    def resample(self, proposal_sample, weights, resample_size):
        # resample from the proposal with weights
        resample_index = torch.distributions.Categorical(probs=weights).sample((resample_size,))
        return proposal_sample[resample_index]

    def gen_data(self, inp, proj_params):
        # generate data from the projected model
        proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proj_params)
        return proj_model(inp)

    def gen_data_with_loader(self, loader, proj_params):
        # generate data from the projected model
        proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proj_params)
        data_list = []
        for data, target in loader:
            data = data.to(self.device)
            data_list.append(proj_model(data))
        return torch.cat(data_list, dim=0)

    def reconstruct_loss(self, proj_params, valid_loader, enable_tqdm=False):
        # calculate the reconstruction loss
        num_samples = proj_params.size(0)
        loss_total = torch.zeros((num_samples,), device=self.device)
        range_generator = range(num_samples) if not enable_tqdm else tqdm(range(num_samples))
        for i in range_generator:
            with torch.no_grad():
                proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proj_params[i])
                loss = 0.
                num_datapoints = 0.
                # for batch_num, (data, target) in enumerate(self.loader):
                for batch_num, (data, target) in enumerate(valid_loader):
                    num_datapoints += data.size(0)
                    batch_loss, _, _ = self.criterion(proj_model, data, target)
                    loss += batch_loss * data.size(0)
            loss_total[i] = loss.clone()
        return torch.exp(-loss_total)

    def calc_marginal(self, proposal_sample, valid_loader, enable_tqdm=False):
        # return P(D | w(z))^(1/T) * P(w(z))
        num_samples = proposal_sample.size(0)
        self.base_model.eval()
        log_weights = self.prior.log_prob(proposal_sample)
        range_generator = range(num_samples) if not enable_tqdm else tqdm(range(num_samples))
        for i in range_generator:
            with torch.no_grad():
                proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proposal_sample[i])
                loss = 0.
                num_datapoints = 0.
                for batch_num, (data, target) in enumerate(valid_loader):
                    data = data.to(self.device, non_blocking=True)
                    target = target.to(self.device, non_blocking=True)
                    batch_loss, _, _ = self.criterion(proj_model, data, target)
                    loss += batch_loss * data.size(0)
                    num_datapoints += data.size(0)

            log_weights[i] -= loss / self.temperature
        return log_weights

    def calc_mean_loss(self, proposal_sample, valid_loader, enable_tqdm=False):
        # return P(D | w(z))
        if len(valid_loader) == 1:
            num_samples = proposal_sample.size(0)
            self.base_model.eval()
            loss_result = torch.zeros((num_samples,), device=self.device)
            range_generator = range(num_samples) if not enable_tqdm else tqdm(range(num_samples))
            with torch.no_grad():
                for i in range_generator:
                    proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proposal_sample[i])
                    loss_result[i], _, _ = self.criterion(proj_model, valid_loader.dataset.tensors[0], valid_loader.dataset.tensors[1])
            return loss_result
        num_samples = proposal_sample.size(0)
        self.base_model.eval()
        loss_result = torch.zeros((num_samples,), device=self.device)
        range_generator = range(num_samples) if not enable_tqdm else tqdm(range(num_samples))
        with torch.no_grad():
            for i in range_generator:
                proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proposal_sample[i])
                loss = 0.
                num_datapoints = 0.
                for batch_num, (data, target) in enumerate(valid_loader):
                    data = data.to(self.device, non_blocking=True)
                    target = target.to(self.device, non_blocking=True)
                    num_datapoints += data.size(0)
                    batch_loss, _, _ = self.criterion(proj_model, data, target)
                    loss += batch_loss * data.size(0)
                loss_result[i] = loss / num_datapoints
        return loss_result

    def calc_log_density_one_point(self, proposal_sample, data_iter, valid_loader):
        # return log posterior for one point
        try:
            data, target = next(data_iter)
        except StopIteration:
            data_iter = iter(valid_loader)
            data, target = next(data_iter)
        data = data.to(self.device, non_blocking=True)
        target = target.to(self.device, non_blocking=True)
        proj_model = ProjectedModel(model=self.base_model, subspace=self.subspace, proj_params=proposal_sample)
        loss, _, _ = self.criterion(proj_model, data, target)
        # scale batch loss to full dataset
        loss = loss * data.size(0) * valid_loader.__len__() / self.temperature
        self.post_eval_times += 1
        if self.post_eval_times % 500 == 0:
            print("Posterior eval times: %d" % self.post_eval_times)
        return self.prior.log_prob(proposal_sample) - loss, data_iter

    def sampling_with_nuts(self, num_samples, valid_loader, thinning=5, warmup_steps=300):
        from pyro.infer.mcmc import NUTS, MCMC
        self.post_eval_times = 0
        data_iter = iter(valid_loader)

        def pot_fn(params):
            nonlocal data_iter
            log_density, data_iter = self.calc_log_density_one_point(params['params'], data_iter, valid_loader)
            return -log_density

        initial_params = {"params": torch.zeros(self.dim).to(self.device)}
        nuts_kernel = NUTS(potential_fn=pot_fn, target_accept_prob=0.5, max_tree_depth=5)
        mcmc = MCMC(nuts_kernel, num_samples=(num_samples - 1) * thinning + 1, warmup_steps=warmup_steps, initial_params=initial_params)
        mcmc.run()
        samples = mcmc.get_samples()["params"]
        thinning_samples = samples[::thinning]
        return thinning_samples
