import numpy as np
import math
import sys
from tqdm import tqdm
from utils import mat_to_toporder
from synthetic import Uniform_torch
from synthetic import Gaussian_torch, Uniform_torch
from synthetic import ErdosRenyi, ScaleFree, Yeast, Ecoli
from synthetic.noise_scale import init_noise_dist_torch
import os
import torch
import torch.utils
import torch.nn as nn
import torch.nn.functional as F
import argparse
import pickle
import gc
from tqdm import tqdm
from models import policy_opt, BaseModel_torch

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev) 

# Rest of the file remains the same

def make_parser():

    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="er", type=str, help="Enviroment name")
    parser.add_argument("--edge_prob", default=0.2, type=float, help="expected edges per node")
    parser.add_argument("--num_steps", default=200, type=int, help="total number of steps")
    parser.add_argument("--n_nodes", default=10, type=int, help="number of nodes")
    parser.add_argument("--n_stage", default=5, type=int, help="number of stages for intervention")
    parser.add_argument("--torch_manual_seed", default=123, type=int, help="torch manual seed")
    parser.add_argument("--env_seed", default=283956, type=int, help="torch manual seed")
    parser.add_argument("--n_int", default=5, type=int, help="number of interventions")
    parser.add_argument("--n_obs", default=30, type=int, help="number of observations")
    parser.add_argument("--train_n_envs", default=10, type=int, help="number of envs in each training step")
    parser.add_argument("--eval_n_envs", default=10, type=int, help="number of envs in each evaluation step")
    parser.add_argument("--n_update_posterior", default=5, type=int, help="number of updates of posterior network")

    parser.add_argument("--int_value", default=5.0, type=float, help="intervention value")

    parser.add_argument("--eval_every", default=5, type=int, help="evaluate after every n steps")
    parser.add_argument("--save_every", default=500, type=int, help="save model after every n steps")

    parser.add_argument("--save_dir", default="results/model_ckpts_linear", 
                        type=str, help="save directory of the model and results")

    

    return parser


class Sequential_Graph(nn.Module):
    def __init__(self, param =  Gaussian_torch(0, 1.414), bias =  Uniform_torch(-1, 1), noise = Gaussian_torch(0, 0.316), 
                       noise_scale_constant = True, noise_scale = None, noise_scale_heteroscedastic = None ,
                       n_stage = 10,
                       interv_dist = None, bias_with_ancestor = False, n_nodes = 10,
                       actor_network = None, actor_network_untrained = None, post_network = None,
                       env_avici = ErdosRenyi(edges_per_var = 2), n_int = 5, n_obs = 50, int_value = 5.0,
                       post_lr = 1e-4, post_gamma = 0.8, actor_lr = 1e-4, actor_gamma = 0.8):
        
        super().__init__()
        self.param                = param
        self.bias                 = bias
        self.bias_with_ancestor   = bias_with_ancestor
        self.noise                = noise
        self.noise_scale          = noise_scale
        self.noise_scale_constant = noise_scale_constant
        self.noise_scale_heteroscedastic = noise_scale_heteroscedastic
        self.interv_dist   = interv_dist
        
        
        self.env_avici = env_avici
        self.n_nodes   = n_nodes
        
        self.step = 0
        self.n_stage = n_stage

        self.n_int = n_int
        self.n_obs = n_obs

        self.int_value = int_value
        
        self.actor_network = actor_network       
        self.actor_optimizer = torch.optim.Adam(self.actor_network.parameters(), lr = actor_lr)
        self.actor_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.actor_optimizer, gamma = actor_gamma)

        self.actor_network_untrained = actor_network_untrained
        
        self.post_network = post_network
        self.post_optimizer = torch.optim.Adam(self.post_network.parameters(), lr= post_lr)
        self.post_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.post_optimizer, gamma = post_gamma)


    def sample_graphs_and_weights(self, n_envs, seed = 0):
        rng_numpy         = np.random.default_rng(seed=int(seed +1))
        rng_torch         = torch.Generator().manual_seed(int(seed +1)) #

        graph_set   = []
        f_set       = []
        nse_set     = []

        for i in range(n_envs):
            simulated_g = self.env_avici(rng_numpy, self.n_nodes)
            graph_set.append(simulated_g) 
                 
            f = []

            for j in range(self.n_nodes):
                w =  self.param(rng_torch, shape=(self.n_nodes,)).to(device)  
                b =  self.bias(rng_torch, shape=(1,)).to(device)              
                f.append(lambda x, is_parent, z, theta=w, bias=b: (x @ (theta * is_parent)) + bias + z)

                
                nse = []
                for j in range(self.n_nodes):
                    nse.append(init_noise_dist_torch(rng=rng_torch,
                                            seed=seed,
                                            dim=int(simulated_g[:, j].sum().item()),
                                            dist=self.noise,
                                            noise_scale_constant=self.noise_scale_constant,
                                            noise_scale=self.noise_scale,
                                            noise_scale_heteroscedastic=self.noise_scale_heteroscedastic))
            f_set.append( f )
            nse_set.append( nse )

        return graph_set, f_set, nse_set

    def sample_recursive_scm(self, rng, n_obs, n_int, g, f, nse,
                 int_target = None, int_value = None, bias_with_ancestor = False):
        n_vars   = g.shape[-1]
        toporder = mat_to_toporder(g)
        
        x_int = torch.zeros((n_int, n_vars), requires_grad=True).to(device)
    
        for j in toporder:
            int_percentage = int_target[j]
            
            z_j = nse[j](rng=rng, x=x_int, is_parent= torch.from_numpy(g[:, j])).to(device) 

            x_j   = f[j](x=x_int, z=z_j, is_parent=torch.from_numpy(g[:, j]).to(device))
            raw_contrib = x_j 
            x_int[:, j] = int_percentage * int_value[j] + (1-int_percentage) * raw_contrib
            
        return x_int  

    def standardize(self,  y_hist):
        

        ref_y = y_hist[..., 0]

        mean  = ref_y.mean(dim=-2, keepdim=True)
        std   = ref_y.std(dim=-2, keepdim=True)
        ref_y = (ref_y - mean) / torch.where(std == 0.0, torch.tensor(1.0, dtype=std.dtype, device=device), std)
        
        return ref_y


    def model(self, n_envs, if_eval=False):
        graph_set, f_set, nse_set = self.sample_graphs_and_weights( n_envs , seed=self.step)
        X_obs_full = torch.zeros(( n_envs, self.n_obs, self.n_nodes, 2)).to(device)
        y_hist_full = torch.zeros(( n_envs, self.n_stage, self.n_int, self.n_nodes, 2)).to(device)

        for idx_env in range(n_envs): 
            rng         = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))

            y_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            X_obs                = torch.zeros(( self.n_obs , self.n_nodes, 2 )).to(device)
            X                    = torch.zeros(  self.n_stage,  self.n_int, self.n_nodes, 2).to(device)

            for t in range(self.n_stage):

                if t == 0:
                    int_target_vec_0 = torch.zeros(self.n_nodes).to(device)
                    int_value_vec_0 = torch.zeros(self.n_nodes).to(device)
                    x_int_0 = self.sample_recursive_scm(rng = rng, n_obs = 0, 
                                    n_int = self.n_obs, 
                                    g = graph_set[idx_env], 
                                    f = f_set[idx_env], 
                                    nse = nse_set[idx_env],
                                    int_target = int_target_vec_0, 
                                    int_value = int_value_vec_0,
                                    bias_with_ancestor = False)
                    y_hist_input = x_int_0
                    d_hist_input = int_target_vec_0.unsqueeze(1).tile(1, self.n_obs).moveaxis(0, 1)   # n_node, n_int
                    
                    X_obs[..., 0] = y_hist_input
                    X_obs[..., 1] = d_hist_input
                else:                   
                    y_hist_input = y_history[ :t, :, :] 
                    d_hist_input = d_history[ :t, :, :]   # t, n_int, n_nodes
                    actor_input = torch.stack((y_hist_input, d_hist_input), dim = -1)
                    X[:t, :, :, :] = actor_input

                actor_input_with_stage = X.clone()[:t, ...]
                X_new = actor_input_with_stage.permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                combined_X = torch.cat([X_obs, X_new], dim=0)
                act_input = combined_X

                int_target_vec, int_value_vec = self.actor_network(act_input, is_training = True, step=self.step) #clone()

                mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_int)    # n_node, n_int

                x_int = self.sample_recursive_scm(rng = rng, n_obs = 0, 
                                                   n_int = self.n_int, 
                                                   g = graph_set[idx_env], 
                                                   f = f_set[idx_env], 
                                                   nse = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value = int_value_vec, 
                                                   bias_with_ancestor = False)

            
                
                x_int = torch.stack([x_int, mask_tiled.moveaxis(0, 1)], dim = -1)
                y_history[ t, :, : ] = x_int[...,0]
                d_history[ t, :, : ] = x_int[...,1]

            X_obs_full[ idx_env, ...] = X_obs
            y_hist_full[ idx_env, ..., 0 ] = y_history
            y_hist_full[ idx_env, ..., 1 ] = d_history
            
        y_hist_full_reshape = y_hist_full.reshape(n_envs, -1, self.n_nodes, 2)
        post_net_input = torch.cat([X_obs_full.to(device), y_hist_full_reshape], dim = 1)

        return post_net_input ,  torch.from_numpy(np.asarray(graph_set)).to(device)
    
    def train(self, hist_y_d, graph, n_envs = 10, n_update = 5, training_post = True, n_update_actor = None):  
        if training_post:                        # train_posterior
            self.post_network.train()
            self.actor_network.eval()
            for t in range(n_update):
                self.post_optimizer.zero_grad()
                log_prob_post = self.post_network( hist_y_d.detach(), graph.detach(), True)[0]
                loss = log_prob_post 
                loss.backward()  
                torch.nn.utils.clip_grad_norm_(self.post_network.parameters(), 5)
                self.post_optimizer.step()   
            if n_update_actor % 1000 == 0 and n_update_actor > 1:     
                self.post_scheduler.step()
                
            self.post_network.eval()
        else: 
            self.post_network.eval()
            self.actor_network.train()
            log_prob_post = self.post_network(hist_y_d, graph, False)[0]

        return log_prob_post
    
    def forward(self, n_envs = 10, n_update_posterior = 10 , n_update_actor = None, if_eval = False):
        
        torch.manual_seed(self.step+1)
        if if_eval:
            self.post_network.eval()
            self.actor_network.eval()
            with torch.no_grad():
                graph_set, f_set, nse_set = self.sample_graphs_and_weights(n_envs, seed=12789*self.step + 121)
                y_hist_trained, graph_torch_trained = self.model_eval_trained_policy(n_envs, if_eval = True, graph_set = graph_set, f_set = f_set, nse_set = nse_set)
                log_prob_post_trained = self.post_network(y_hist_trained.to(device), graph_torch_trained.to(device), False)[0]

                y_hist_untrained, graph_torch_untrained = self.model_eval_untrained_policy(n_envs, if_eval = True, graph_set = graph_set, f_set = f_set, nse_set = nse_set)
                log_prob_post_untrained = self.post_network(y_hist_untrained.to(device), graph_torch_untrained.to(device), False)[0]

                y_hist_random, graph_torch_random = self.model_eval_random(n_envs, if_eval = True, graph_set = graph_set, f_set = f_set, nse_set = nse_set)
                log_prob_post_random = self.post_network(y_hist_random.to(device), graph_torch_random.to(device), False)[0]

            return (log_prob_post_trained, log_prob_post_untrained, log_prob_post_random)

        else:
            y_hist, graph_torch = self.model(n_envs)

            self.actor_network.train()
            self.post_network.eval()
            self.actor_optimizer.zero_grad()
            
            log_prob_post =  self.train(hist_y_d = y_hist.to(device), graph = graph_torch.to(device), 
                                        n_update = n_update_posterior, training_post = False).mean()      
            loss_actor = log_prob_post  
            loss_actor.backward()  
            torch.nn.utils.clip_grad_norm_(self.actor_network.parameters(), 5)
            self.actor_optimizer.step()
            if n_update_actor % 1000  == 0 and n_update_actor > 1: 
                self.actor_scheduler.step()
            self.actor_network.eval()

            log_prob_post = self.train(hist_y_d = y_hist.to(device), graph = graph_torch.to(device),
                                    n_update = n_update_posterior, training_post = True, n_update_actor = n_update_actor).mean()
                
            self.step += 1
        
        return (log_prob_post, loss_actor) 
        

    def model_eval_trained_policy(self, n_envs, if_eval=False, graph_set = None, f_set = None, nse_set = None):

        X_obs_full = torch.zeros(( n_envs, self.n_obs, self.n_nodes, 2)).to(device)
        y_hist_full = torch.zeros(( n_envs, self.n_stage, self.n_int, self.n_nodes, 2)).to(device)

        for idx_env in range(n_envs): 
            our_int_targets = np.zeros(self.n_stage)

            rng         = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))

            y_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            X_obs                = torch.zeros(( self.n_obs , self.n_nodes, 2 )).to(device)
            X                    = torch.zeros(  self.n_stage,  self.n_int, self.n_nodes, 2).to(device)

            for t in range(self.n_stage ):


                if t == 0:
                    int_target_vec_0 = torch.zeros(self.n_nodes).to(device)
                    int_value_vec_0 = torch.zeros(self.n_nodes).to(device)
                    x_int_0 = self.sample_recursive_scm(rng = rng, n_obs = 0, 
                                    n_int = self.n_obs, 
                                    g = graph_set[idx_env], 
                                    f = f_set[idx_env], 
                                    nse = nse_set[idx_env],
                                    int_target = int_target_vec_0, 
                                    int_value = int_value_vec_0,
                                    bias_with_ancestor = False)
                    y_hist_input = x_int_0
                    d_hist_input = int_target_vec_0.unsqueeze(1).tile(1, self.n_obs).moveaxis(0, 1)   
                    
                    X_obs[..., 0] = y_hist_input
                    X_obs[..., 1] = d_hist_input
                else:                   
                    y_hist_input = y_history[ :t, :, :] 
                    d_hist_input = d_history[ :t, :, :]   # t, n_int, n_nodes
                    actor_input = torch.stack((y_hist_input, d_hist_input), dim = -1)
                    X[:t, :, :, :] = actor_input

                actor_input_with_stage = X.clone()[:t, ...]
                X_new = actor_input_with_stage.permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                combined_X = torch.cat([X_obs, X_new], dim=0)
                act_input = combined_X

                int_target_vec, int_value_vec = self.actor_network(act_input, is_training = False, step=self.step) #clone()
                mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_int)   

                x_int = self.sample_recursive_scm(rng = rng, n_obs = 0, 
                                                   n_int = self.n_int, 
                                                   g = graph_set[idx_env], 
                                                   f = f_set[idx_env], 
                                                   nse = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value = int_value_vec, 
                                                   bias_with_ancestor = False)

                x_int = torch.stack([x_int, mask_tiled.moveaxis(0, 1)], dim = -1)
                y_history[ t, :, : ] = x_int[...,0]
                d_history[ t, :, : ] = x_int[...,1]
            
            X_obs_full[ idx_env, ...] = X_obs
            y_hist_full[ idx_env, ..., 0 ] = y_history
            y_hist_full[ idx_env, ..., 1 ] = d_history
            
        y_hist_full_reshape = y_hist_full.reshape(n_envs, -1, self.n_nodes, 2)
        post_net_input = torch.cat([X_obs_full.to(device), y_hist_full_reshape], dim = 1)

        return post_net_input ,  torch.from_numpy(np.asarray(graph_set)).to(device)


    def model_eval_untrained_policy(self, n_envs, if_eval=False, graph_set = None, f_set = None, nse_set = None):

        X_obs_full = torch.zeros(( n_envs, self.n_obs, self.n_nodes, 2)).to(device)
        y_hist_full = torch.zeros(( n_envs, self.n_stage, self.n_int, self.n_nodes, 2)).to(device)

        for idx_env in range(n_envs): 
            our_int_targets = np.zeros(self.n_stage)

            rng         = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))

            y_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device)
            X_obs                = torch.zeros(( self.n_obs , self.n_nodes, 2 )).to(device)
            X                    = torch.zeros(  self.n_stage,  self.n_int, self.n_nodes, 2).to(device)
            for t in range(self.n_stage ):
                if t == 0:
                    int_target_vec_0 = torch.zeros(self.n_nodes).to(device)
                    int_value_vec_0 = torch.zeros(self.n_nodes).to(device)
                    x_int_0 = self.sample_recursive_scm(rng = rng, n_obs = 0, 
                                    n_int = self.n_obs, 
                                    g = graph_set[idx_env], 
                                    f = f_set[idx_env], 
                                    nse = nse_set[idx_env],
                                    int_target = int_target_vec_0, 
                                    int_value = int_value_vec_0,
                                    bias_with_ancestor = False)
                    y_hist_input = x_int_0
                    d_hist_input = int_target_vec_0.unsqueeze(1).tile(1, self.n_obs).moveaxis(0, 1)   # n_node, n_int
                    
                    X_obs[..., 0] = y_hist_input
                    X_obs[..., 1] = d_hist_input
                else:                   
                    y_hist_input = y_history[ :t, :, :] 
                    d_hist_input = d_history[ :t, :, :]   # t, n_int, n_nodes
                    actor_input = torch.stack((y_hist_input, d_hist_input), dim = -1)
                    X[:t, :, :, :] = actor_input

                actor_input_with_stage = X.clone()[:t, ...]
                X_new = actor_input_with_stage.permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                combined_X = torch.cat([X_obs, X_new], dim=0)

                act_input = combined_X

                int_target_vec, int_value_vec = self.actor_network_untrained(act_input, is_training = False, step=self.step) #clone()
                mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_int)    # n_node, n_int
                idx_helper          = torch.arange(1, self.n_nodes +1).to(device)
                int_target          = (idx_helper * int_target_vec).sum() - 1

                our_int_targets[t]  = int_target.detach().cpu().numpy()

                x_int = self.sample_recursive_scm(rng = rng, n_obs = 0, 
                                                   n_int = self.n_int, 
                                                   g = graph_set[idx_env], 
                                                   f = f_set[idx_env], 
                                                   nse = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value = int_value_vec, 
                                                   bias_with_ancestor = False)

                x_int = torch.stack([x_int, mask_tiled.moveaxis(0, 1)], dim = -1)
                y_history[ t, :, : ] = x_int[...,0]
                d_history[ t, :, : ] = x_int[...,1]
            
            X_obs_full[ idx_env, ...] = X_obs
            y_hist_full[ idx_env, ..., 0 ] = y_history
            y_hist_full[ idx_env, ..., 1 ] = d_history
            
        y_hist_full_reshape = y_hist_full.reshape(n_envs, -1, self.n_nodes, 2)
        post_net_input = torch.cat([X_obs_full.to(device), y_hist_full_reshape], dim = 1)

        return post_net_input ,  torch.from_numpy(np.asarray(graph_set)).to(device)
    
    def model_eval_random(self, n_envs, if_eval=False, graph_set = None, f_set = None, nse_set = None):

        y_hist_full = torch.zeros(( n_envs, self.n_stage, self.n_int, self.n_nodes, 2)).to(device)

        for idx_env in range(n_envs): 
            our_int_targets = np.zeros(self.n_stage)

            rng         = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))

            y_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 

            for t in range( self.n_stage  ):

                int_target_vec = torch.zeros(self.n_nodes).to(device)
                int_target = np.random.choice(self.n_nodes,1)[0]
                int_target_vec[int(int_target)] = 1
                int_value_vec = (self.actor_network.max_val - self.actor_network.min_val) * torch.rand(self.n_nodes) + self.actor_network.min_val

                mask_tiled = int_target_vec.unsqueeze(1).tile(1, self.n_int)   

                our_int_targets[t]  = int_target

                x_int = self.sample_recursive_scm(rng = rng, n_obs = 0, 
                                                   n_int = self.n_int, 
                                                   g = graph_set[idx_env], 
                                                   f = f_set[idx_env], 
                                                   nse = nse_set[idx_env],
                                                   int_target = int_target_vec, 
                                                   int_value = int_value_vec, 
                                                   bias_with_ancestor = False)

                x_int = torch.stack([x_int, mask_tiled.moveaxis(0, 1)], dim = -1)
                y_history[ t, :, : ] = x_int[...,0]
                d_history[ t, :, : ] = x_int[...,1]

            y_hist_full[ idx_env, ..., 0 ] = y_history
            y_hist_full[ idx_env, ..., 1 ] = d_history
            
        return y_hist_full.reshape(n_envs, -1, self.n_nodes, 2) ,  torch.from_numpy(np.asarray(graph_set)).to(device)

def save_checkpoint(state, filename):
    torch.save(state, filename)

if __name__ == "__main__":
    torch.autograd.set_detect_anomaly(True)
    kwargs = make_parser().parse_args()
    torch_manual_seed = kwargs.torch_manual_seed
    torch.manual_seed(torch_manual_seed)
    n_nodes = kwargs.n_nodes
    n_stage = kwargs.n_stage
    n_int = kwargs.n_int
    n_obs = kwargs.n_obs
    train_n_envs = kwargs.train_n_envs
    eval_n_envs = kwargs.eval_n_envs
    int_value = kwargs.int_value
    n_update_posterior = kwargs.n_update_posterior
    env_seed = kwargs.env_seed
    num_steps = kwargs.num_steps

    env_name = kwargs.env
    edge_prob = kwargs.edge_prob
    edge_per_var = math.ceil(edge_prob * n_nodes)
    if env_name == "er":
        env = ErdosRenyi(edge_per_var)
    elif env_name == "sf":
        env = ScaleFree(edge_per_var)
    elif env_name == "yeast":
        env = Yeast()
    elif env_name == "ecoli":
        env = Ecoli()


    actor_network = policy_opt(dropout=0.05, layers=4, dim=64, key_size=32, num_heads=8, widening_factor=4, input_shape=2).to(device)
    post_network = BaseModel_torch(layers= 8, num_heads= 8, key_size= 64, dim = 128, dropout = 0.05, input_shape = 2 ).to(device)


    actor_network_untrained = policy_opt(dropout=0.05, layers=4, dim=64, key_size=32, num_heads=8, widening_factor=4, input_shape=2).to(device)


    seq_graph = Sequential_Graph(n_int = n_int, n_obs = n_obs, n_nodes = n_nodes, actor_network = actor_network, actor_network_untrained=actor_network_untrained, env_avici=env,
                                 post_network = post_network, n_stage = n_stage, int_value=int_value).to(device)

    actor_network.num_steps = num_steps + 100

    saved_dir_env = kwargs.save_dir + "/" + "linear_w_obs_" + env_name 

    if not os.path.exists(saved_dir_env):
        os.makedirs(saved_dir_env)
    saved_dir_node = saved_dir_env + "/saved_dir_node_{}_env_seed_{}_torch_seed_{}".format(n_nodes, env_seed, torch_manual_seed)
    if not os.path.exists(saved_dir_node):
        os.makedirs(saved_dir_node)
    checkpoint_dir = saved_dir_node + "/checkpoints_{}_{}_{}_{}_env_seed_{}_train_steps_{}_n_envs_{}".format(n_nodes, n_stage, n_int, n_update_posterior, env_seed, num_steps, train_n_envs)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    loss_set_trained = []
    loss_set_eval_trained = []
    loss_set_eval_untrained = []
    loss_set_eval_random = []
    for step in tqdm(range(num_steps)):
        log_kld, loss_actor = seq_graph.forward(n_envs = train_n_envs, n_update_posterior = n_update_posterior , n_update_actor = step)
        
        torch.cuda.empty_cache()

        if step % kwargs.eval_every == 0:
            log_kld_eval_trained, log_kld_eval_untrained, log_kld_eval_random = seq_graph.forward(n_envs = eval_n_envs, n_update_posterior = n_update_posterior , n_update_actor = step, if_eval = True)
            print(f"Evaluation: Step {step}, log_kld_eval_trained_policy: {log_kld_eval_trained}, log_kld_eval_untrained_policy: {log_kld_eval_untrained}, log_kld_eval_random: {log_kld_eval_random}", flush=True)
            loss_set_trained.append(loss_actor.detach().cpu().numpy())
            loss_set_eval_trained.append(log_kld_eval_trained.detach().cpu().numpy())
            loss_set_eval_untrained.append(log_kld_eval_untrained.detach().cpu().numpy())
            loss_set_eval_random.append(log_kld_eval_random.detach().cpu().numpy())

        if step % kwargs.save_every == 0:
            checkpoint = {
                'step': step,
                'actor_network_state_dict': seq_graph.actor_network.state_dict(),
                'post_network_state_dict': seq_graph.post_network.state_dict(),
                'actor_optimizer_state_dict': seq_graph.actor_optimizer.state_dict(),
                'post_optimizer_state_dict': seq_graph.post_optimizer.state_dict(),
                'actor_scheduler_state_dict': seq_graph.actor_scheduler.state_dict(),
                'post_scheduler_state_dict': seq_graph.post_scheduler.state_dict(),
                'log_kld': float(log_kld),
            }
            save_checkpoint(checkpoint, filename=f"{checkpoint_dir}/checkpoint_step_{step}.pth.tar")
        
            torch.cuda.empty_cache()  # Clear cache after saving checkpoint

    trained_exp_results = {}
    trained_exp_results["loss_set_trained"] = loss_set_trained
    trained_exp_results["loss_set_eval_trained"] = loss_set_eval_trained
    trained_exp_results["loss_set_eval_untrained"] = loss_set_eval_untrained
    trained_exp_results["loss_set_eval_random"] = loss_set_eval_random



    results_trained_policy = []
    results_untrained_policy = []
    results_random_policy = []

    for i in range(20):
        log_kld_eval_trained, log_kld_eval_untrained, log_kld_eval_random = seq_graph.forward(n_envs = eval_n_envs, n_update_posterior = n_update_posterior , n_update_actor = step * 1024 + i, if_eval = True)
        
        results_trained_policy.append(log_kld_eval_trained.detach().cpu().numpy())
        results_untrained_policy.append(log_kld_eval_untrained.detach().cpu().numpy())
        results_random_policy.append(log_kld_eval_random.detach().cpu().numpy())


    print("--------------------Results------------------------")
    print(f"Mean results of trained policy: {np.mean(results_trained_policy)}, Std results of trained policy: {np.std(results_trained_policy)}")
    print(f"Mean results of untrained policy: {np.mean(results_untrained_policy)}, Std results of untrained policy: {np.std(results_untrained_policy)}")
    print(f"Mean results of random interventions: {np.mean(results_random_policy)}, Std results of random interventions: {np.std(results_random_policy)}")

    with open(f"{checkpoint_dir}/trained_exp_results.pkl", "wb") as f:
        pickle.dump(trained_exp_results, f)



