import numpy as np
import torch

from .base import Model
from .base import ModelFactory
from .np_priors import DirichletPrior


class DirichletFactory(ModelFactory):
    def __init__(self, config, benchmark, simulation_budget):
        super().__init__(config, benchmark, simulation_budget, DirichletModel)


class DirichletModel(Model):
    def __init__(self, benchmark, model_path, config, normalization_constants):
        super().__init__(normalization_constants)

        self.benchmark = benchmark
        self.prior = benchmark.get_prior()
        self.config = config
        self.nb_estimators = config["nb_estimators"]
        self.bins = config["bins"]
        self.device = benchmark.get_device()
        self.dirichlet_prior = DirichletPrior(benchmark, config, "cpu")

        self.parameter_dim = benchmark.get_parameter_dim()
        parameter_ranges = benchmark.get_domain()
        self.LOWER = parameter_ranges[0]
        self.UPPER = parameter_ranges[1]

        self.bin_log_volume = np.sum(
            np.log(
                np.array(
                    [
                        (upper - lower) / self.bins
                        for lower, upper in zip(self.LOWER, self.UPPER)
                    ]
                )
            )
        )
        self.estimators = None

    def log_prob(self, theta, x):
        return (
            self.dirichlet_prior.sample_functions(theta, x, self.nb_estimators)
            .squeeze(dim=2)
            .mean(dim=1)
        )

    # def sample_theta_from_bin(self, bin):
    #     thetas = []
    #     for dim in range(self.parameter_dim):
    #         dim_bin = (bin // self.bins**dim) % self.bins
    #         bin_size = (self.UPPER[dim] - self.LOWER[dim]) / self.bins
    #         theta_lower = self.LOWER[dim] + dim_bin * bin_size
    #         u = torch.rand(bin.shape[0])
    #         u = u * bin_size

    #         theta = theta_lower + u
    #         thetas.append(theta)

    #     return torch.stack(thetas, dim=1)

    # def log_prob(self, theta, x):
    #     output_list = []
    #     for estimator in self.estimators:
    #         bin_output = estimator[self.dirichlet_prior.get_bin(theta)].log()
    #         output = bin_output - self.bin_log_volume
    #         output_list.append(output)

    #     output_list = torch.stack(output_list, dim=-1)
    #     average_output = torch.logsumexp(output_list, dim=-1) - torch.log(
    #         torch.Tensor([self.nb_estimators])
    #     )
    #     return average_output

    # def sample(self, x, shape):
    #     nb_samples = np.prod(np.array(shape))
    #     samples = []
    #     for _ in range(nb_samples):
    #         index = np.random.randint(self.nb_estimators)
    #         estimator = self.estimators[index]
    #         theta_bin = torch.multinomial(estimator, 1)
    #         samples.append(self.sample_theta_from_bin(theta_bin))

    #     return torch.stack(samples, dim=0).view(*shape, -1)

    # def get_posterior_fct(self):
    #     def get_posterior(x):
    #         class Posterior:
    #             def __init__(self, sampling_fct, log_prob_fct):
    #                 self.sample = sampling_fct
    #                 self.log_prob = log_prob_fct

    #         return Posterior(
    #             lambda shape: self.sample(x.to(self.device), shape).cpu(),
    #             lambda theta: self.log_prob(
    #                 theta.to(self.device), x.to(self.device)
    #             ).cpu(),
    #         )

    #     return get_posterior

    # def train_models(self):
    #     self.estimators = [
    #         self.dirichlet_prior.sample_distribution()
    #         for _ in range(self.nb_estimators)
    #     ]

    def train_models(self):
        pass

    @classmethod
    def is_trained(cls, model_path):
        return True

    # def sampling_enabled(self):
    #     return True

    def sampling_enabled(self):
        return False

    def __call__(self, theta, x):
        return self.log_prob(theta, x)

    def get_loss_fct(self):
        pass

    def save(self):
        pass

    def load(self):
        pass

    def train(self):
        pass

    def eval(self):
        pass

    def to(self, device):
        pass
