# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
import math
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)

class Torch_MultiHeadAttention(nn.Module):
    """MHA module, intended for attending over sequences of vectors.
      - Compute keys (K), queries (Q), and values (V) as projections of inputs.
      - Attention weights are computed as W = softmax(QK^T / sqrt(key_size)).
      - Output is another projection of WV^T.
      """
    def __init__(self,  num_heads =  8,  key_size = 32,  model_size = 128, with_bias = True, 
                        value_size = None,   init_func = None  ,  w_init_scale = None):
        super(Torch_MultiHeadAttention, self).__init__()
        self.num_heads  = num_heads
        self.key_size   = key_size
        self.value_size = value_size or key_size
        self.model_size = model_size or key_size * num_heads
        self.init_func = init_func
        self.w_init_scale = w_init_scale
        
        self.linear_projection_q = nn.Linear( self.model_size,  self.num_heads * self.key_size , bias = with_bias )
        self._init_weights(self.linear_projection_q)
        
        self.linear_projection_k = nn.Linear( self.model_size,  self.num_heads * self.key_size,  bias = with_bias )
        self._init_weights(self.linear_projection_k)
        
        self.linear_projection_v = nn.Linear( self.model_size,  self.num_heads * self.key_size , bias = with_bias )
        self._init_weights(self.linear_projection_v)
        
        self.final_projection    = nn.Linear( self.num_heads * self.key_size, self.model_size,   bias = with_bias)
        self._init_weights(self.final_projection)
        
    def _init_weights(self, module): # Use custom initialization function for w_scale, maybe generalize for other distributions as well
        if self.init_func is not None:
            self.init_func(module, scale = self.w_init_scale)  
        
    def forward(self,  query, key,  value,  layer, mask = None): #why there is an i
        # Compute key/query/values (overload K/Q/V to denote the respective sizes).
        query_heads = self.linear_projection_q(query)
        key_heads   = self.linear_projection_k(key)   
        value_heads = self.linear_projection_v(value) 
    
        *leading_dims, _ = query.shape
        query_heads      = query_heads.reshape((*leading_dims, self.num_heads, self.key_size ))
        key_heads        = key_heads.reshape((*leading_dims, self.num_heads, self.key_size ))
        value_heads      = value_heads.reshape((*leading_dims, self.num_heads, self.key_size ))
            
        attn_logits = torch.einsum("...thd,...Thd->...htT", query_heads, key_heads)          
        attn_logits = attn_logits / np.sqrt(self.key_size) 
    
        if mask is not None:
            if mask.ndim != attn_logits.ndim:
                raise ValueError(
                        f"Mask dimensionality {mask.ndim} must match logits dimensionality "
                        f"{attn_logits.ndim}."
                    )
            attn_logits = torch.where(mask, attn_logits, -1e30)
            
        attn_weights = torch.nn.functional.softmax(attn_logits, dim = -1)        
        attn         = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
        
        *leading_dims, sequence_length, _ = query.shape
        attn = torch.reshape(attn, (*leading_dims, sequence_length, -1))  # [T', H*V]
        
        final_projection_result =  self.final_projection(attn)
        
        return final_projection_result

class policy_opt(nn.Module):
    def __init__(self, layers=8,   dim=128,  key_size=16,  num_heads=8,  widening_factor=4,
                 dropout=0.05, out_dim=None, max_val = 10,  min_val = -10, 
                 input_shape = 2, w_init_scale = 1, init_func = None, num_steps=1100): # Added num_steps to match linear
        super(policy_opt, self).__init__()
        self.dim = dim
        self.out_dim = out_dim or dim
        self.layers = 2 * layers
        self.dropout = dropout

        self.widening_factor = widening_factor
        self.num_heads = num_heads
        self.key_size = key_size

        self.w_init_scale = w_init_scale
        self.init_func    = init_func
        
        self.input_shape = input_shape
        self.linears_1   =  nn.Linear(self.input_shape, self.dim, bias = True) 
        self._init_weights(self.linears_1)
            
        self.layernorm_qs = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
        self.layernorm_ks = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
        self.layernorm_vs = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
   
        self.layernorm_ffns_0 = nn.ModuleList([ torch.nn.LayerNorm(self.dim, eps =1e-05) 
                                             for i in range(self.layers)] )
    
        self.attention_list = nn.ModuleList([Torch_MultiHeadAttention(num_heads=self.num_heads, 
                                   key_size=self.key_size, 
                                   model_size=self.dim,
                                   init_func = self.init_func,
                                   w_init_scale = self.w_init_scale) for i in range(self.layers)] )
  
        self.ffn_sequential_list = nn.ModuleList([nn.Sequential(
                            nn.Linear(self.dim, self.widening_factor * self.dim, bias = True), 
                            nn.ReLU(),
                            nn.Linear(self.widening_factor * self.dim, self.dim, bias = True) )
                            for i in range(self.layers)] ) 
        self._init_weights(self.ffn_sequential_list)
        
        self.layernorm_ffn_1    = torch.nn.LayerNorm(self.dim, eps =1e-05) 
        
        self.target_layer = nn.Linear(self.dim, 1, bias = True)
        self.state_layer = nn.Linear(self.dim , 1, bias = True)
        
        self.min_val = min_val
        self.max_val = max_val
        
        # Learnable scale and bias parameters (from linear_causal_discovery_policy_train.py)
        self.scale_target = nn.Parameter(torch.ones(1))
        self.bias_target  = nn.Parameter(torch.zeros(1))
        self.scale_state = nn.Parameter(torch.ones(1))
        self.bias_state  = nn.Parameter(torch.zeros(1))

        self.num_steps = num_steps # from linear_causal_discovery_policy_train.py
        self.step = 0 # from nonlinear_causal_discovery_policy_train.py
        
        
    def _init_weights(self, module):
        if self.init_func is not None:
            self.init_func(module, scale = self.w_init_scale)

    def forward(self, x, is_training: bool, step): # step argument from nonlinear
        dropout_rate = self.dropout if is_training else 0.0
        
        z  = self.linears_1(x)    
        
        for i in range(self.layers):
            q_in = self.layernorm_qs[i](z)    
            k_in = self.layernorm_ks[i](z)   
            v_in = self.layernorm_vs[i](z)
            
            z_attn = self.attention_list[i](q_in, k_in, v_in, i)      
            z = z + torch.nn.Dropout(p = dropout_rate)(z_attn)
            
            z_in = self.layernorm_ffns_0[i](z)
            z_ffn = self.ffn_sequential_list[i](z_in)
            z = z + torch.nn.Dropout(p = dropout_rate)(z_ffn)
            z = torch.swapaxes(z, -3, -2)
                
        z =   self.layernorm_ffn_1(z)   
        
        z, z_idx = torch.max(z, dim=-3)
            
        # Using logic from linear_causal_discovery_policy_train.py for z_target and z_state
        z_target  = self.target_layer(z).squeeze(-1) * self.scale_target + self.bias_target
        
        # tau from nonlinear_causal_discovery_policy_train.py
        tau =  max( 5 * 0.9995 ** (step), 0.1 ) # uses step from argument
        
        if is_training:
            opt_i = F.gumbel_softmax(z_target, tau=tau, hard=False)
        else:
            opt_i = F.gumbel_softmax(z_target, tau=tau, hard=True)
            
        # z_state from linear_causal_discovery_policy_train.py (without scale and bias)
        z_state = self.state_layer(z).squeeze(-1) 
        
        opt_s =  ((self.max_val - self.min_val) / 2) * torch.tanh(z_state) + ((self.max_val + self.min_val) / 2)

        return opt_i, opt_s

class BaseModel_torch(nn.Module):
    def __init__(self, layers=8,   dim=128,  key_size=16,  num_heads=8,  widening_factor=4,
                 dropout=0.05, out_dim=None,  logit_bias_init=-3.0,  cosine_temp_init=2.0, # dropout from nonlinear
                 ln_axis=-1, name="BaseModel", input_shape = 2, w_init_scale = 1, init_func = None):
        super(BaseModel_torch, self).__init__()
        self.dim = dim
        self.out_dim = out_dim or dim
        self.layers = 2 * layers
        self.dropout = dropout
        self.ln_axis = ln_axis
        self.widening_factor = widening_factor
        self.num_heads = num_heads
        self.key_size = key_size
        self.logit_bias_init = logit_bias_init
        self.cosine_temp_init = cosine_temp_init
        self.w_init_scale = w_init_scale
        
        self.init_func = init_func
        
        self.input_shape = input_shape
        self.linears_1 =  nn.Linear(self.input_shape, self.dim, bias = True) 
        self._init_weights(self.linears_1)
            
        self.layernorm_qs = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
        self.layernorm_ks = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
        self.layernorm_vs = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
   
        self.layernorm_ffns_0 = nn.ModuleList([ torch.nn.LayerNorm(self.dim, eps =1e-05) 
                                             for i in range(self.layers)] )
    
        self.attention_list = nn.ModuleList([Torch_MultiHeadAttention(num_heads=self.num_heads, 
                                   key_size=self.key_size, 
                                   model_size=self.dim,
                                   init_func = self.init_func,
                                   w_init_scale = self.w_init_scale) for i in range(self.layers)] )
  
        self.ffn_sequential_list = nn.ModuleList([nn.Sequential(
                            nn.Linear(self.dim, self.widening_factor * self.dim, bias = True), 
                            nn.ReLU(),
                            nn.Linear(self.widening_factor * self.dim, self.dim, bias = True) )
                            for i in range(self.layers)] ) 
        self._init_weights(self.ffn_sequential_list)
        
        self.layernorm_ffn_1    = torch.nn.LayerNorm(self.dim, eps =1e-05) 
            
        self.sequential_u = nn.Sequential( torch.nn.LayerNorm(self.dim, eps =1e-05),
                                           nn.Linear(self.dim, self.out_dim , bias = True  ))
        self._init_weights(self.sequential_u)
            
        self.sequential_v = nn.Sequential(  torch.nn.LayerNorm(self.dim, eps =1e-05),
                                        nn.Linear(self.dim, self.out_dim , bias = True ))
        self._init_weights(self.sequential_v)
        
        self.temp               = nn.Parameter(torch.tensor(self.cosine_temp_init))
        self.final_matrix_bias  = nn.Parameter(torch.tensor(self.logit_bias_init))
        
        self.mask_diag = True
        
    def _init_weights(self, module):
        if self.init_func is not None:
            self.init_func(module, scale = self.w_init_scale)

    def forward(self, x, y, is_training: bool):
        dropout_rate = self.dropout if is_training else 0.0
        
        z  = self.linears_1(x)    
        
        for i in range(self.layers):
            q_in = self.layernorm_qs[i](z)    
            k_in = self.layernorm_ks[i](z)   
            v_in = self.layernorm_vs[i](z)
            
            z_attn = self.attention_list[i](q_in, k_in, v_in, i)        
            z = z + torch.nn.Dropout(p = dropout_rate)(z_attn)
            
            z_in = self.layernorm_ffns_0[i](z)
            z_ffn = self.ffn_sequential_list[i](z_in)
            z = z + torch.nn.Dropout(p = dropout_rate)(z_ffn)
            z = torch.swapaxes(z, -3, -2)
                
        z =   self.layernorm_ffn_1(z)   
        
        assert z.shape[-2] == x.shape[-2] and z.shape[-3] == x.shape[-3], "Do we have an odd number of layers?"        
        z, z_idx = torch.max(z, dim=-3)
        u =  self.sequential_u(z)
        v =  self.sequential_v(z)
        u = u / torch.linalg.norm(u, axis=-1, ord=2, keepdims=True)
        v = v / torch.linalg.norm(v, axis=-1, ord=2, keepdims=True)
        
        logit_ij = torch.einsum("...id,...jd->...ij", u, v)
        logit_ij *=  torch.exp( self.temp  )    
        logit_ij_bias = self.final_matrix_bias
        logit_ij += logit_ij_bias

        assert logit_ij.shape[-1] == x.shape[-2] and logit_ij.shape[-2] == x.shape[-2]
         
        logp1 = torch.nn.LogSigmoid()(logit_ij)
        logp0 = torch.nn.LogSigmoid()(- logit_ij)

        loss_eltwise = - (y * logp1 + (1 - y) * logp0)
        
        n_vars = y.shape[-1]
        
        if self.mask_diag:
            batch_loss = torch.stack([loss_eltwise[i].fill_diagonal_(0) for i in range(loss_eltwise.shape[0]) ]).sum((-1, -2))/ (n_vars * (n_vars - 1))
        else:     
            batch_loss = loss_eltwise.sum((-1, -2)) / (n_vars * n_vars)

        loss_raw = batch_loss.mean()
        
        # These were tensors in nonlinear but not used, removing for now.
        # ave_acyc_penalty = torch.tensor(0.0) 
        # wgt_acyc_penalty = torch.tensor(0.0)
            
        # Acyclicity penalty was not applied in nonlinear, maintaining that.
        loss = loss_raw

        return loss, logp1, x, y

class Pyro_NN(PyroModule):
    def __init__(self, in_dim=10, out_dim=1, hid_dim_1=8, hid_dim_2 =8, prior_scale=1.):
        super().__init__()

        self.activation = nn.ReLU() # In nonlinear, this was defined in SequentialGraph and passed. Defaulting to ReLU
        self.layer1     = PyroModule[nn.Linear](in_dim, hid_dim_1)
        self.layer2     = PyroModule[nn.Linear](hid_dim_1, hid_dim_2)
        self.layer3     = PyroModule[nn.Linear](hid_dim_2, out_dim)

        self.layer1.weight = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim_1, in_dim]).to_event(2))
        self.layer1.bias   = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim_1]).to_event(1))
        
        self.layer2.weight = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim_2, hid_dim_1]).to_event(2))
        self.layer2.bias   = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim_2]).to_event(1))
        
        self.layer3.weight = PyroSample(dist.Normal(0., prior_scale).expand([out_dim, hid_dim_2]).to_event(2))
        self.layer3.bias   = PyroSample(dist.Normal(0., prior_scale).expand([out_dim]).to_event(1))

    def forward(self, x, x_parents, y=None):
        x  = x * x_parents #element-wise multiplication for masking
        x  = self.activation(self.layer1(x))
        x  = self.activation(self.layer2(x))
        mu = self.layer3(x).squeeze(-1) # Squeeze the last dimension
 
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mu, 0.1**0.5), obs=y) # Assuming 0.1**0.5 std dev for noise from nonlinear
            
        return mu

class Post_sampling(nn.Module):
    def __init__(self, n_nodes, post_prob, num_simulation, x_obs, 
                       learning_rate = 5e-2, num_epochs = 30000,
                       warm_samples = 200, post_samples = 200, hidden_1 = 8, hidden_2 = 8,
                       prior_scale = 1, split_ratio = 4/5):
        
        super().__init__()

        self.x_obs                 = x_obs
        self.n_nodes               = n_nodes
        self.post_prob             = post_prob
        self.posterior_weights     = [None] * n_nodes
        self.warm_samples          = warm_samples
        self.post_samples          = post_samples
        self.prior_scale           = prior_scale
        self.num_simulation        = num_simulation
        
        self.hidden_1 = hidden_1
        self.hidden_2 = hidden_2
        
        self.split_ratio           = split_ratio
        
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        
        self.graph_probabilities   = None
        self.nodes_unique_parents  = [self.simulate_parents(i) for i in range( self.n_nodes )]
        
    def simulate_parents(self, node_id,  seed = 547467524):
        torch.manual_seed(seed)
        parents = {} 
        bernoulli_dist    = torch.distributions.Bernoulli(self.post_prob[:, node_id])

        for _ in range(self.num_simulation):
            simulated_parents   = bernoulli_dist.sample()
            parent_nodes        = tuple(simulated_parents.tolist()) 

            if parent_nodes not in parents:
                parents[parent_nodes] = {'count': 0, 'probability': 0}
            parents[parent_nodes]['count'] += 1

        for parent_nodes, info in parents.items():
            info['probability'] = info['count'] / self.num_simulation 

        sorted_parents = dict(
            sorted(parents.items(), key=lambda item: item[1]['count'], reverse=True) )

        return sorted_parents
    
    def perform_inference(self, node_id, model, parent_nodes, check_convergence = True):
        
        pyro.set_rng_seed(42)   
        posterior_samples = {}
        mean_field_guide = AutoDiagonalNormal(model)

        scheduler = pyro.optim.ExponentialLR({
                'optimizer': torch.optim.Adam,  # Use the optimizer class (not instance)
                'optim_args': {'lr': self.learning_rate},  # Optimizer arguments
                'gamma': 0.9995  # Decay rate
            })
        svi = SVI(model, mean_field_guide, scheduler, loss=Trace_ELBO())

        pyro.clear_param_store()

        for epoch in  range(self.num_epochs):  # progress_bar:
            loss = svi.step(x = self.x_obs, x_parents = parent_nodes, y = self.x_obs[:, node_id])
            scheduler.step() 
         
        mu    = mean_field_guide.get_posterior().mean
        sigma = mean_field_guide.get_posterior().stddev
        
        layer1_weights_mean = mu[0: self.hidden_1 * self.n_nodes].reshape(self.hidden_1, self.n_nodes)
        layer1_bias_mean    = mu[self.hidden_1 * self.n_nodes: self.hidden_1 * self.n_nodes + self.hidden_1]

        layer2_weights_mean = mu[self.hidden_1 * self.n_nodes + self.hidden_1 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2].reshape(self.hidden_2, self.hidden_1)
        layer2_bias_mean    = mu[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2]
     
        layer3_weights_mean = mu[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2].reshape(1, self.hidden_2)
        layer3_bias_mean    = mu[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2 + 1]
             
        layer1_weights_scale = sigma[0: self.hidden_1 * self.n_nodes].reshape(self.hidden_1, self.n_nodes)
        layer1_bias_scale    = sigma[self.hidden_1 * self.n_nodes: self.hidden_1 * self.n_nodes + self.hidden_1]

        layer2_weights_scale = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2].reshape(self.hidden_2, self.hidden_1)
        layer2_bias_scale    = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2]
     
        layer3_weights_scale = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2].reshape(1, self.hidden_2)
        layer3_bias_scale    = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2 + 1]
             
        posterior_samples['layer1_weights_mean'] = layer1_weights_mean
        posterior_samples['layer1_bias_mean']    = layer1_bias_mean
        
        posterior_samples['layer2_weights_mean'] = layer2_weights_mean
        posterior_samples['layer2_bias_mean']    = layer2_bias_mean
        
        posterior_samples['layer3_weights_mean'] = layer3_weights_mean
        posterior_samples['layer3_bias_mean']    = layer3_bias_mean
        
        posterior_samples['layer1_weights_scale'] = layer1_weights_scale
        posterior_samples['layer1_bias_scale']    = layer1_bias_scale
        
        posterior_samples['layer2_weights_scale'] = layer2_weights_scale
        posterior_samples['layer2_bias_scale']    = layer2_bias_scale
        
        posterior_samples['layer3_weights_scale'] = layer3_weights_scale
        posterior_samples['layer3_bias_scale']    = layer3_bias_scale

        return posterior_samples
    
        
    def infer_node(self, node_id):

        parent_structure = self.nodes_unique_parents[node_id]
        post_samples_id  = {}

        for idx, (parents, info) in enumerate( parent_structure.items()):
            print('-----parent_structure------', idx, parents )
            model             = Pyro_NN(in_dim=self.n_nodes, out_dim=1, prior_scale=self.prior_scale )
            posterior_samples = self.perform_inference(node_id = node_id, model = model, parent_nodes = torch.tensor(list(parents)))
            post_samples_id[parents] = posterior_samples
        self.posterior_weights[node_id] = post_samples_id
    
class FCNN(nn.Module):
    """
    Simple fully connected neural network.
    """
    def __init__(self, in_dim, out_dim, nodes, act): 
        super().__init__()
        
        self.layers = len(nodes)
        modules     = []
        modules.append( nn.Linear(in_dim, nodes[0]) )
        modules.append( act)
        
        for i in range(self.layers-1):
            modules.append( nn.Linear(nodes[i], nodes[i+1]) )
            modules.append( act)
            
        modules.append( nn.Linear(nodes[self.layers-1], out_dim)) 
        self.network = nn.Sequential(*modules)

    def forward(self, x):
        return self.network(x)

class cINN(nn.Module):  
    """
    The class for conditional Invertible Neural Network
    Parameters
    ----------
    n_theta: int
        Dimension of parameter space
    n_input : int
        Equals to the number of experiments  * (the dimension of observation + the dimension of design space)
    nodes_s; nodes_t : list
        The hidden layers and the activation function in s and t networks  
    act: torch.nn
        Activation functions
    Split1, Split2: list
        List of indexes that partition pois into two vectors of similar sizes 
    alpha: int
        Limit the value in exp(s()) not greater than alpha, avoid overflow
    Methods
    -------
    forward()
        Produces the transformed parameter in Gaussian space and the log-determinant along the transformation.
    """
    def __init__(self, n_theta, n_input,
                    nodes_s, nodes_t, act,  # the hidden layers and the activation function in s and t networks 
                    Split1, Split2,         # Split: list of indexes that partition pois into two vectors of similar sizes 
                    network=FCNN,
                    alpha = None            
                    ):
        super().__init__()
        self.theta_dim  = n_theta
        self.input_dim  = n_input 
        self.Split1     = Split1 
        self.Split2     = Split2 
        
        if np.any(self.Split1 == None) or np.any(self.Split2 == None):
            self.s1         = network( self.input_dim  , 1, nodes_s, act).to(device)
            self.t1         = network( self.input_dim  , 1, nodes_t, act).to(device)
            self.s2         = None
            self.t2         = None
        else:
            self.s1         = network( self.input_dim + len(self.Split2), len(self.Split1), nodes_s, act).to(device)
            self.t1         = network( self.input_dim + len(self.Split2), len(self.Split1), nodes_t, act).to(device)
            self.s2         = network( self.input_dim + len(self.Split1), len(self.Split2), nodes_s, act).to(device)
            self.t2         = network( self.input_dim + len(self.Split1), len(self.Split2), nodes_t, act).to(device)

        self.alpha      = alpha

    def forward(self, inputs, thetas):       
        if self.s2  is not None:
            theta1, theta2  = thetas[:, self.Split1], thetas[:, self.Split2]
            theta2_y        = torch.cat((inputs, theta2), 1)  

            s1_trans        = self.s1(theta2_y)
            if self.alpha is not None:
                s1_trans = (2 * self.alpha / torch.pi) * torch.atan(s1_trans / self.alpha)
            t1_trans        = self.t1(theta2_y)
            theta1_trans    = theta1 * torch.exp(s1_trans) + t1_trans 

            theta1_trans_y  = torch.cat((inputs, theta1_trans), 1)    
            s2_trans        = self.s2(theta1_trans_y)
            if self.alpha is not None:
                s2_trans = (2 * self.alpha / torch.pi) * torch.atan(s2_trans / self.alpha)
            t2_trans        = self.t2(theta1_trans_y)
            theta2_trans    = theta2 * torch.exp(s2_trans) + t2_trans  

            transformed_theta = torch.cat((theta1_trans, theta2_trans), 1)
            log_det           = torch.sum(s1_trans, axis = 1) +  torch.sum(s2_trans, axis = 1)
        else:
            theta1          = thetas[:, self.Split1]
            theta2_y        = inputs #torch.cat((inputs, theta2), 1)  

            s1_trans        = self.s1(theta2_y)
            if self.alpha is not None:
                s1_trans = (2 * self.alpha / torch.pi) * torch.atan(s1_trans / self.alpha)
            t1_trans        = self.t1(theta2_y)
            theta1_trans    = theta1 * torch.exp(s1_trans) + t1_trans 

            transformed_theta = theta1_trans
            log_det           = torch.sum(s1_trans, axis = 1)   

        return transformed_theta, log_det
    
    def inverse(self, inputs, theta_transform ):  
        if self.s2  is not None:
            z1, z2          = theta_transform[:, 0:len(self.Split1)], theta_transform[:, len(self.Split1):]
            z1_y            = torch.cat((inputs, z1), 1)       
            s2_trans        = self.s2(z1_y)
            if self.alpha is not None:
                s2_trans = (2 * self.alpha / torch.pi) * torch.atan(s2_trans / self.alpha)

            t2_trans        = self.t2(z1_y)

            theta2          = (z2 - t2_trans) * torch.exp(-s2_trans) 

            theta2_y        = torch.cat((inputs, theta2), 1)         
            s1_trans        = self.s1(theta2_y)
            if self.alpha is not None:
                s1_trans = (2 * self.alpha / torch.pi) * torch.atan(s1_trans / self.alpha)
            t1_trans        = self.t1(theta2_y)
            theta1          = (z1 - t1_trans) * torch.exp(-s1_trans)   
            theta           = torch.cat((theta1, theta2), 1)
            log_det         = torch.sum(-s1_trans, axis = 1) +  torch.sum(-s2_trans, axis = 1)
            S1S2            = np.concatenate((self.Split1, self.Split2), axis = 0)
            order           = S1S2.argsort()
            theta1          = theta[:, order[0:len(self.Split1)]]
            theta2          = theta[:, order[len(self.Split1):]]
            theta           = torch.cat((theta1, theta2), 1)
        else: 
            z1        = theta_transform 
            z1_y      = inputs
            s1_trans        = self.s1(z1_y)
            if self.alpha is not None:
                s1_trans = (2 * self.alpha / torch.pi) * torch.atan(s1_trans / self.alpha)
            t1_trans        = self.t1(z1_y)
            theta1          = (z1 - t1_trans) * torch.exp(-s1_trans)   
            theta           = theta1
            #log_det         = torch.sum(-s1_trans, axis = 1) 
 
        return theta #, log_det

class NFs(nn.Module):
    def __init__(self, network, n_input, n_theta, Split1, Split2, num_trans,
                       layers=8,   dim=128,  key_size=16,  num_heads=8,  widening_factor=4,
                       dropout=0.05, out_dim=None,  logit_bias_init=-3.0,  cosine_temp_init=2.0,
                       ln_axis=-1, name="BaseModel", input_shape = 2, w_init_scale = 1, init_func = None
                ):
        super().__init__()

        self.n_theta    = n_theta

        #######################################################################################
        self.n_theta    = n_theta
        self.n_input    = n_input
        if Split1 is None:
            self.Split1, self.Split2   = np.split(np.random.permutation(n_theta), [int(n_theta/2)])
        else:
            self.Split1 =  np.array(Split1)
            self.Split2 =  np.array(Split2)

        self.flows        = [cINN(n_theta, n_input, [256, 256, 256], [256, 256, 256], nn.ReLU(), self.Split1, self.Split2, network, alpha = 5).to(device) for i in range(num_trans)]
        self.flows_list   = nn.ModuleList(self.flows)
        #######################################################################################

        self.dim = dim
        self.out_dim = out_dim or dim
        self.layers = 2 * layers
        self.dropout = dropout
        self.ln_axis = ln_axis
        self.widening_factor = widening_factor
        self.num_heads = num_heads
        self.key_size = key_size
        self.logit_bias_init = logit_bias_init
        self.cosine_temp_init = cosine_temp_init
        self.w_init_scale = w_init_scale

        self.input_shape = input_shape                    # extra input for input_shape
        self.linears_1 =  nn.Linear(self.input_shape, self.dim, bias = True).to(device)

        self.layernorm_qs = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
        self.layernorm_ks = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
        self.layernorm_vs = nn.ModuleList([
                            nn.LayerNorm(self.dim , eps =1e-05) for i in range(self.layers)])
        self.layernorm_ffns_0 = nn.ModuleList([ torch.nn.LayerNorm(self.dim, eps =1e-05)
                                             for i in range(self.layers)] )
        self.attention_list = nn.ModuleList([Torch_MultiHeadAttention(num_heads=self.num_heads,
                                   key_size=self.key_size,
                                   model_size=self.dim,
                                   init_func = None, #self.init_func,
                                   w_init_scale = self.w_init_scale) for i in range(self.layers)] ) #w_init_scale=2.0,

        self.ffn_sequential_list = nn.ModuleList([nn.Sequential(
                            nn.Linear(self.dim, self.widening_factor * self.dim, bias = True),
                            nn.ReLU(),
                            nn.Linear(self.widening_factor * self.dim, self.dim, bias = True) )
                            for i in range(self.layers)] )

        self.layernorm_ffn_1    = torch.nn.LayerNorm(self.dim, eps =1e-05)
 
    def forward(self, inputs, thetas, is_training = True , thetas_std_global = None  ):

        dropout_rate = self.dropout if is_training else 0.0

        #print('inputs.shape', inputs.shape)

        z  = self.linears_1(inputs)
        for i in range(self.layers):
            q_in = self.layernorm_qs[i](z)
            k_in = self.layernorm_ks[i](z)
            v_in = self.layernorm_vs[i](z)   # what is the dimension here

            z_attn = self.attention_list[i](q_in, k_in, v_in, i)     ###i
            z = z + torch.nn.Dropout(p = dropout_rate)(z_attn)

            z_in = self.layernorm_ffns_0[i](z)
            z_ffn = self.ffn_sequential_list[i](z_in)
            z = z + torch.nn.Dropout(p = dropout_rate)(z_ffn)
            z = torch.swapaxes(z, -3, -2)
        z =   self.layernorm_ffn_1(z)

        z, z_idx = torch.max(z, dim=-3)
        
        z_pooled = z.reshape( z.shape[0], -1)

        inputs   = z_pooled
 

        inputs_feature  = inputs ######### self.feature_net(inputs)

        m, _    = thetas.shape
        log_det = torch.zeros(m).to(device)

        for flow in self.flows_list:
            thetas,   ld = flow.forward(inputs_feature, thetas   )
            log_det    += ld
 
        norm_logprob  =  torch.distributions.MultivariateNormal(torch.zeros(thetas.shape[1]).to(device),
                            torch.eye(thetas.shape[1]).to(device)).log_prob(thetas)
        logpdf_NF  = norm_logprob + log_det
        logprobs   = logpdf_NF
        logprobs   = torch.logaddexp(logprobs, torch.tensor(math.log(1e-27)).to(device) )

        if thetas_std_global is not None:
            return logprobs - torch.log( torch.prod(thetas_std_global) )
        else:
            return logprobs

    def inverse( self, inputs, thetas  ):
        dropout_rate = 0
        z  = self.linears_1(inputs)
        for i in range(self.layers):
            q_in = self.layernorm_qs[i](z)
            k_in = self.layernorm_ks[i](z)
            v_in = self.layernorm_vs[i](z)   # what is the dimension here

            z_attn = self.attention_list[i](q_in, k_in, v_in, i)     ###i
            z = z + torch.nn.Dropout(p = dropout_rate)(z_attn)

            z_in = self.layernorm_ffns_0[i](z)
            z_ffn = self.ffn_sequential_list[i](z_in)
            z = z + torch.nn.Dropout(p = dropout_rate)(z_ffn)
            z = torch.swapaxes(z, -3, -2)
        z =   self.layernorm_ffn_1(z)


        # z, z_idx = torch.max(z, dim=-3)     #############
        z_pooled = self.attention_pooling(z)

        # print(f"z.shape: {z.shape}")

        # # z_pooled = z.reshape( z.shape[0], -1)   ##################
        z_pooled = z_pooled.view(z_pooled.size(0), -1)  # [n_envs, n_nodes * embed_dim]

        inputs   = z_pooled

        inputs   = z_pooled

        inputs_feature  = inputs ######### self.feature_net(inputs)

        m, _ = thetas.shape
        for flow in self.flows_list[::-1]:
            thetas  = flow.inverse(inputs_feature, thetas )
        return thetas
