import torch
import torch.nn as nn 
import numpy as np 
import math 
import time 
from models.layers.masks import GumbelAdjacency, GumbelInNOut, GumbelIntervWeight
from models.realnvp import MyNFlows
# per-node Bernoulli logits used; no categorical import needed


class iResBlock(nn.Module):
    """
    ----------------------------------------------------------------------------------------
    The class for a single residual map, i.e., (I -f)(x) = e. 
    ----------------------------------------------------------------------------------------
    The forward method computes the residual map and also log-det-Jacobian of the map. 

    Parameters:
    1) func - (nn.Module) - torch module for modelling the function f in (I - f).
    2) n_power_series - (int/None) - Number of terms used for computing determinent of log-det-Jac, 
                                     set it to None to use Russian roulette estimator. 
    3) neumann_grad - (bool) - If True, Neumann gradient estimator is used for Jacobian.
    4) n_dist - (string) - distribution used to sample n when using Russian roulette estimator. 
                           'geometric' - geometric distribution.
                           'poisson' - poisson distribution.
    5) lamb - (float) - parameter of poisson distribution.
    6) geom_p - (float) - parameter of geometric distribution.
    7) n_samples - (int) - number of samples to be sampled from n_dist. 
    8) grad_in_forward - (bool) - If True, it will store the gradients of Jacobian with respect to 
                                  parameters in the forward pass. 
    9) n_exact_terms - (int) - Minimum number of terms in the power series. 
    10) dag_input - (bool) - If True, it will use the DAG constraint in the forward pass. 
    11) lin_logdet - (bool) - If True, it will use linear algebra based logdet estimator. 
    12) centered - (bool) - If True, it will use centering in the forward pass. 
    13) total_exp - (int) - Total number of experimental regimes. 
    14) batch_size - (int) - Batch size used during training. 
    15) learn_interv - (bool) - If True, the intervention targets will be learned. 
    16) tau - (float) - Temperature parameter for Gumbel softmax. 
    ----------------------------------------------------------------------------------------
    """
    def __init__(self, func, func_i, n_power_series, neumann_grad=True, n_dist='geometric', lamb=2., geom_p=0.5, n_samples=1, grad_in_forward=False, n_exact_terms=2, dag_input=False, lin_logdet=False, centered=True, total_exp=None, batch_size=128, learn_interv=True, tau=0.5):
        super(iResBlock, self).__init__()
        self.f = func
        self.f_i = func_i
        self.geom_p = nn.Parameter(torch.tensor(np.log(geom_p) - np.log(1. - geom_p)))
        self.lamb = nn.Parameter(torch.tensor(lamb))
        self.n_dist = n_dist
        self.n_power_series = n_power_series 
        self.neumann_grad = neumann_grad 
        self.grad_in_forward = grad_in_forward
        self.n_exact_terms = n_exact_terms
        self.n_samples = n_samples
        self.dag_input = dag_input
        self.gumbel_soft_layer = GumbelAdjacency(self.f.n_nodes)
        self.lin_logdet = lin_logdet
        self.centered = centered
        self.total_exp = total_exp
        self.batch_size = batch_size
        self.learn_interv = learn_interv 
        self.tau = tau 
        self.mu = nn.Parameter(torch.zeros(self.f.n_nodes).float())
        if dag_input:
            self.Lambda = nn.Parameter(torch.zeros(self.f.n_nodes).float())
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = device
        # per-node normalizing flows for observational / interventional data
        self.realnvp_by_node_obs = nn.ModuleDict()
        for d in range(1, self.f.n_nodes + 1):
            self.realnvp_by_node_obs[str(d)] = MyNFlows(1).to(device)
        self.realnvp_by_node_int = nn.ModuleDict()
        for d in range(1, self.f.n_nodes + 1):
            self.realnvp_by_node_int[str(d)] = MyNFlows(1).to(device)

        # learned intervention mask
        self.trained_interv = GumbelIntervWeight(self.f.n_nodes, self.total_exp, tau=self.tau)        
        self.last_interv_mask = None  
        self.last_interv_indices = None  

    def forward(self, x, intervention_mask=None, neumann_grad=True, exp_id=None):
        self.neumann_grad = neumann_grad
        
        if intervention_mask is None and self.learn_interv == True:
            B = x.shape[0]
            if isinstance(exp_id, torch.Tensor):
                regime_idx = exp_id.clone().detach().to(device=x.device, dtype=torch.long)
            else:
                regime_idx = torch.tensor(exp_id, dtype=torch.long, device=x.device) 
            regime_for_call_cpu = regime_idx.cpu()
            intervention_mask = self.trained_interv(1, regime=regime_for_call_cpu)
            intervention_mask = 1- intervention_mask.to(x.device).type_as(x)
        elif intervention_mask is None and self.learn_interv == False:
            intervention_mask = torch.ones(x.shape[0], x.shape[1], device=x.device)

        I = torch.ones(x.shape[0], x.shape[1], device=x.device)        

        if self.dag_input:
            Lamb_mat = torch.diag(torch.exp(self.Lambda))
            Lamb_mat_inv = torch.diag(1/torch.exp(self.Lambda))
            x_inp = (x - self.mu) @ Lamb_mat
        else:
            x_inp = x
        f_x, f_x_i, logdetgrad, cmp_time = self._logdetgrad(x_inp, intervention_mask)
        if self.dag_input:
            e = (x - self.mu) - (f_x* intervention_mask) @ Lamb_mat_inv  - (f_x_i * (1 - intervention_mask)) @ Lamb_mat_inv 
        else:
            e = x - f_x * intervention_mask - f_x_i * (I - intervention_mask)
        e = e.to(self.device) 
        B, D = e.shape
        z_obs_full = torch.zeros((B, D), device=e.device, dtype=e.dtype)
        z_int_full = torch.zeros((B, D), device=e.device, dtype=e.dtype)
        logdet_obs_per_node = torch.zeros((B, D), device=e.device, dtype=e.dtype)
        logdet_int_per_node = torch.zeros((B, D), device=e.device, dtype=e.dtype)

        for j in range(D):
            col = e[:, j]                          
            mask_col = intervention_mask[:, j]    
            # indices for observed / intervened in this batch
            obs_idx = (mask_col > 0).nonzero(as_tuple=True)[0]
            int_idx = (mask_col <= 0).nonzero(as_tuple=True)[0]
            # observed
            if obs_idx.numel() > 0:
                obs_col = col.index_select(0, obs_idx).reshape(-1, 1)   
                flow_obs = self.realnvp_by_node_obs[str(j+1)]
                z_o, ld_o = flow_obs(obs_col)
                z_obs_full[obs_idx, j] = z_o.reshape(-1)
                logdet_obs_per_node[obs_idx, j] = ld_o.reshape(-1)
            # intervened
            if int_idx.numel() > 0:
                int_col = col.index_select(0, int_idx).reshape(-1, 1)  
                flow_int = self.realnvp_by_node_int[str(j+1)]
                z_i, ld_i = flow_int(int_col)
                z_int_full[int_idx, j] = z_i.reshape(-1)
                logdet_int_per_node[int_idx, j] = ld_i.reshape(-1)

        # combine per-coordinate latents using the mask
        z_full = z_obs_full * intervention_mask + z_int_full * (1.0 - intervention_mask)

        # per-sample scalar logdet (sum over nodes)
        logdet_per_sample = logdet_obs_per_node.sum(dim=1) + logdet_int_per_node.sum(dim=1)

        return z_full, logdetgrad, logdet_per_sample
    def return_adjacency(self):
        return self.gumbel_soft_layer.get_proba() 

    def predict_from_latent(self, latent_vec, n_iter=10, intervention_set=[None], init_provided=False, x_init=None):
        if init_provided:
            x = torch.tensor(x_init).float().to(latent_vec.device) 
        else:
            x = torch.randn(latent_vec.size(), device=latent_vec.device)
        c = torch.zeros_like(x)
        obs_set = np.setdiff1d(np.arange(x.shape[1]), intervention_set)
        U = torch.zeros(x.shape[1], x.shape[1], device=x.device)
        U[obs_set, obs_set] = 1
        if intervention_set[0] != None:
            c[:, intervention_set] = torch.tensor(x_init[:, intervention_set]).float().to(latent_vec.device)

        for _ in range(n_iter):
            x = self.f(x - self.mu) @ U + latent_vec @ U + c + self.mu
        
        return x 

    def _logdetgrad(self, x, intervention_mask):
        with torch.enable_grad():
            if self.n_dist == 'geometric':
                geom_p = torch.sigmoid(self.geom_p).item()
                sample_fn = lambda m: geometric_sample(geom_p, m)
                rcdf_fn = lambda k, offset: geometric_1mcdf(geom_p, k, offset)
            elif self.n_dist == 'poisson':
                lamb = self.lamb.item()
                sample_fn = lambda m: poisson_sample(lamb, m)
                rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset)
            
            if self.training:
                if self.n_power_series is None:
                    # Unbiased estimation.
                    lamb = self.lamb.item()
                    n_samples = sample_fn(self.n_samples)
                    n_power_series = max(n_samples) + self.n_exact_terms
                    coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \
                        sum(n_samples >= k - self.n_exact_terms) / len(n_samples)
                else:
                    # Truncated estimation.
                    n_power_series = self.n_power_series
                    coeff_fn = lambda k: 1.

            vareps = torch.randn_like(x)

            if self.lin_logdet:
                estimator_fn = linear_logdet_estimator
            else:
                if self.training and self.neumann_grad:
                    estimator_fn = neumann_logdet_estimator
                else:
                    estimator_fn = basic_logdet_estimator

            if self.training and self.grad_in_forward:
                f_x, logdetgrad = mem_eff_wrapper(
                    estimator_fn, self.f, x, n_power_series, vareps, coeff_fn, self.training
                )
            else:
                x = x.requires_grad_(True)
                graph_adj = self.gumbel_soft_layer(x.shape[0])
                f_x = self.f(x,graph_adj)
                f_x_i = self.f_i(x,graph_adj)
                tic = time.time()
                if self.lin_logdet:
                    Weight = self.f.layer.weight
                    Weight_i = self.f_i.layer.weight
                    I = torch.eye(Weight.shape[0], Weight.shape[1], device=Weight.device)
                    self_loop_mask = torch.ones_like(Weight)
                    ind = np.diag_indices(Weight.shape[0])
                    self_loop_mask[ind[0], ind[1]] = 0
                    logdetgrad = estimator_fn((U @ self_loop_mask * Weight) + ((I-U) @ self_loop_mask * Weight_i), x.shape[0])
                else:
                    I = torch.ones(x.shape[0], x.shape[1], device=intervention_mask.device)
                    #I = torch.eye(x.shape[1], x.shape[1], device=x.device)
                    logdetgrad = estimator_fn((f_x * intervention_mask) + (f_x_i * (I - intervention_mask)), x, n_power_series, vareps, coeff_fn, self.training)
                toc = time.time()
                comp_time = toc - tic 

        return f_x, f_x_i, logdetgrad.view(-1, 1), comp_time
        
def basic_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
    vjp = vareps
    logdetgrad = torch.tensor(0.).to(x)
    for k in range(1, n_power_series + 1):
        vjp = torch.autograd.grad(g, x, vjp, create_graph=training, retain_graph=True)[0]
        tr = torch.sum(vjp.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
        delta = -1 / k * coeff_fn(k) * tr
        logdetgrad = logdetgrad + delta
    return logdetgrad


def neumann_logdet_estimator(g, x, n_power_series, vareps, coeff_fn, training):
    vjp = vareps
    neumann_vjp = vareps
    with torch.no_grad():
        for k in range(1, n_power_series + 1):
            vjp = torch.autograd.grad(g, x, vjp, retain_graph=True)[0]
            neumann_vjp = neumann_vjp + (-1) * coeff_fn(k) * vjp
    vjp_jac = torch.autograd.grad(g, x, neumann_vjp, create_graph=training)[0]
    logdetgrad = torch.sum(vjp_jac.view(x.shape[0], -1) * vareps.view(x.shape[0], -1), 1)
    return logdetgrad

def linear_logdet_estimator(W, bs):
    n = W.shape[0]
    I = torch.eye(n, device=W.device)
    return torch.log(torch.det(I - W)) * torch.ones(bs, 1, device=W.device)


def geometric_sample(p, n_samples):
    return np.random.geometric(p, n_samples)

def geometric_1mcdf(p, k, offset):
    if k <= offset:
        return 1.
    else:
        k = k - offset
    """P(n >= k)"""
    return (1 - p)**max(k - 1, 0)

def poisson_sample(lamb, n_samples):
    return np.random.poisson(lamb, n_samples)

def poisson_1mcdf(lamb, k, offset):
    if k <= offset:
        return 1.
    else:
        k = k - offset
    """P(n >= k)"""
    sum = 1.
    for i in range(1, k):
        sum += lamb**i / math.factorial(i)
    return 1 - np.exp(-lamb) * sum