from nflows.distributions import Distribution, StandardNormal
import torch.distributions as D
import torch
import torch.nn as nn
import torch.nn.functional as F
from nflows.flows.base import Flow

from . import DensityEstimator
from ..utils import batch_or_dataloader
import pdb

class AdaptedMixtureSameFamily():
    def __init__(self, distribution):
        self.distribution = distribution

    def log_prob(self, inputs, context=None):
        return self.distribution.log_prob(inputs)
    
    def sample(self, num_samples, context=None): # Context to play nice with nflows
        return torch.stack([self.distribution.sample() for _ in range(num_samples)])

class EmbeddingNet(nn.Module):
    def __init__(self, conditioning_dimension):
        super().__init__()
        self.conditioning_dimension = conditioning_dimension

    def forward(self, idxs):
        conditioning_vector = torch.zeros(idxs.shape[0], self.conditioning_dimension).to(idxs.device)
        conditioning_vector[torch.arange(idxs.shape[0]), idxs] = 1
        return conditioning_vector 

class NormalizingFlow(DensityEstimator):

    model_type = "nf"

    def __init__(self, 
        dim, 
        transform, 
        base_distribution=None, 
        num_mixture_components=0, 
        distribution_mean_spacing=1, 
        conditioning=None,
        conditioning_dimension=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.transform = transform
        self.base_distribution_name = base_distribution #TODO: fix clumsy naming
        self.dim = dim

        self.embedding_net = None
        self.conditioning = conditioning
        if conditioning is not None:
            self.embedding_net = EmbeddingNet(conditioning_dimension)

        if base_distribution is None:
            self.base_distribution = StandardNormal([dim])
        
        elif "mixture_of_gaussians" in base_distribution:
            self.num_mixture_components = num_mixture_components

            # TODO: remove device hard code
            # TODO: refactor mixture distribution code
            if "completely_learned" in base_distribution:
                self.mixture_weights = nn.Parameter(torch.ones(num_mixture_components),requires_grad=True)
                self.means = nn.Parameter(torch.arange((-self.num_mixture_components // 2)*distribution_mean_spacing, (self.num_mixture_components // 2 + self.num_mixture_components % 2)*distribution_mean_spacing, distribution_mean_spacing)[:,None].repeat(1,self.dim).to(torch.float32),requires_grad=True)
                self.stds = nn.Parameter(torch.ones_like(self.means),requires_grad=True)
            elif "learned" in base_distribution:
                self.mixture_weights = nn.Parameter(torch.ones(num_mixture_components),requires_grad=True)
                self.means = torch.arange((-self.num_mixture_components // 2)*distribution_mean_spacing, (self.num_mixture_components // 2 + self.num_mixture_components % 2)*distribution_mean_spacing, distribution_mean_spacing)[:,None].repeat(1,self.dim).to(torch.float32).cuda()
                self.stds =torch.ones_like(self.means)
            else:
                self.mixture_weights = torch.ones(num_mixture_components).cuda()
                self.means = torch.arange((-self.num_mixture_components // 2)*distribution_mean_spacing, (self.num_mixture_components // 2 + self.num_mixture_components % 2)*distribution_mean_spacing, distribution_mean_spacing)[:,None].repeat(1,self.dim).to(torch.float32).cuda()
                self.stds = torch.ones_like(self.means)

            self.reset_prior()
            
        else:
            raise NotImplementedError(f"Base distribution {base_distribution} not implemented")

        self._nflow = Flow(
            transform=self.transform,
            distribution=self.base_distribution,
            embedding_net=self.embedding_net
        )

    def reset_prior(self):
        # Initialized for each forward pass because you can't backprop through distribution params
        if self.base_distribution_name is not None and "mixture_of_gaussians" in self.base_distribution_name:
            mix = D.Categorical(F.softmax(self.mixture_weights, dim=0))
            comp = D.Independent(D.Normal(self.means, self.stds),1)
            self.base_distribution = AdaptedMixtureSameFamily(D.MixtureSameFamily(mix,comp))

            self._nflow = Flow(
                transform=self.transform,
                distribution=self.base_distribution,
                embedding_net=self.embedding_net
            )
        
    def sample_conditioning(self, n_samples):
        return torch.multinomial(self.conditioning_counts, n_samples, replacement=True)

    def sample(self, n_samples):
        # TODO: batch in parent class
        # TODO: refactor reset_prior
        self.reset_prior()

        conditioning = None
        if self.conditioning is not None:
            conditioning = self.sample_conditioning(n_samples).to(self.device)

        samples = self._nflow.sample(n_samples, context=conditioning)
        return self._inverse_data_transform(samples)

    @batch_or_dataloader(pass_label=True)
    def log_prob(self, x, conditioning=None):
        # TODO: Careful with log probability when using _data_transform()
        self.reset_prior()

        if type(x) == tuple:
            label = x[1]
            x = x[0]
            if conditioning is None and self.conditioning is not None: conditioning = label

        x = self._data_transform(x)
        try:

            log_prob = self._nflow.log_prob(x, context=conditioning)

        except:
            pdb.set_trace()
        
        if len(log_prob.shape) == 1:
            log_prob = log_prob.unsqueeze(1)

        return log_prob
