from typing import Union, Tuple, List, Optional
import math
from torch import Tensor
from torch_geometric.data import Data
import torch
import torch.nn as nn


class Evidence(nn.Module):
    """Module for computing evidence from logits/log-probabilities"""

    def __init__(self, scale='latent-new'):
        """
        Initializes a new evidence layer.

        Parameters
        ----------
        scale: str, default: 'latent-new'
            The scaling to use for evidence calculation. Can be one of:
            'latent', 'latent-new', 'latent-sqrt', 'latent-plus-classes', 'latent-sqrt-plus-classes'
        """
        super().__init__()
        self.scale = scale
        
    def forward(self, log_softmax, dim=None, further_scale=1.0):
        """
        Computes evidence from log-softmax values.
        
        Parameters
        ----------
        log_softmax: Tensor
            Log-softmax values
        dim: int, optional
            Dimension for scaling, if required by the scale method
        further_scale: float, default: 1.0
            Additional scaling factor
            
        Returns
        -------
        Tensor
            The computed evidence values
        """
        if self.scale == 'latent':
            evidence = torch.exp(log_softmax)
            
        elif self.scale == 'latent-new':
            max_logits = torch.max(log_softmax, dim=1, keepdim=True)[0]
            # stabilized logits, i.e., all values <= 0, i.e., all resulting values in [0, 1]
            stab_logits = log_softmax - max_logits
            evidence = torch.exp(stab_logits)
            
        elif self.scale == 'latent-sqrt':
            if dim is None:
                evidence = torch.exp(log_softmax / math.sqrt(2))
            else:
                evidence = torch.exp(log_softmax / math.sqrt(dim * 2))
                
        elif self.scale == 'latent-plus-classes':
            if dim is None:
                evidence = torch.exp(log_softmax) * further_scale
            else:
                evidence = torch.exp(log_softmax) * further_scale
                
        elif self.scale == 'latent-sqrt-plus-classes':
            if dim is None:
                evidence = torch.exp(log_softmax / math.sqrt(2)) * further_scale
            else:
                evidence = torch.exp(log_softmax / math.sqrt(dim * 2)) * further_scale
                
        else:
            raise ValueError(f"Unknown evidence scale '{self.scale}'")
            
        return evidence


class Density(nn.Module):
    """
    encapsulates the PostNet step of transforming latent space
    embeddings z into alpha-scores with normalizing flows
    """

    def __init__(self,
                 dim_latent: int,
                 num_mixture_elements: int,
                 radial_layers: int = 6,
                 maf_layers: int = 0,
                 gaussian_layers: int = 0,
                 flow_size: float = 0.5,
                 maf_n_hidden: int = 2,
                 flow_batch_norm: bool = False,
                 use_batched_flow: bool = False):

        super().__init__()
        self.num_mixture_elements = num_mixture_elements
        self.dim_latent = dim_latent
        self.use_batched_flow = use_batched_flow

        self.use_flow = True
        if (maf_layers == 0) and (radial_layers == 0):
            self.use_flow = False

        if self.use_batched_flow:
            self.use_flow = False

        if self.use_batched_flow:
            self.flow = BatchedNormalizingFlowDensity(
                c=num_mixture_elements,
                dim=dim_latent,
                flow_length=radial_layers,
                flow_type='radial_flow')

        elif self.use_flow:
            self.flow = nn.ModuleList([
                NormalizingFlow(
                    dim=self.dim_latent,
                    radial_layers=radial_layers,
                    maf_layers=maf_layers,
                    flow_size=flow_size,
                    n_hidden=maf_n_hidden,
                    batch_norm=flow_batch_norm) 
                for _ in range(num_mixture_elements)])

        else:
            self.flow = nn.ModuleList([MixtureDensity(
                dim=self.dim_latent,
                n_components=gaussian_layers) for _ in range(num_mixture_elements)])

    def forward(self, z: Tensor) -> Tensor:
        # produces log p(z|c)
        if self.use_batched_flow:
            log_q_c = self.forward_batched(z)

        elif self.use_flow:
            log_q_c = self.forward_flow(z)

        else:
            log_q_c = self.forward_mixture(z)

        if not self.training:
            # If we're evaluating and observe a NaN value, this is always caused by the
            # normalizing flow "diverging". We force these values to minus infinity.
            log_q_c[torch.isnan(log_q_c)] = float('-inf')

        return log_q_c

    def forward_batched(self, z: Tensor) -> Tensor:
        return self.flow.log_prob(z).transpose(0, 1)

    def forward_flow(self, z: Tensor) -> Tensor:
        n_nodes = z.size(0)
        log_q = torch.zeros((n_nodes, self.num_mixture_elements)).to(z.device.type)

        for c in range(self.num_mixture_elements):
            out, log_det = self.flow[c](z)
            log_p = X.log_prob_standard_normal(out) + log_det
            log_q[:, c] = log_p

        return log_q

    def forward_mixture(self, z: Tensor) -> Tensor:
        n_nodes = z.size(0)
        log_q = torch.zeros((n_nodes, self.num_mixture_elements)).to(z.device.type)

        for c in range(self.num_mixture_elements):
            log_q[:, c] = self.flow[c](z)

        return log_q
