from typing import Any, Dict, Iterable, Literal, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.bernoulli import Bernoulli
import numpy as np
import os
import lightning as L
from .spns import (
    construct_fg,
    construct_ifg,
    SPNFG,
    ISPNFG
)


class BasicFG(torch.nn.Module):
    """
    Class use to generate P(U,V), the gene to factor/factor to gene and intervention to factor matrix  
    This is a simple baseline model.
    """
    def __init__(
        self,
        num_vars,
        num_interventions,
        num_modules,
        tau=1.0,
        sample_mask=True,
        hard_mask=True,
        **kwargs
    ):
        super().__init__()
        
        # self.register_buffer("m2i", torch.zeros((num_interventions, num_modules)))

        self.tau = tau
        self.sample_mask = sample_mask
        self.hard_mask = hard_mask
        self.num_vars = num_vars
        self.num_interventions = num_interventions
        self.num_modules = num_modules
        
        self.register_parameter('logits', nn.Parameter(torch.zeros((num_vars+num_interventions, num_modules, 2))))
        self.register_parameter('logpy', nn.Parameter(torch.zeros(num_vars, num_modules+1)))

    @property
    def numel(self):
        return (self.num_vars + self.num_interventions) * self.num_modules * 2 + self.num_vars * (self.num_modules + 1)

    def forward(self, bs):
        """Samples factor graphs
        
        Args:
            bs (int): batch size
        
        Returns:
            torch.Tensor: node-to-factor and factor-to-node matrices (concatenated by intervention-to-factor matrix)
        """
        if not self.sample_mask:
            mtx = F.softmax(self.logits).repeat(bs, 1, 1, 1)
            if self.hard_mask:
                y = F.one_hot(self.logpy.argmax(dim=-1), num_classes=self.num_modules+1).repeat(bs, 1, 1)
            else:
                y = F.softmax(self.logpy).repeat(bs, 1, 1)
            is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0.1, 0.7, 0.1, 0.1] -> is_in_m2n = [0.9, 0.2, 0.1, 0.0] 
        else:
            mtx = F.gumbel_softmax(self.logits, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1, 1)
            y = F.gumbel_softmax(self.logpy, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1)
            is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0, 1, 0, 0] -> is_in_m2n = [1, 0, 0, 0]

        if self.num_interventions > 0:
            n2m = mtx[:, :self.num_vars, :, 0] * (1 - is_in_m2n[:, :, :-1])
            m2n = mtx[:, :self.num_vars, :, 0] * is_in_m2n[:, :, :-1]
            i2m = mtx[:, self.num_vars:, :, 0]
            return (
                torch.cat([n2m, i2m], 1),
                m2n
            )
        
        n2m = mtx[:, :, :, 0] * (1 - is_in_m2n[:, :, :-1])
        m2n = mtx[:, :, :, 0] * is_in_m2n[:, :, :-1]
        return n2m, m2n

    def reset_parameters(self):
        self.register_parameter(
            'logits',
            nn.Parameter(
                    torch.zeros((self.num_vars+self.num_interventions, self.num_modules, 2))
                )
        )
        self.register_parameter('logpy', nn.Parameter(torch.zeros(self.num_vars, self.num_modules+1)))
    
    def kl_loss(self, p_edge: Tensor, p_y: Tensor):
        """Calculates the KL divergence between current PMF and a prior.

        Args:
            p_edge (:class:`torch.tensor`): The target distribution of single edge connection.
            We assume each edge is a Bernoulli random variable.
            p_y (:class:`torch.tensor`): The target distribution of partition variables.

        Returns:
            :class:`torch.tensor`: The KL divergence.
        """
        q_y = nn.Softmax(dim=-1)(self.logpy)
        q_mtx = nn.Softmax(dim=-1)(self.logits)
        kl = torch.sum(q_y * (torch.log(q_y) - torch.log(p_y)))\
            + torch.sum(q_mtx * (torch.log(q_mtx) - torch.log(p_edge)))
        return kl
    
    def sample(self, bs: int = 1):
        """ Randomly sample factor graphs. """
        mtx = F.gumbel_softmax(self.logits, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1, 1)
        y = F.gumbel_softmax(self.logpy, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1)
        is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0, 1, 0, 0] -> is_in_m2n = [1, 0, 0, 0]
        
        if self.num_interventions > 0:
            n2m = mtx[:, :self.num_vars, :, 0] * (1 - is_in_m2n[:, :, :-1])
            m2n = mtx[:, :self.num_vars, :, 0] * is_in_m2n[:, :, :-1]
            i2m = mtx[:, self.num_vars:, :, 0]
            return (
                torch.cat([n2m, i2m], 1),
                m2n
            )
        
        n2m = mtx[:, :, :, 0] * (1 - is_in_m2n[:, :, :-1])
        m2n = mtx[:, :, :, 0] * is_in_m2n[:, :, :-1]
        return n2m, m2n
    
    def fixed_fg(self, thred: float = 0.5, hard: bool = True):
        """Get the factor graph representation of the current distribution.

        Args:
            thred (float, optional): Threshold for determining an edge. Defaults to 0.5.

        Returns:
            tuple: The factor-to-node and node-to-factor matrices. If interventions are present, the third matrix is
            the intervention-to-factor matrix.
        """   
        with torch.no_grad():
            if hard:
                mtx = (F.softmax(self.logits, dim=-1)[:, :, 0] > thred).int()
                y = F.one_hot(self.logpy.argmax(dim=-1), num_classes=self.num_modules+1)
            else:
                mtx = F.softmax(self.logits, dim=-1)[:, :, 0]
                y = F.softmax(self.logpy, dim=-1)

            is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0, 1, 0, 0] -> is_in_m2n = [1, 0, 0, 0]
            if self.num_interventions > 0:
                n2m = mtx[:self.num_vars, :] * (1 - is_in_m2n[:, :-1])
                m2n = mtx[:self.num_vars, :] * is_in_m2n[:, :-1]
                i2m = mtx[self.num_vars:, :]
                return n2m, m2n, i2m
            n2m = mtx * (1 - is_in_m2n[:, :-1])
            m2n = mtx * is_in_m2n[:, :-1]
        return n2m, m2n, None


class SPNFG(L.LightningModule):
    """
    Class use to generate P(U,V), the gene to factor/factor to gene and intervention to factor matrix
    """
    def __init__(
        self,
        num_vars,
        num_interventions,
        num_modules,
        spn_target='factor',
        max_copies=8,
        tau=1.0,
        p_conn=0.1,
        sparsity_temp=0.1,
        sample_mask=True,
        hard_mask=True
    ):
        super().__init__()
        self.save_hyperparameters()
        # self.register_buffer("m2i", torch.zeros((num_interventions, num_modules)))

        self.tau = tau
        self.sample_mask = sample_mask
        self.hard_mask = hard_mask
        self.num_vars = num_vars
        self.num_interventions = num_interventions
        self.num_modules = num_modules
        
        if self.num_interventions > 0:
            self.fg = construct_ifg(
                num_nodes=self.num_vars,
                num_factors=self.num_modules,
                num_interventions=self.num_interventions,
                max_copies=max_copies,
                spn_target=spn_target,
                tau=tau,
                p_conn=p_conn,
                sparsity_temp=sparsity_temp
            )
        else:
            self.fg = construct_fg(
                num_nodes=self.num_vars,
                num_factors=self.num_modules,
                max_copies=max_copies,
                spn_target=spn_target,
                tau=tau,
                p_conn=p_conn,
                sparsity_temp=sparsity_temp
            )

    @property
    def numel(self):
        return self.fg.numel

    def forward(self, bs: int):
        if self.sample_mask:
            if self.num_interventions > 0:
                n2m, m2n, i2m = self.fg.sample(bs, self.hard_mask)
                return torch.cat([n2m, i2m], 1), m2n
            else:
                n2m, m2n = self.fg.sample(bs, self.hard_mask)
                return n2m, m2n
        if self.num_interventions > 0:
            n2m, m2n, i2m = self.fg.sample_deterministic(self.hard_mask)
            return torch.cat([n2m, i2m], 0).unsqueeze(0), m2n.unsqueeze(0)
        else:
            n2m, m2n = self.fg.sample_deterministic(self.hard_mask)
            return n2m.unsqueeze(0), m2n.unsqueeze(0)

    def reset_parameters(self):
        self.fg.ready()
    
    def kl_loss(self, p_edge: Tensor, p_y: Tensor):
        return self.fg.kl_random()
    
    def sample(self, bs: int = 1):
        if self.num_interventions > 0:
            n2m, m2n, i2m = self.fg.sample(bs, True)
            return torch.cat([n2m, i2m], 1), m2n

        n2m, m2n = self.fg.sample(bs, True)
        return n2m, m2n
    
    def fixed_fg(self, hard: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
        """ Get the (extended) factor graph representation of the current distribution. """
        if isinstance(self.fg, ISPNFG):
            n2m, m2n, i2m = self.fg.sample_deterministic(hard=hard)
            return n2m, m2n, i2m
        elif isinstance(self.fg, SPNFG):
            n2m, m2n = self.fg.sample_deterministic(hard=hard)
            return n2m, m2n, None
        else:
            raise NotImplementedError("Only SPNFG and ISPNFG are supported.")


class DisentangledFG(torch.nn.Module):
    """
    Class use to generate a f-DAG with disentangled intervention-to-factor matrix.
    """
    def __init__(
        self,
        num_vars: int,
        num_interventions: int,
        num_modules: int,
        tau: float = 1.0,
        p_repeat: float = 0.5,
        block_size: int = 1,
        sample_mask: bool = True,
        hard_mask: bool = True,
        **kwargs
    ):
        super().__init__()
        
        # self.register_buffer("m2i", torch.zeros((num_interventions, num_modules)))

        self.tau = tau
        self.register_buffer('p_repeat', torch.tensor([p_repeat]))
        self.block_size = block_size
        self.n_block = num_interventions // self.block_size
        if num_interventions % self.block_size != 0:
            self.n_block += 1
        self.sample_mask = sample_mask
        self.hard_mask = hard_mask
        self.num_vars = num_vars
        self.num_interventions = num_interventions
        self.num_modules = num_modules
        
        self.register_parameter('logits', nn.Parameter(torch.zeros((num_vars, num_modules, 2))))
        if self.num_interventions > 0:
            self.register_parameter('logits_intv', nn.Parameter(torch.zeros((num_interventions, num_modules, 2))))
        else:
            self.logits_intv = None
        self.register_parameter('logpy', nn.Parameter(torch.zeros(num_vars, num_modules+1)))

    @property
    def numel(self):
        return (self.num_vars + self.num_interventions) * self.num_modules * 2 + self.num_vars * (self.num_modules + 1)

    def _intv_block(self, block_idx: int, bit_count: Tensor, sample_mask: bool = True, hard_mask: bool = True) -> Tensor:
        """ Samples a block of a factor graph."""
        start, end = block_idx * self.block_size, min((block_idx + 1) * self.block_size, self.num_interventions)
        
        if not sample_mask:
            if hard_mask:
                prob = (self.logits_intv[start:end] > 0).int()
            else:
                prob = F.softmax(self.logits_intv[start:end])
        else:
            prob = F.gumbel_softmax(self.logits_intv[start:end], tau=self.tau, hard=self.hard_mask)
        i2m = prob[:, :, 0]
        # Dropout 1's proportional to the drop out probability
        with torch.no_grad():
            keep = Bernoulli((self.p_repeat).pow(bit_count))
            i2m = i2m * keep.sample([i2m.shape[0]]).to(i2m.device)
        return i2m

    def _forward_i2m(self, bs: int, sample_mask: bool = True, hard_mask: bool = True) -> Tensor:
        i2m = []
        bit_count = torch.zeros((self.num_modules))
        for block_idx in range(self.n_block):
            i2m.append(self._intv_block(block_idx, bit_count, sample_mask, hard_mask))
            bit_count += i2m[-1].detach().cpu().sum(dim=0)
        i2m = torch.cat(i2m, dim=0).repeat(bs, 1, 1)
        return i2m
        

    def forward(self, bs):
        """Samples factor graphs
        
        Args:
            bs (int): batch size
        
        Returns:
            torch.Tensor: node-to-factor and factor-to-node matrices (concatenated by intervention-to-factor matrix)
        """
        # node-to-factor and factor-to-node matrices
        if not self.sample_mask:  # whether to sample the mask
            if self.hard_mask:  # whether to use hard mask (binary) or soft mask (probability)
                mtx = (self.logits > 0).int().repeat(bs, 1, 1, 1)
                y = F.one_hot(self.logpy.argmax(dim=-1), num_classes=self.num_modules+1).repeat(bs, 1, 1)
            else:
                mtx = F.softmax(self.logits).repeat(bs, 1, 1, 1)
                y = F.softmax(self.logpy).repeat(bs, 1, 1)
            is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0.1, 0.7, 0.1, 0.1] -> is_in_m2n = [0.9, 0.2, 0.1, 0.0] 
        else:
            mtx = F.gumbel_softmax(self.logits, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1, 1)
            y = F.gumbel_softmax(self.logpy, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1)
            is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0, 1, 0, 0] -> is_in_m2n = [1, 0, 0, 0]
        n2m = mtx[:, :, :, 0] * (1 - is_in_m2n[:, :, :-1])
        m2n = mtx[:, :, :, 0] * is_in_m2n[:, :, :-1]
        
        # if there are interventions, add the intervention-to-factor matrix
        if self.num_interventions > 0:
            i2m = self._forward_i2m(bs, self.sample_mask, self.hard_mask)
            return (
                torch.cat([n2m, i2m], 1),
                m2n
            )

        return n2m, m2n

    def reset_parameters(self):
        self.register_parameter(
            'logits',
            nn.Parameter(
                    torch.zeros((self.num_vars, self.num_modules, 2))
                )
        )
        self.register_parameter(
            'logits_intv',
            nn.Parameter(
                    torch.zeros((self.num_interventions, self.num_modules, 2))
                )
        )
        self.register_parameter('logpy', nn.Parameter(torch.zeros(self.num_vars, self.num_modules+1)))
    
    def kl_loss(self, p_edge: Tensor, p_y: Tensor):
        """Calculates the KL divergence between current PMF and a prior.

        Args:
            p_edge (:class:`torch.tensor`): The target distribution of single edge connection.
            We assume each edge is a Bernoulli random variable.
            p_y (:class:`torch.tensor`): The target distribution of partition variables.

        Returns:
            :class:`torch.tensor`: The KL divergence.
        """
        q_y = nn.Softmax(dim=-1)(self.logpy)
        q_mtx = nn.Softmax(dim=-1)(self.logits)
        kl = torch.sum(q_y * (torch.log(q_y) - torch.log(p_y)))\
            + torch.sum(q_mtx * (torch.log(q_mtx) - torch.log(p_edge)))
        return kl
    
    def sample(self, bs: int = 1):
        """ Randomly sample factor graphs. """
        mtx = F.gumbel_softmax(self.logits, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1, 1)
        y = F.gumbel_softmax(self.logpy, tau=self.tau, hard=self.hard_mask).repeat(bs, 1, 1)
        is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0, 1, 0, 0] -> is_in_m2n = [1, 0, 0, 0]
        n2m = mtx[:, :, :, 0] * (1 - is_in_m2n[:, :, :-1])
        m2n = mtx[:, :, :, 0] * is_in_m2n[:, :, :-1]
        
        if self.num_interventions > 0:
            i2m = self._forward_i2m(bs, True, True)
            return (
                torch.cat([n2m, i2m], 1),
                m2n
            )
        return n2m, m2n
    
    def fixed_fg(self, thred: float = 0.5, hard: bool = True):
        """Get the factor graph representation of the current distribution.

        Args:
            thred (float, optional): Threshold for determining an edge. Defaults to 0.5.

        Returns:
            tuple: The factor-to-node and node-to-factor matrices. If interventions are present, the third matrix is
            the intervention-to-factor matrix.
        """   
        with torch.no_grad():
            if hard:
                mtx = (F.softmax(self.logits, dim=-1)[:, :, 0] > thred).int()
                y = F.one_hot(self.logpy.argmax(dim=-1), num_classes=self.num_modules+1)
            else:
                mtx = F.softmax(self.logits, dim=-1)[:, :, 0]
                y = F.softmax(self.logpy, dim=-1)

            is_in_m2n = 1 - torch.cumsum(y, dim=-1)  # e.g. y = [0, 1, 0, 0] -> is_in_m2n = [1, 0, 0, 0]
            n2m = mtx * (1 - is_in_m2n[:, :-1])
            m2n = mtx * is_in_m2n[:, :-1]
            if self.num_interventions > 0:
                i2m = self._forward_i2m(1, hard, True).squeeze(0)
                return n2m, m2n, i2m
        return n2m, m2n, None