import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from zuko.utils import broadcast

from .bayesian_npe import BayesianNPEFactory
from .bayesian_npe import BayesianNPEModel
from .bayesian_npe import NPEWithEmbedding
from .np_priors import DirichletPrior


class DiscretizedBayesianNPEFactory(BayesianNPEFactory):
    def __init__(self, config, benchmark, simulation_budget):
        super().__init__(
            config, benchmark, simulation_budget, DiscretizedBayesianNPEModel
        )


class DiscretizedNPE(nn.Module):
    def __init__(
        self,
        parameter_dim,
        embedding_dim,
        bins,
        parameter_ranges,
        device,
        build,
        **kwargs
    ):
        super().__init__()
        self.bins = bins
        self.parameter_dim = parameter_dim
        self.embedding_dim = embedding_dim
        self.device = device
        self.parameter_ranges = parameter_ranges
        self.model = build(embedding_dim, bins**parameter_dim, **kwargs)
        self.LOWER = parameter_ranges[0].to(self.device)
        self.UPPER = parameter_ranges[1].to(self.device)
        self.bin_log_volume = np.sum(
            np.log(
                np.array(
                    [
                        (upper.cpu() - lower.cpu()) / bins
                        for lower, upper in zip(self.LOWER, self.UPPER)
                    ]
                )
            )
        )

    def get_bin(self, theta):
        # Move to numbers between 0 and 1 where 0 correspond to min and 1 to max.
        theta = (theta - self.LOWER.unsqueeze(0)) / (
            self.UPPER.unsqueeze(0) - self.LOWER.unsqueeze(0)
        )

        # Multiply by the number of bins per dimension
        theta = (theta * self.bins).int()

        # Add offset to each axis
        offset = (
            torch.LongTensor([self.bins**x for x in range(self.parameter_dim)])
            .unsqueeze(0)
            .to(self.device)
        )
        theta = theta * offset

        # Sum axes.
        bins = theta.sum(dim=1)
        return bins

    def sample_theta_from_bin(self, bin):
        thetas = []
        for dim in range(self.parameter_dim):
            dim_bin = (
                torch.div(bin, self.bins**dim, rounding_mode="trunc")
            ) % self.bins
            bin_size = (self.UPPER[dim] - self.LOWER[dim]) / self.bins
            theta_lower = self.LOWER[dim] + dim_bin * bin_size

            u = torch.rand(bin.shape[0]).to(self.device)
            u = u * bin_size

            theta = theta_lower + u
            thetas.append(theta)

        return torch.stack(thetas, dim=1)

    def forward(self, theta, x):
        theta, x = broadcast(theta, x, ignore=1)
        logits = self.model(x)
        output = F.log_softmax(logits, dim=1)
        if len(x.shape) == 1:
            bin_output = output[self.get_bin(theta)]
        else:
            bin_output = torch.gather(output, 1, self.get_bin(theta).view(-1, 1))

        bin_output = bin_output.squeeze()

        return bin_output - self.bin_log_volume

    def sample(self, x, shape):
        logits = self.model(x)
        bin_probs = F.softmax(logits, dim=0)
        theta_bins = torch.multinomial(
            bin_probs, int(np.prod(np.array(shape))), replacement=True
        ).view(shape)
        theta = self.sample_theta_from_bin(theta_bins)
        return theta


class DiscretizedBayesianNPEModel(BayesianNPEModel):
    def __init__(self, benchmark, model_path, config, normalization_constants):
        super().__init__(benchmark, model_path, config, normalization_constants)

        self.benchmark = benchmark
        self.config = config
        classifier_build, classifier_kwargs = benchmark.get_classifier_build()
        lower, upper = benchmark.get_domain()
        lower = self.normalize_parameters(lower)
        upper = self.normalize_parameters(upper)
        self.classifier = DiscretizedNPE(
            self.parameter_dim,
            self.embedding_dim,
            config["bins"],
            (lower, upper),
            self.device,
            classifier_build,
            **classifier_kwargs,
        )
        self.flow = None
        self.model = NPEWithEmbedding(
            self.classifier,
            self.embedding,
            self.normalize_observation,
            self.unnormalize_observation,
            self.normalize_parameters,
            self.unnormalize_parameters,
            self.get_normalization_log_jacobian(),
        )
        self.bnn_prior = self.get_bnn_prior()
        self.bnn_prior.to(self.device)

    def save_model(self, chain_id, index):
        torch.save(
            self.embedding.state_dict(),
            os.path.join(self.model_path, "embedding_{}_{}.pt".format(chain_id, index)),
        )
        torch.save(
            self.classifier.state_dict(),
            os.path.join(
                self.model_path, "classifier_{}_{}.pt".format(chain_id, index)
            ),
        )

    def save(self):
        pass

    def load_model(self, chain_id, index):
        self.embedding.load_state_dict(
            torch.load(
                os.path.join(
                    self.model_path, "embedding_{}_{}.pt".format(chain_id, index)
                )
            )
        )
        self.classifier.load_state_dict(
            torch.load(
                os.path.join(
                    self.model_path, "classifier_{}_{}.pt".format(chain_id, index)
                )
            )
        )

    def load(self):
        pass

    def get_np_distribution(self):
        return DirichletPrior(self.benchmark, self.config, self.device)
