import functools
import numpy as np
import math
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
from tqdm.auto import trange
pyro.clear_param_store()
from pyro.nn import PyroModule, PyroSample
import torch.nn as nn
from tqdm import tqdm
import os
import pickle

# Import from local modules
from utils import mat_to_toporder
from synthetic import Gaussian_torch
from synthetic import ErdosRenyi, ScaleFree, Yeast, Ecoli
from synthetic.noise_scale import init_noise_dist

import torch
import torch.utils
import torch.nn as nn
import torch.nn.functional as F
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev) 

from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
from models import policy_opt, BaseModel_torch, Pyro_NN
import argparse


def make_parser():

    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="er", type=str, help="Enviroment name")
    parser.add_argument("--edge_prob", default=0.2, type=float, help="expected edges per node")
    parser.add_argument("--num_steps", default=10000, type=int, help="total number of steps")
    parser.add_argument("--n_nodes", default=10, type=int, help="number of nodes")
    parser.add_argument("--n_stage", default=5, type=int, help="number of stages for intervention")
    parser.add_argument("--torch_manual_seed", default=123, type=int, help="torch manual seed")
    parser.add_argument("--env_seed", default=283956, type=int, help="torch manual seed")
    parser.add_argument("--n_int", default=5, type=int, help="number of interventions")
    parser.add_argument("--n_obs", default=50, type=int, help="number of observations")
    parser.add_argument("--total_n_envs", default=1000, type=int, help="total number of envs")
    parser.add_argument("--train_n_envs", default=10, type=int, help="number of envs in each training step")
    parser.add_argument("--eval_n_envs", default=10, type=int, help="number of envs for evaluation")

    parser.add_argument("--n_update_posterior", default=5, type=int, help="number of updates of posterior network")

    parser.add_argument("--int_value", default=5.0, type=float, help="intervention value")

    parser.add_argument("--eval_every", default=5, type=int, help="evaluate after every n steps")
    parser.add_argument("--save_every", default=50, type=int, help="save model after every n steps")

    parser.add_argument("--save_dir", default="results/model_ckpts_nonlinear", 
                        type=str, help="save directory of the model and results")

    

    return parser


class Sequential_Graph(nn.Module):
    def __init__(self, hidden_size = 8, 
                       activation = 'ReLU', 
                       prior_scale = 1,
                       noise =  Gaussian_torch(0, 0.1**0.5),
                       noise_scale_constant = True, 
                       training_actor = True,
                       noise_scale_heteroscedastic = None ,
                       n_stage = 10, 
                       noise_scale   = None, interv_dist = None,  bias_with_ancestor = False, n_nodes = 10,
                       actor_network = None, post_network = None, env = ErdosRenyi(edges_per_var = 2),
                       actor_network_untrained = None,
                       n_batch = 5 , n_obs = 30,
                       post_lr = 1e-4, post_gamma = 0.8, actor_lr = 1e-4, actor_gamma = 0.8, 
                       post_sampling_instance = None
                       ):
        
        super().__init__()
        

        
        if activation == 'ReLU':
            self.activation = nn.ReLU()
        self.prior_scale                = prior_scale
        
        self.noise                       = noise
        self.noise_scale                 = noise_scale
        self.noise_scale_constant        = noise_scale_constant
        self.noise_scale_heteroscedastic = noise_scale_heteroscedastic
        
        self.interv_dist                 = interv_dist
        self.training_actor              = training_actor
        
        self.n_nodes   = n_nodes
        self.step      = 0
        self.n_stage   = n_stage
        self.n_batch   = n_batch
        self.n_obs = n_obs
        self.env = env
        
        if actor_network:
            self.actor_network   = actor_network       
            self.actor_optimizer = torch.optim.Adam(self.actor_network.parameters(), lr = actor_lr)
            self.actor_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.actor_optimizer, gamma = actor_gamma)

        self.actor_network_untrained = actor_network_untrained

        self.post_network    = post_network
        self.post_optimizer  = torch.optim.Adam(self.post_network.parameters(), lr = post_lr)
        self.post_scheduler  = torch.optim.lr_scheduler.ExponentialLR(self.post_optimizer, gamma = post_gamma)

        self.hidden_size = hidden_size


    @staticmethod
    def _nn_mechanism(*, x, z, parents, fc1, fc2, fc3, activation):

        x_msk = x * parents
        hidden1 = activation(fc1(x_msk))
        hidden2 = activation(fc2(hidden1))
        output  = fc3(hidden2).squeeze(1) + z
        
        return output
    
    def draw_nn_params(self, *, seed = 123):   #change seed
        torch.manual_seed(seed)

        fc1 = nn.Linear(self.n_nodes,     self.hidden_size).to(device)
        fc2 = nn.Linear(self.hidden_size, self.hidden_size ).to(device)
        fc3 = nn.Linear(self.hidden_size, 1).to(device)
        
        nn.init.normal_(fc1.weight, mean=0.0, std=self.prior_scale)
        nn.init.normal_(fc1.bias, mean=0.0, std=self.prior_scale)
        nn.init.normal_(fc2.weight, mean=0.0, std=self.prior_scale)
        nn.init.normal_(fc2.bias, mean=0.0, std=self.prior_scale)
        nn.init.normal_(fc3.weight, mean=0.0, std=self.prior_scale)
        nn.init.normal_(fc3.bias, mean=0.0, std=self.prior_scale)

        return dict( fc1=fc1,  fc2=fc2, fc3 = fc3, activation=self.activation )

    def sample_graphs_and_weights(self, n_envs = 5, is_training = True, step = None):
        
        if step is None:
            step = self.step
        rng_numpy   = np.random.default_rng( seed=int(step + 1))
        rng_torch   = torch.Generator().manual_seed( int(step +1) ) #
        f_set       = []
        nse_set     = []
        graph_set   = []
        
        for i in range(n_envs):
            simulated_g = self.env(rng_numpy, self.n_nodes) 
            graph_set.append(simulated_g) 
            
            f = []
            for j in range(self.n_nodes):
                parents = simulated_g[:, j]#.reshape(1, -1)
                nn_params = self.draw_nn_params( seed = step )
                f.append(functools.partial(Sequential_Graph._nn_mechanism, **nn_params))

            nse = []
            for j in range(self.n_nodes):
                nse.append(init_noise_dist(rng                         = rng_torch,
                                           seed                        = int(step+1),
                                           dim                         = int(simulated_g[:, j].sum().item()),
                                           dist                        = self.noise,
                                           noise_scale_constant        = self.noise_scale_constant,
                                           noise_scale                 = self.noise_scale,
                                           noise_scale_heteroscedastic = self.noise_scale_heteroscedastic))
            f_set.append( f )
            nse_set.append( nse )

        return graph_set, f_set, nse_set


    def sample_recursive_scm_custom(self, rng, n_batch, g, f, nse, n_obs = 0,
                 int_target = None, int_value = None, bias_with_ancestor = False):
        n_vars   = g.shape[-1]
        toporder = mat_to_toporder(g)
        
        if int_target is None:
            int_target = torch.zeros( n_vars ).to(device)
            int_value  = torch.zeros( n_vars ).to(device)

        x_int = torch.zeros((n_batch, n_vars), requires_grad=True).to(device)
    
        for j in toporder:
            int_percentage = int_target[j]
            z_j = nse[j](rng=rng, x=x_int, is_parent= torch.from_numpy(g[:, j])).to(device) 
            raw_contrib = f[j](x=x_int, z=z_j, parents = torch.from_numpy(g[:, j]).to(device)  ) 
            x_int[:, j] = int_percentage * int_value[j] + (1-int_percentage) * raw_contrib
            
        return x_int 

    def standardize(self,  y_hist):
        ref_y = y_hist[..., 0]
        mean  = ref_y.mean(dim=-2, keepdim=True)
        std   = ref_y.std(dim=-2, keepdim=True)
        ref_y = (ref_y - mean) / torch.where(std == 0.0, torch.tensor(1.0, dtype=std.dtype, device=device), std)
        
        return ref_y
    
    def model(self, n_envs, is_training = True):  
        
        graph_set, f_set, nse_set = self.sample_graphs_and_weights( n_envs, is_training )
        # y_hist_full               = torch.zeros(( n_envs, self.n_stage, self.n_batch, self.n_nodes, 3)).to(device)
        X_obs_full = torch.zeros(( n_envs, self.n_obs, self.n_nodes, 2)).to(device)
        y_hist_full               = torch.zeros(( n_envs, self.n_stage, self.n_batch, self.n_nodes, 2)).to(device)

        
        for idx_env in range(n_envs): 
            rng                  = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))
            y_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            X_obs                = torch.zeros(( self.n_obs , self.n_nodes, 2 )).to(device)
            X                    = torch.zeros(  self.n_stage,  self.n_batch, self.n_nodes, 2).to(device)

            for t in range( self.n_stage  ):
                int_target_vec = None 
                int_value_vec  = None
                mask_tiled     = torch.zeros( self.n_nodes ).to(device)
                
                if self.training_actor:
                    if t == 0:
                        int_target_vec_0 = torch.zeros(self.n_nodes).to(device)
                        int_value_vec_0 = torch.zeros(self.n_nodes).to(device)
                        x_int_0 = self.sample_recursive_scm_custom(rng = rng, n_obs = 0, 
                                        n_batch = self.n_obs, 
                                        g = graph_set[idx_env], 
                                        f = f_set[idx_env], 
                                        nse = nse_set[idx_env],
                                        int_target = int_target_vec_0, 
                                        int_value = int_value_vec_0,
                                        bias_with_ancestor = False)
                        y_hist_input = x_int_0
                        # print(f"x_int_0: {x_int_0}")
                        d_hist_input = int_target_vec_0.unsqueeze(1).tile(1, self.n_obs).moveaxis(0, 1)   # n_node, n_int
                        
                        X_obs[..., 0] = y_hist_input
                        X_obs[..., 1] = d_hist_input
                    else:                   
                        y_hist_input = y_history[ :t, :, :] 
                        d_hist_input = d_history[ :t, :, :]   # t, n_int, n_nodes
                        actor_input = torch.stack((y_hist_input, d_hist_input), dim = -1)
                        X[:t, :, :, :] = actor_input

                    actor_input_with_stage = X.clone()[:t, ...]
                    X_new = actor_input_with_stage.permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                    combined_X = torch.cat([X_obs, X_new], dim=0)
                    act_input = combined_X

                    int_target_vec, int_value_vec = self.actor_network(act_input, is_training = True, step=self.step) #clone()

                    mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_batch)    # n_node, n_int

                
                x_int = self.sample_recursive_scm_custom(rng  = rng, 
                                                   n_batch    = self.n_batch, 
                                                   g          = graph_set[idx_env], 
                                                   f          = f_set[idx_env], 
                                                   nse        = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value  = int_value_vec, 
                                                   bias_with_ancestor = False)
                
                y_history[ t, :, : ] = x_int
                d_history[ t, :, : ] = mask_tiled.moveaxis(0, 1)

            X_obs_full[ idx_env, ...] = X_obs

            y_hist_full[ idx_env, ...,  0 ] = y_history  
            y_hist_full[ idx_env, ...,  1 ] = d_history  
            
        # y_hist_reshape = y_hist_full.reshape(n_envs, -1,  self.n_nodes, 2)
        # post_net_input = y_hist_reshape 
        y_hist_full_reshape = y_hist_full.reshape(n_envs, -1, self.n_nodes, 2)
        post_net_input = torch.cat([X_obs_full.to(device), y_hist_full_reshape], dim = 1)

                
        return  post_net_input,  torch.from_numpy(np.asarray(graph_set)).to(device) 


    
    def train(self, hist_y_d, graph, n_envs = 10, n_update_posterior = 5, training_post = True, n_update = None):  
        if self.training_actor:
            if training_post:                        # train_posterior
                self.post_network.train()
                self.actor_network.eval()
                for t in range(n_update_posterior):
                    self.post_optimizer.zero_grad()
                    log_prob_post = self.post_network( hist_y_d[ :, : , :, :].detach(), graph.detach(), True)[0]
                    loss = log_prob_post 
                    loss.backward()  
                    torch.nn.utils.clip_grad_norm_(self.post_network.parameters(), 1)
                    self.post_optimizer.step()   
                if n_update % 1000 == 0 and n_update > 1:     
                    self.post_scheduler.step()
                self.post_network.eval()
            else: 
                self.post_network.eval()
                self.actor_network.train()
                log_prob_post = self.post_network(hist_y_d, graph, False)[0]
        else:
            if training_post:                  
                self.post_network.train()
                for t in range(n_update_posterior):
                    self.post_optimizer.zero_grad()
                    log_prob_post = self.post_network( hist_y_d[ :, : , :, :].detach(), graph.detach(), True)[0]
                    loss = log_prob_post 
                    loss.backward()  
                    torch.nn.utils.clip_grad_norm_(self.post_network.parameters(), 1)
                    self.post_optimizer.step()   
                if n_update % 1000 == 0 and n_update > 1:     
                    self.post_scheduler.step()
                self.post_network.eval()
            else: 
                self.post_network.eval()
                log_prob_post = self.post_network(hist_y_d, graph, False)[0]

        return log_prob_post
    
    def forward(self, n_envs = 10, n_update_posterior = 10 , n_update = None, is_training = True):
        
        torch.manual_seed(self.step+1)
        # y_hist, graph_torch = self.model(n_envs = n_envs, is_training = is_training)
        
        if is_training is not True:
            self.actor_network.eval() 
            self.post_network.eval()
            with torch.no_grad():
                graph_set, f_set, nse_set = self.sample_graphs_and_weights(n_envs, is_training = False)
                y_hist_trained, graph_torch_trained = self.model_eval_trained_policy(n_envs, is_training = False, graph_set = graph_set, f_set = f_set, nse_set = nse_set)
                log_prob_post_trained = self.post_network(y_hist_trained.to(device), graph_torch_trained.to(device), False)[0]

                y_hist_untrained, graph_torch_untrained = self.model_eval_untrained_policy(n_envs, is_training = False, graph_set = graph_set, f_set = f_set, nse_set = nse_set)
                log_prob_post_untrained = self.post_network(y_hist_untrained.to(device), graph_torch_untrained.to(device), False)[0]

                y_hist_random, graph_torch_random = self.model_eval_random(n_envs, is_training = False, graph_set = graph_set, f_set = f_set, nse_set = nse_set)
                log_prob_post_random = self.post_network(y_hist_random.to(device), graph_torch_random.to(device), False)[0]

            return (log_prob_post_trained, log_prob_post_untrained, log_prob_post_random)
        else:
            y_hist, graph_torch = self.model(n_envs = n_envs, is_training = is_training)
            self.actor_network.train() #train actor
            self.post_network.eval()
            self.actor_optimizer.zero_grad()

            log_prob_post =  self.train(hist_y_d = y_hist.to(device), graph = graph_torch.to(device), 
                                        n_update_posterior = n_update_posterior, training_post = False).mean()      # notice this negative sign
            loss_actor = log_prob_post   
            loss_actor.backward()  
            torch.nn.utils.clip_grad_norm_(self.actor_network.parameters(), 1)
            self.actor_optimizer.step()

            if n_update % 1000  == 0 and n_update > 1: 
                self.actor_scheduler.step()
            self.actor_network.eval()

            log_prob_post = self.train(hist_y_d = y_hist.to(device), graph = graph_torch.to(device),
                                       n_update_posterior = n_update_posterior, training_post = True, n_update = n_update).mean()

            self.step += 1
        
        return (log_prob_post, loss_actor)

    def model_eval_trained_policy(self, n_envs, is_training = False, graph_set = None, f_set = None, nse_set = None):  
        
        X_obs_full = torch.zeros(( n_envs, self.n_obs, self.n_nodes, 2)).to(device)
        y_hist_full               = torch.zeros(( n_envs, self.n_stage, self.n_batch, self.n_nodes, 2)).to(device)

        
        for idx_env in range(n_envs): 
            rng                  = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))
            y_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            X_obs                = torch.zeros(( self.n_obs , self.n_nodes, 2 )).to(device)
            X                    = torch.zeros(  self.n_stage,  self.n_batch, self.n_nodes, 2).to(device)

            for t in range( self.n_stage  ):
                int_target_vec = None 
                int_value_vec  = None
                mask_tiled     = torch.zeros( self.n_nodes ).to(device)
                
                if self.training_actor:
                    if t == 0:
                        int_target_vec_0 = torch.zeros(self.n_nodes).to(device)
                        int_value_vec_0 = torch.zeros(self.n_nodes).to(device)
                        x_int_0 = self.sample_recursive_scm_custom(rng = rng, n_obs = 0, 
                                        n_batch = self.n_obs, 
                                        g = graph_set[idx_env], 
                                        f = f_set[idx_env], 
                                        nse = nse_set[idx_env],
                                        int_target = int_target_vec_0, 
                                        int_value = int_value_vec_0,
                                        bias_with_ancestor = False)
                        y_hist_input = x_int_0
                        d_hist_input = int_target_vec_0.unsqueeze(1).tile(1, self.n_obs).moveaxis(0, 1)   # n_node, n_int
                        
                        X_obs[..., 0] = y_hist_input
                        X_obs[..., 1] = d_hist_input
                    else:                   
                        y_hist_input = y_history[ :t, :, :] 
                        d_hist_input = d_history[ :t, :, :]   # t, n_int, n_nodes
                        actor_input = torch.stack((y_hist_input, d_hist_input), dim = -1)
                        X[:t, :, :, :] = actor_input

                    actor_input_with_stage = X.clone()[:t, ...]
                    X_new = actor_input_with_stage.permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                    combined_X = torch.cat([X_obs, X_new], dim=0)
                    act_input = combined_X

                    int_target_vec, int_value_vec = self.actor_network(act_input, is_training = True, step=self.step) #clone()

                    mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_batch)    # n_node, n_int

                
                x_int = self.sample_recursive_scm_custom(rng  = rng, 
                                                   n_batch    = self.n_batch, 
                                                   g          = graph_set[idx_env], 
                                                   f          = f_set[idx_env], 
                                                   nse        = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value  = int_value_vec, 
                                                   bias_with_ancestor = False)
                
                y_history[ t, :, : ] = x_int
                d_history[ t, :, : ] = mask_tiled.moveaxis(0, 1)
                                
            X_obs_full[ idx_env, ...] = X_obs
            y_hist_full[ idx_env, ...,  0 ] = y_history  
            y_hist_full[ idx_env, ...,  1 ] = d_history  
            
        y_hist_full_reshape = y_hist_full.reshape(n_envs, -1, self.n_nodes, 2)
        post_net_input = torch.cat([X_obs_full.to(device), y_hist_full_reshape], dim = 1)
                
        return  post_net_input,  torch.from_numpy(np.asarray(graph_set)).to(device) 

    def model_eval_untrained_policy(self, n_envs, is_training = False, graph_set = None, f_set = None, nse_set = None):  
        
        X_obs_full = torch.zeros(( n_envs, self.n_obs, self.n_nodes, 2)).to(device)
        y_hist_full               = torch.zeros(( n_envs, self.n_stage, self.n_batch, self.n_nodes, 2)).to(device)

        for idx_env in range(n_envs): 
            rng                  = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))
            y_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            X_obs                = torch.zeros(( self.n_obs , self.n_nodes, 2 )).to(device)
            X                    = torch.zeros(  self.n_stage,  self.n_batch, self.n_nodes, 2).to(device)

            for t in range( self.n_stage  ):
                int_target_vec = None 
                int_value_vec  = None
                mask_tiled     = torch.zeros( self.n_nodes ).to(device)
                
                if self.training_actor:
                    if t == 0:
                        int_target_vec_0 = torch.zeros(self.n_nodes).to(device)
                        int_value_vec_0 = torch.zeros(self.n_nodes).to(device)
                        x_int_0 = self.sample_recursive_scm_custom(rng = rng, n_obs = 0, 
                                        n_batch = self.n_obs, 
                                        g = graph_set[idx_env], 
                                        f = f_set[idx_env], 
                                        nse = nse_set[idx_env],
                                        int_target = int_target_vec_0, 
                                        int_value = int_value_vec_0,
                                        bias_with_ancestor = False)
                        y_hist_input = x_int_0
                        d_hist_input = int_target_vec_0.unsqueeze(1).tile(1, self.n_obs).moveaxis(0, 1)   # n_node, n_int
                        
                        X_obs[..., 0] = y_hist_input
                        X_obs[..., 1] = d_hist_input
                    else:                   
                        y_hist_input = y_history[ :t, :, :] 
                        d_hist_input = d_history[ :t, :, :]   # t, n_int, n_nodes
                        actor_input = torch.stack((y_hist_input, d_hist_input), dim = -1)
                        X[:t, :, :, :] = actor_input

                    # Use X_filled instead of X for further operations
                    # X_new = X.clone().permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                    actor_input_with_stage = X.clone()[:t, ...]
                    X_new = actor_input_with_stage.permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                    combined_X = torch.cat([X_obs, X_new], dim=0)
                    act_input = combined_X

                    int_target_vec, int_value_vec = self.actor_network_untrained(act_input, is_training = True, step=self.step) #clone()

                    mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_batch)    # n_node, n_int

                
                x_int = self.sample_recursive_scm_custom(rng  = rng, 
                                                   n_batch    = self.n_batch, 
                                                   g          = graph_set[idx_env], 
                                                   f          = f_set[idx_env], 
                                                   nse        = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value  = int_value_vec, 
                                                   bias_with_ancestor = False)
                
                y_history[ t, :, : ] = x_int
                d_history[ t, :, : ] = mask_tiled.moveaxis(0, 1)
                                
            X_obs_full[ idx_env, ...] = X_obs
            y_hist_full[ idx_env, ...,  0 ] = y_history  
            y_hist_full[ idx_env, ...,  1 ] = d_history  
            
        y_hist_full_reshape = y_hist_full.reshape(n_envs, -1, self.n_nodes, 2)
        post_net_input = torch.cat([X_obs_full.to(device), y_hist_full_reshape], dim = 1)

                    
        return  post_net_input,  torch.from_numpy(np.asarray(graph_set)).to(device) 

    def model_eval_random(self, n_envs, is_training = False, graph_set = None, f_set = None, nse_set = None):  
        
        y_hist_full               = torch.zeros(( n_envs, self.n_stage, self.n_batch, self.n_nodes, 2)).to(device)

        
        for idx_env in range(n_envs): 
            rng                  = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))
            y_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_batch, self.n_nodes )).to(device) 
            X                    = torch.zeros(  self.n_stage,  self.n_batch, self.n_nodes, 2).to(device)

            for t in range( self.n_stage  ):
                int_target_vec = None 
                int_value_vec  = None
                mask_tiled     = torch.zeros( self.n_nodes ).to(device)
                
                if self.training_actor:
                    int_target_vec = torch.zeros(self.n_nodes).to(device)
                    int_target = np.random.choice(self.n_nodes,1)[0]
                    int_target_vec[int(int_target)] = 1
                    int_value_vec = (self.actor_network.max_val - self.actor_network.min_val) * torch.rand(self.n_nodes) + self.actor_network.min_val

                    mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_batch)    # n_node, n_int
                    d_history[ t, :, : ] = mask_tiled.moveaxis(0, 1)
                
                x_int = self.sample_recursive_scm_custom(rng  = rng, 
                                                   n_batch    = self.n_batch, 
                                                   g          = graph_set[idx_env], 
                                                   f          = f_set[idx_env], 
                                                   nse        = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value  = int_value_vec, 
                                                   bias_with_ancestor = False)
                
                y_history[ t, :, : ] = x_int                     
                                
            y_hist_full[ idx_env, ...,  0 ] = y_history  
            y_hist_full[ idx_env, ...,  1 ] = d_history  
            
            y_hist_reshape = y_hist_full.reshape(n_envs, -1,  self.n_nodes, 2)
            post_net_input = y_hist_reshape 
                
        return  post_net_input,  torch.from_numpy(np.asarray(graph_set)).to(device) 


class Post_sampling(nn.Module):
    def __init__(self, n_nodes, post_prob, num_simulation, x_obs, 
                       learning_rate = 5e-2, num_epochs = 30000,
                       warm_samples = 200, post_samples = 200, hidden_1 = 8, hidden_2 = 8,
                       prior_scale = 1, split_ratio = 4/5):
        
        super().__init__()

        self.x_obs                 = x_obs
        self.n_nodes               = n_nodes
        self.post_prob             = post_prob
        self.posterior_weights     = [None] * n_nodes
        self.warm_samples          = warm_samples
        self.post_samples          = post_samples
        self.prior_scale           = prior_scale
        self.num_simulation        = num_simulation
        
        self.hidden_1 = hidden_1
        self.hidden_2 = hidden_2
        
        self.split_ratio           = split_ratio
        
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        
        self.graph_probabilities   = None
        self.nodes_unique_parents  = [self.simulate_parents(i) for i in range( self.n_nodes )]
        
    def simulate_parents(self, node_id,  seed = 547467524):
        torch.manual_seed(seed)
        parents = {} 
        bernoulli_dist    = torch.distributions.Bernoulli(self.post_prob[:, node_id])

        for _ in range(self.num_simulation):
            simulated_parents   = bernoulli_dist.sample()
            parent_nodes        = tuple(simulated_parents.tolist()) 

            if parent_nodes not in parents:
                parents[parent_nodes] = {'count': 0, 'probability': 0}
            parents[parent_nodes]['count'] += 1

        for parent_nodes, info in parents.items():
            info['probability'] = info['count'] / self.num_simulation 

        sorted_parents = dict(
            sorted(parents.items(), key=lambda item: item[1]['count'], reverse=True) )

        return sorted_parents
    
    def perform_inference(self, node_id, model, parent_nodes, check_convergence = True):
        
        pyro.set_rng_seed(42)   
        posterior_samples = {}
        mean_field_guide = AutoDiagonalNormal(model)

        scheduler = pyro.optim.ExponentialLR({
                'optimizer': torch.optim.Adam,  # Use the optimizer class (not instance)
                'optim_args': {'lr': self.learning_rate},  # Optimizer arguments
                'gamma': 0.9995  # Decay rate
            })
        svi = SVI(model, mean_field_guide, scheduler, loss=Trace_ELBO())

        pyro.clear_param_store()
        #progress_bar = trange(self.num_epochs)

        for epoch in range(self.num_epochs): #progress_bar:
            loss = svi.step(x = self.x_obs, x_parents = parent_nodes, y = self.x_obs[:, node_id])
            scheduler.step() 
            #progress_bar.set_postfix(loss=f"{loss / self.x_obs.shape[0]:.3f}")

        mu    = mean_field_guide.get_posterior().mean
        sigma = mean_field_guide.get_posterior().stddev
        
        layer1_weights_mean = mu[0: self.hidden_1 * self.n_nodes].reshape(self.hidden_1, self.n_nodes)
        layer1_bias_mean    = mu[self.hidden_1 * self.n_nodes: self.hidden_1 * self.n_nodes + self.hidden_1]

        layer2_weights_mean = mu[self.hidden_1 * self.n_nodes + self.hidden_1 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2].reshape(self.hidden_2, self.hidden_1)
        layer2_bias_mean    = mu[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2]
     
        layer3_weights_mean = mu[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2].reshape(1, self.hidden_2)
        layer3_bias_mean    = mu[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2 + 1]
             
        layer1_weights_scale = sigma[0: self.hidden_1 * self.n_nodes].reshape(self.hidden_1, self.n_nodes)
        layer1_bias_scale    = sigma[self.hidden_1 * self.n_nodes: self.hidden_1 * self.n_nodes + self.hidden_1]

        layer2_weights_scale = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2].reshape(self.hidden_2, self.hidden_1)
        layer2_bias_scale    = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2]
     
        layer3_weights_scale = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 : self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2].reshape(1, self.hidden_2)
        layer3_bias_scale    = sigma[self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2: self.hidden_1 * self.n_nodes + self.hidden_1 + self.hidden_1 * self.hidden_2 + self.hidden_2 + self.hidden_2 + 1]
             
        posterior_samples['layer1_weights_mean'] = layer1_weights_mean
        posterior_samples['layer1_bias_mean']    = layer1_bias_mean
        
        posterior_samples['layer2_weights_mean'] = layer2_weights_mean
        posterior_samples['layer2_bias_mean']    = layer2_bias_mean
        
        posterior_samples['layer3_weights_mean'] = layer3_weights_mean
        posterior_samples['layer3_bias_mean']    = layer3_bias_mean
        
        posterior_samples['layer1_weights_scale'] = layer1_weights_scale
        posterior_samples['layer1_bias_scale']    = layer1_bias_scale
        
        posterior_samples['layer2_weights_scale'] = layer2_weights_scale
        posterior_samples['layer2_bias_scale']    = layer2_bias_scale
        
        posterior_samples['layer3_weights_scale'] = layer3_weights_scale
        posterior_samples['layer3_bias_scale']    = layer3_bias_scale
        
        return posterior_samples
    
    def infer_node(self, node_id):

        parent_structure = self.nodes_unique_parents[node_id]
        post_samples_id  = {}

        for idx, (parents, info) in enumerate( parent_structure.items()):
            print('-----parent_structure------', idx, parents )
            model             = Pyro_NN(in_dim=self.n_nodes, out_dim=1, prior_scale=self.prior_scale )
            posterior_samples = self.perform_inference(node_id = node_id, model = model, parent_nodes = torch.tensor(list(parents)))
            post_samples_id[parents] = posterior_samples
        self.posterior_weights[node_id] = post_samples_id
    
def save_checkpoint(state, filename):
    torch.save(state, filename)

if __name__ == "__main__":
    kwargs = make_parser().parse_args()
    torch_manual_seed = kwargs.torch_manual_seed
    torch.manual_seed(torch_manual_seed)
    np.random.seed(torch_manual_seed)
    n_nodes = kwargs.n_nodes
    n_stage = kwargs.n_stage
    total_n_envs = kwargs.total_n_envs
    n_obs = kwargs.n_obs
    n_int = kwargs.n_int
    train_n_envs = kwargs.train_n_envs
    eval_n_envs = kwargs.eval_n_envs
    int_value = kwargs.int_value
    n_update_posterior = kwargs.n_update_posterior
    env_seed = kwargs.env_seed


    env_name = kwargs.env
    edge_prob = kwargs.edge_prob
    edge_per_var = math.ceil(edge_prob * n_nodes)
    if env_name == "er":
        env = ErdosRenyi(edge_per_var)
    elif env_name == "sf":
        env = ScaleFree(edge_per_var)
    elif env_name == "yeast":
        env = Yeast()
    elif env_name == "ecoli":
        env = Ecoli()


    actor_network = policy_opt(dropout=0.05, layers=4, dim=64, key_size=32, num_heads=8, widening_factor=4, input_shape=2).to(device)
    post_network = BaseModel_torch(layers= 8, num_heads= 8, key_size= 64, dim = 128, dropout = 0.05, input_shape = 2 ).to(device)

    actor_network_untrained = policy_opt(dropout=0.05, layers=4, dim=64, key_size=32, num_heads=8, widening_factor=4, input_shape=2).to(device)

    num_steps = kwargs.num_steps  

    seq_graph     = Sequential_Graph(
                                    hidden_size = 8, 
                                    activation = 'ReLU', 
                                    prior_scale = 1,
                                    noise =  Gaussian_torch(0, 0.1**0.5),
                                    noise_scale_constant = True, 
                                    training_actor = True, 
                                    n_stage = n_stage, 
                                    n_batch = n_int ,
                                    n_obs = n_obs,
                                    n_nodes = n_nodes,
                                    actor_network = actor_network,
                                    actor_network_untrained = actor_network_untrained,
                                    post_network = post_network,
                                    env = env,
                                    post_lr = 1e-4,  post_gamma = 0.8, 
                                    actor_lr = 1e-4, actor_gamma = 0.8,
                                    post_sampling_instance = None).to(device) 




    saved_dir_env = kwargs.save_dir + "/" + "fresh_post_" + env_name
    if not os.path.exists(saved_dir_env):
        os.makedirs(saved_dir_env)
    saved_dir_node = saved_dir_env + "/saved_dir_node_{}_env_seed_{}_torch_seed_{}".format(n_nodes, env_seed, torch_manual_seed)
    if not os.path.exists(saved_dir_node):
        os.makedirs(saved_dir_node)
    checkpoint_dir = saved_dir_node + "/checkpoints_{}_{}_{}_{}_env_seed_{}_train_steps_{}_n_envs_{}".format(n_nodes, n_stage, n_int, n_update_posterior, env_seed, num_steps, train_n_envs)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    loss_set_trained = []
    loss_set_eval_trained = []
    loss_set_eval_untrained = []
    loss_set_eval_random = []
    for step in tqdm(range(num_steps)):
        log_kld, loss_actor = seq_graph.forward(n_envs = train_n_envs, n_update_posterior = n_update_posterior , n_update = step)
        
        torch.cuda.empty_cache()
        print(f"Training: Step {step}, log_kld: {log_kld}, loss_actor: {loss_actor}", flush=True)

        if step % kwargs.eval_every == 1:
            log_kld_eval_trained, log_kld_eval_untrained, log_kld_eval_random = seq_graph.forward(n_envs = eval_n_envs, n_update_posterior = n_update_posterior , n_update = step, is_training = False)
            print(f"Evaluation: Step {step}, log_kld_eval_trained_policy: {log_kld_eval_trained}, log_kld_eval_untrained_policy: {log_kld_eval_untrained}, log_kld_eval_random: {log_kld_eval_random}", flush=True)
            loss_set_trained.append(loss_actor.detach().cpu().numpy())
            loss_set_eval_trained.append(log_kld_eval_trained.detach().cpu().numpy())
            loss_set_eval_untrained.append(log_kld_eval_untrained.detach().cpu().numpy())
            loss_set_eval_random.append(log_kld_eval_random.detach().cpu().numpy())

        if step % kwargs.save_every == 0:
            checkpoint = {
                'step': step,
                'actor_network_state_dict': seq_graph.actor_network.state_dict(),
                'post_network_state_dict': seq_graph.post_network.state_dict(),
                'actor_optimizer_state_dict': seq_graph.actor_optimizer.state_dict(),
                'post_optimizer_state_dict': seq_graph.post_optimizer.state_dict(),
                'actor_scheduler_state_dict': seq_graph.actor_scheduler.state_dict(),
                'post_scheduler_state_dict': seq_graph.post_scheduler.state_dict(),
                'log_kld': float(log_kld),
            }
            save_checkpoint(checkpoint, filename=f"{checkpoint_dir}/checkpoint_step_{step}.pth.tar")
        
            torch.cuda.empty_cache()  # Clear cache after saving checkpoint

    trained_exp_results = {}
    trained_exp_results["loss_set_trained"] = loss_set_trained
    trained_exp_results["loss_set_eval_trained"] = loss_set_eval_trained
    trained_exp_results["loss_set_eval_untrained"] = loss_set_eval_untrained
    trained_exp_results["loss_set_eval_random"] = loss_set_eval_random

        
    results_trained_policy = []
    results_untrained_policy = []
    results_random_policy = []

    for i in range(20):

        log_kld_eval_trained, log_kld_eval_untrained, log_kld_eval_random = seq_graph.forward(n_envs = eval_n_envs, n_update_posterior = n_update_posterior , n_update = step, is_training = False)
        
        results_trained_policy.append(log_kld_eval_trained.detach().cpu().numpy())
        results_untrained_policy.append(log_kld_eval_untrained.detach().cpu().numpy())
        results_random_policy.append(log_kld_eval_random.detach().cpu().numpy())


    print("--------------------Results------------------------")
    print(f"Mean results of trained policy: {np.mean(results_trained_policy)}, Std results of trained policy: {np.std(results_trained_policy)}")
    print(f"Mean results of untrained policy: {np.mean(results_untrained_policy)}, Std results of untrained policy: {np.std(results_untrained_policy)}")
    print(f"Mean results of random interventions: {np.mean(results_random_policy)}, Std results of random interventions: {np.std(results_random_policy)}")

    with open(f"{checkpoint_dir}/trained_exp_results.pkl", "wb") as f:
        pickle.dump(trained_exp_results, f)

