import warnings
import functools
import numpy as np
import math
import os
import torch
import warnings
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
from tqdm.auto import tqdm 
import pyro
pyro.clear_param_store()
import torch.nn as nn
import torch.nn.functional as F 
from utils import mat_to_toporder, acyclic_constr_nograd
from synthetic import ErdosRenyi, ScaleFree, Yeast, Ecoli 
from synthetic.distributions import Gaussian_torch 
from models import policy_opt, Post_sampling, FCNN, NFs 
import argparse

if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = torch.device(dev)

def parse_list(arg):
    try:
        # Remove brackets, split by comma, and convert to integers
        return [int(x) for x in arg.strip('[]').split(',')]
    except Exception as e:
        raise argparse.ArgumentTypeError(f"Failed to parse list: {e}")


def make_parser():

    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="er", type=str, help="Enviroment name")
    parser.add_argument("--edge_prob", default=0.2, type=float, help="expected edges per node for ER graph")
    parser.add_argument("--num_steps", default=200, type=int, help="total number of steps")
    parser.add_argument("--n_nodes", default=10, type=int, help="number of nodes")
    parser.add_argument("--n_stage", default=5, type=int, help="number of stages for intervention")
    parser.add_argument("--torch_manual_seed", default=123, type=int, help="torch manual seed")
    parser.add_argument("--env_seed", default=283956, type=int, help="seed for environment generation in data loading")
    parser.add_argument("--do_node", default=0, type=int, help="the node to do intervention")
    parser.add_argument("--goal_nodes", type=parse_list, default=[1,2], help="the goal nodes (format: [n1,n2,...])")
    parser.add_argument("--do_mean", type=int, default=0, help="the mean of the do distribution")
    parser.add_argument("--do_std", type=float, default=1.0, help="the std of the do distribution")
    parser.add_argument("--n_int", default=5, type=int, help="number of interventions per stage")
    parser.add_argument("--n_obs", default=50, type=int, help="number of observational samples (used for Post_sampling x_obs)")
    parser.add_argument("--train_n_envs", default=10, type=int, help="number of envs in each training step")
    parser.add_argument("--eval_n_envs", default=10, type=int, help="number of envs for evaluation")
    parser.add_argument("--n_update_posterior", default=5, type=int, help="number of updates of posterior network")
    parser.add_argument("--eval_every", default=5, type=int, help="evaluate after every n steps (used for print/log)") 
    parser.add_argument("--save_every", default=50, type=int, help="save model after every n steps") 
    parser.add_argument("--save_dir", default="results/nonlinear_goal_experiments",
                        type=str, help="save directory of the model and results")
    parser.add_argument("--posterior_data_dir", default="data/nonlinear_posterior_samples", type=str,
                        help="Directory containing precomputed posterior samples (p1.pt, posterior_weights.pt, torch_obs_data.pt)")
    
    return parser


class Pred_Oriented_Policy( Post_sampling ):
    def __init__(self, post_sampling_instance, hidden_size_1, hidden_size_2, activation = nn.ReLU(),
                       do_node = None, do_distribution = None, 
                       goal_node = None, n_int = 5, n_stage = 10, 
                       actor_network = None, post_network = None, 
                       post_lr = 1e-4, post_gamma = 0.8,
                       actor_lr = 1e-4, actor_gamma = 0.8):
        
        super().__init__(n_nodes  = post_sampling_instance.n_nodes , 
                        post_prob = post_sampling_instance.post_prob, 
                        num_simulation = post_sampling_instance.num_simulation, 
                        x_obs = post_sampling_instance.x_obs)
        
        self.nodes_unique_parents = post_sampling_instance.nodes_unique_parents
        self.posterior_weights    = post_sampling_instance.posterior_weights 
        self.graph_probabilities  = post_sampling_instance.graph_probabilities 
        
        self.hidden_size_1 = hidden_size_1
        self.hidden_size_2 = hidden_size_2
        self.activation    = activation
        
 
        self.n_int           = n_int
        self.n_stage         = n_stage

        self.actor_network   = actor_network
        self.actor_optimizer = torch.optim.Adam(self.actor_network.parameters(), lr = actor_lr)
        self.actor_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.actor_optimizer, gamma = actor_gamma)

        self.post_network    = post_network
        self.post_optimizer  = torch.optim.Adam(self.post_network.parameters(), lr = post_lr)
        self.post_scheduler  = torch.optim.lr_scheduler.ExponentialLR(self.post_optimizer, gamma = post_gamma)

        self.step  = 0

        self.do_node         = do_node
        self.do_distribution = do_distribution
        self.goal_node       = goal_node
###############################    
        self.z_mean_global, self.z_std_global = self.compute_normalization_stats(n_samples=10000)

    def compute_normalization_stats(self, n_samples=10000):
        np.random.seed(528717506)
        torch.manual_seed(528717506)
        z = torch.zeros((n_samples, self.n_nodes), requires_grad=False).to(device) 

        for i in tqdm(range(n_samples)):
            rng = torch.Generator().manual_seed(int(i*999))
            graph_set, f_set     = self.simulate_dag_and_weights(n_envs = 1, is_training = False )
    
            z_pred =  self.sample_goal(rng        = rng, 
                                       g          = graph_set[0].float(), 
                                       f          = f_set[0] )
            z[i]   = z_pred
 
        z_mean_global = z.mean(axis =0).unsqueeze(0)[:, self.goal_node]
        z_std_global = z.std(axis =0).unsqueeze(0)[:, self.goal_node]
 
        return z_mean_global.detach(), z_std_global.detach()
 
        
    @staticmethod
    def _nn_mechanism(*, x, z, parents, fc1, fc2, fc3, activation):
        x_msk   = x * parents
        hidden1 = activation(fc1(x_msk))
        hidden2 = activation(fc2(hidden1))
        output  = fc3(hidden2).squeeze(1) + z
        return output
        
    def draw_nn_params(self, *, seed, param_samples, is_training):
        """Draws random instantiation of NN parameters."""
        
        #torch.manual_seed(seed) 
        
        fc1 = nn.Linear(self.n_nodes, self.hidden_size_1).to(device)
        fc2 = nn.Linear(self.hidden_size_1, self.hidden_size_2).to(device)
        fc3 = nn.Linear(self.hidden_size_2, 1).to(device)
        
        layer1_weights_mean = param_samples['layer1_weights_mean'] 
        layer1_bias_mean    = param_samples['layer1_bias_mean']    

        layer2_weights_mean = param_samples['layer2_weights_mean']
        layer2_bias_mean    = param_samples['layer2_bias_mean']   

        layer3_weights_mean = param_samples['layer3_weights_mean'] 
        layer3_bias_mean    = param_samples['layer3_bias_mean'] 

        layer1_weights_scale = param_samples['layer1_weights_scale'] 
        layer1_bias_scale    = param_samples['layer1_bias_scale']   

        layer2_weights_scale = param_samples['layer2_weights_scale'] 
        layer2_bias_scale    = param_samples['layer2_bias_scale']  

        layer3_weights_scale = param_samples['layer3_weights_scale'] 
        layer3_bias_scale    = param_samples['layer3_bias_scale'] 
        
        fc1.weight.data = torch.normal(mean= layer1_weights_mean, std=layer1_weights_scale).to(device)
        fc1.bias.data   = torch.normal(mean= layer1_bias_mean, std=layer1_bias_scale).to(device)

        fc2.weight.data =  torch.normal(mean= layer2_weights_mean, std=layer2_weights_scale).to(device)
        fc2.bias.data   =  torch.normal(mean= layer2_bias_mean , std=layer2_bias_scale).to(device)

        fc3.weight.data =  torch.normal(mean= layer3_weights_mean, std=layer3_weights_scale).to(device)
        fc3.bias.data   =  torch.normal(mean= layer3_bias_mean, std=layer3_bias_scale).to(device)
    
        return dict( fc1=fc1,  fc2=fc2, fc3 = fc3, activation=self.activation )
    

    def simulate_dag_and_weights(self, n_envs, is_training = True):
        assert not any(item is None for item in self.posterior_weights), "The list should not have None values!"

        graph_set   = []
        f_set       = []

        if self.graph_probabilities is None: 
            prob_graph = []
            for node_id, parent_structures in enumerate(self.nodes_unique_parents):
                parents        = list(parent_structures.keys())
                prob_parents   = [parent_structures[ps]['probability'] for ps in parents]
                prob_parents   = np.array(prob_parents)
                prob_graph.append( prob_parents )
            self.graph_probabilities = prob_graph

        for i in range(n_envs):
            is_dag = False
            while not is_dag:
                simuated_graph = np.zeros((self.n_nodes, self.n_nodes))
                for node_id, parent_structures in enumerate(self.nodes_unique_parents):
                    parents         = list(parent_structures.keys())
                    sampled_parents = parents[np.random.choice(len(parents), p=self.graph_probabilities[node_id])]
                    simuated_graph[:, node_id] = sampled_parents
                is_dag = acyclic_constr_nograd(simuated_graph, simuated_graph.shape[-1]) == 0 
            graph_set.append( torch.from_numpy(simuated_graph).to(device) )

            f      = []

            for node_id in range(self.n_nodes):
                param_samples = self.posterior_weights[node_id][tuple(simuated_graph[:, node_id].tolist())]
                nn_params = self.draw_nn_params( seed=int(self.step + n_envs  * node_id + 234567) , param_samples=param_samples, is_training = is_training)
                f.append(functools.partial(Pred_Oriented_Policy._nn_mechanism, **nn_params))

            f_set.append( f )
    
           
        return torch.stack(graph_set).to(device), f_set 

 

    def sample_recursive_scm_pertubation(self, rng, g, f, int_target = None,  int_value = None):
        n_vars   = g.shape[-1]
        toporder = mat_to_toporder(g)

        if int_target is None:
            int_target = torch.zeros( n_vars ).to(device)
            int_value  = torch.zeros( n_vars ).to(device)

        x_int = torch.zeros((self.n_int, n_vars), requires_grad=True).to(device)
        
        for j in toporder:  
            int_percentage = int_target[j]

            z_j            = torch.randn( self.n_int ).to(device) * 0.1 ** 0.5
 
            raw_contrib = f[j](x=x_int, z=z_j, parents = g[:, j].to(device)  ) 
            x_int[:, j] = int_percentage * int_value[j] + (1-int_percentage) * raw_contrib

        return x_int 

 
    def sample_goal(self, rng, g, f ):

        n_vars   = g.shape[-1]
        toporder = mat_to_toporder(g)
        x_pred   = torch.zeros((1, n_vars), requires_grad=True).to(device)

        for j in toporder: 
            if j == self.do_node:
                x_pred[:, self.do_node] = self.do_distribution(rng, shape=(1,)).to(device)
            else:                
                z_j            = torch.randn( 1 ).to(device) * 0.1**0.5
                
        
                raw_contrib  = f[j](x=x_pred, z=z_j, parents = g[:, j].to(device)  ) 
                x_pred[:, j] = raw_contrib

        return x_pred#[:, self.goal_node] 


    def model(self, n_envs, is_training = True):  

        graph_set, f_set               = self.simulate_dag_and_weights(n_envs = n_envs, is_training = is_training )
        y_hist_full                    = torch.zeros(( n_envs, self.n_stage, self.n_int, self.n_nodes, 2)).to(device)
        
        z_full              = torch.zeros(( n_envs, self.n_nodes)).to(device)   # z_full to 2 
        
        for idx_env in range(n_envs): 
            rng                  = torch.Generator().manual_seed(int(self.step  * 106 + 1 + idx_env * 50417 ))
            y_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            d_history            = torch.zeros(( self.n_stage , self.n_int, self.n_nodes )).to(device) 
            X                    = torch.zeros(  self.n_stage,  self.n_int, self.n_nodes, 2).to(device)
        
            for t in range( self.n_stage  ):
                int_target_vec = None 
                int_value_vec  = None
                mask_tiled     = torch.zeros( self.n_nodes ).to(device)

                y_hist_input   = y_history.clone()[ :int(t+1), :, :] 
                d_hist_input   = d_history.clone()[ :int(t+1), :, :]   # t, n_int, n_nodes
                actor_input    = torch.stack((y_hist_input, d_hist_input), dim = -1)
                X[:int(t+1), :, :, :] = actor_input

                actor_input_with_stage = X.clone().detach()[ : int(t+1), ... ] #torch.cat([X, stage_index], dim=-1) 

                X_reshape      = actor_input_with_stage.permute(2, 0, 1, 3).reshape(self.n_nodes, -1, 2).moveaxis(0, 1)
                
                act_input      = X_reshape                
 

                int_target_vec, int_value_vec = self.actor_network(act_input, is_training = is_training, step = self.step ) #clone()
                mask_tiled    = int_target_vec.unsqueeze(1).tile(1, self.n_int)    # n_node, n_int
                d_history[ t, :, : ] = mask_tiled.moveaxis(0, 1)

                x_int = self.sample_recursive_scm_pertubation(   rng  = rng, 
                                                           g          = graph_set[idx_env].float(), 
                                                           f          = f_set[idx_env], 
                                        
                                                           int_target = int_target_vec, 
                                                           int_value  = int_value_vec)

                y_history[ t, :, : ] = x_int   
                
            z_pred =  self.sample_goal(rng        = rng, 
                                       g          = graph_set[idx_env].float(), 
                                       f          = f_set[idx_env] )

            z_full[idx_env, :] = z_pred                    



            y_hist_full[ idx_env, ...,  0 ] = y_history  
            y_hist_full[ idx_env, ...,  1 ] = d_history  
      
        y_hist_reshape = y_hist_full.reshape(n_envs, -1,  self.n_nodes, 2)
        post_net_input = y_hist_reshape 
            

        return  post_net_input, z_full
    
    
    def train(self, hist_y_d, z_full, n_envs = 10, n_update_posterior = 5, training_post = True, n_update = None):  
        
        if training_post:                        # train_posterior
            self.post_network.train()
            self.actor_network.eval()
            for t in range(n_update_posterior):
                self.post_optimizer.zero_grad()
                
                #######################################  
                log_prob_post = self.post_network( hist_y_d[ :, : , :, :].detach(), z_full[:, self.goal_node].detach(), is_training = training_post,
                                                   thetas_std_global = self.z_std_global)
                loss = torch.mean(-log_prob_post)       #log_prob_post 
                #######################################
 
                
                loss.backward()  
                torch.nn.utils.clip_grad_norm_(self.post_network.parameters(), 1)
                self.post_optimizer.step()  
            if n_update % 1000 == 0 and n_update > 1:     
                self.post_scheduler.step()
            self.post_network.eval()
              
            # self.post_scheduler.step()
            # self.post_network.eval()
        else: 
            self.post_network.eval()
            self.actor_network.train()
            
            #####################################
            log_prob_post =  self.post_network(hist_y_d, z_full[:, self.goal_node] ,  is_training = training_post, 
                                               thetas_std_global = self.z_std_global).mean()
            #####################################
            
 

        return log_prob_post
    
    def forward(self, n_envs = 10, n_update_posterior = 10 , n_update = None, is_training = True):
    
        y_hist, z_full = self.model(n_envs = n_envs, is_training = is_training)
        
        z_full[:, self.goal_node] = (z_full[:, self.goal_node] - self.z_mean_global)/self.z_std_global
        y_hist_reshape = y_hist
        
        if is_training is not True: 
            self.post_network.eval()
            self.actor_network.eval()
            
    
            ############################
            log_prob_post = self.post_network(y_hist_reshape, z_full[:, self.goal_node],  is_training = is_training, 
                                              thetas_std_global = self.z_std_global).mean() # log q

            ############################

            self.actor_network.eval()
            self.post_network.eval()
        else: 
            
            self.actor_network.train() #train actor
            self.post_network.eval()
            self.actor_optimizer.zero_grad()
            
            log_prob_post =  self.train(hist_y_d = y_hist_reshape.to(device),   z_full = z_full.to(device), 
                                        n_update_posterior = n_update_posterior, training_post = False).mean()      # notice this negative sign
            loss_actor = - log_prob_post   
 
            
            loss_actor.backward()  
            torch.nn.utils.clip_grad_norm_(self.actor_network.parameters(), 1)
            self.actor_optimizer.step()
            
            if n_update % 1000  == 0 and n_update > 1: 
                self.actor_scheduler.step()
            self.actor_network.eval()
            
            log_prob_post = self.train(hist_y_d = y_hist_reshape.to(device), z_full = z_full.to(device), 
                                       n_update_posterior = n_update_posterior, training_post = True, n_update = n_update).mean()

            self.step += 1  
 
        return log_prob_post, z_full, y_hist_reshape #- log_prob_prior, params, y_stack, design_stack, mu_traj
  
def save_checkpoint(state, filename):
    torch.save(state, filename)



if __name__ == "__main__":
    torch.autograd.set_detect_anomaly(True)
    kwargs = make_parser().parse_args()
    torch_manual_seed = kwargs.torch_manual_seed
    torch.manual_seed(torch_manual_seed)
    n_nodes = kwargs.n_nodes
    n_stage = kwargs.n_stage
    n_int = kwargs.n_int
    n_obs = kwargs.n_obs
    train_n_envs = kwargs.train_n_envs
    eval_n_envs = kwargs.eval_n_envs
    n_update_posterior = kwargs.n_update_posterior
    env_seed = kwargs.env_seed
    do_node = kwargs.do_node
    goal_nodes = kwargs.goal_nodes
    do_mean = int(kwargs.do_mean)
    do_std = float(kwargs.do_std)

    env_name = kwargs.env
    edge_prob = kwargs.edge_prob
    edge_per_var = math.ceil(edge_prob * n_nodes)
    if env_name == "er":
        env = ErdosRenyi(edge_per_var)
    elif env_name == "sf":
        env = ScaleFree(edge_per_var)
    elif env_name == "yeast":
        env = Yeast()
    elif env_name == "ecoli":
        env = Ecoli()

    saved_post_dir = kwargs.posterior_data_dir

    p1_file = f"{saved_post_dir}/p1.pt"
    post_weights_dir = f"{saved_post_dir}/posterior_weights"
    p1 = torch.load(p1_file)

    obs_data_file = f"{saved_post_dir}/torch_obs_data.pt"
    # load data
    torch_obs_data = torch.load(obs_data_file)

    actor = policy_opt(dropout=0.05, layers=4, dim=32, key_size=16,
                    num_heads=8, widening_factor=4, input_shape = 2).to(device)

    post = NFs( input_shape = 2,  network = FCNN, Split1 = [ 0 ], Split2 = [1 ],  
                num_trans = 4,    n_input = 160,  n_theta = 2, dim = 16).to(device) ## n_input = n_node * dim


    post_sampling = Post_sampling(n_nodes = 10, post_prob= p1, num_simulation=200, x_obs=torch_obs_data[...,0].float().cpu(), 
                                warm_samples = 200, post_samples = 200, learning_rate = 5e-2, num_epochs = 30000,
                                prior_scale = 1, split_ratio = 4/5)
    for i in range(n_nodes):  
        post_sampling.posterior_weights[i] = torch.load(f'{post_weights_dir}/node_{i}.pt')

    post_sampling = post_sampling.to(device)

    pop = Pred_Oriented_Policy(post_sampling_instance = post_sampling, hidden_size_1 = 8, hidden_size_2 = 8,
                        do_node   = do_node, do_distribution =  Gaussian_torch(do_mean,  do_std),
                        goal_node = goal_nodes,
                        n_int = n_int,
                        n_stage = n_stage,
                        actor_network = actor, post_network = post,
                        post_lr  = 5e-4, post_gamma = 0.8,
                        actor_lr = 5e-4, actor_gamma = 0.8)
    
    save_dir = kwargs.save_dir
    checkpoint_dir = save_dir + f"/trained_policy_seed_{torch_manual_seed}"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    num_steps = kwargs.num_steps
    for step in tqdm(range(num_steps)):
        log_kld, _, yd_hist = pop.forward(n_envs = 10, n_update_posterior = 5 , n_update = step, is_training = True)
        torch.cuda.empty_cache()

        if step % kwargs.eval_every == 0:
            if step % kwargs.save_every == 0:
                checkpoint = {
                    'step': step,
                    'actor_network_state_dict': pop.actor_network.state_dict(),
                    'post_network_state_dict':  pop.post_network.state_dict(),
                    'actor_optimizer_state_dict': pop.actor_optimizer.state_dict(),
                    'post_optimizer_state_dict':  pop.post_optimizer.state_dict(),
                    'actor_scheduler_state_dict': pop.actor_scheduler.state_dict(),
                    'post_scheduler_state_dict':  pop.post_scheduler.state_dict(),
                    'log_kld': float(log_kld),
                }
                save_checkpoint(checkpoint, filename=f"{checkpoint_dir}/nonlinear_goal_ckpt_step_{step}.pth.tar")

            torch.cuda.empty_cache()

            log_kld, _, yd_hist  = pop.forward(n_envs = 10, n_update_posterior = 1 , n_update = step, is_training = False)

            print(f"Step {step}, log_kld: {log_kld}", flush=True)

            del log_kld, _, yd_hist

