# filename: codebase/merger_trees_processing.py
import torch
import numpy as np
import os
from collections import defaultdict
from torch_geometric.data import Data

# Configuration
F_TREE = '/mnt/home/fanonymous/public_www/Pablo_Bermejo/Pablo_merger_trees2.pt'
OUTPUT_DIR = 'data'
PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, 'processed_merger_trees.pt')
SUBSTRUCTURE_CONFIG = {
    'mass_ratio_percentile_threshold': 20.0, 
}
RANDOM_SEED = 42 # For reproducibility of splits

# Global flag to ensure warning prints only once
edge_direction_warning_printed = False

# Helper function to build adjacency lists
def _build_adj_lists(data_obj):
    """
    Builds predecessor and successor maps for a graph.
    Assumes edge_index[0] are progenitors (source) and edge_index[1] are descendants (target),
    meaning scale_factor[source] <= scale_factor[target].
    """
    global edge_direction_warning_printed
    predecessors_map = defaultdict(list)
    successors_map = defaultdict(list)
    
    if data_obj.edge_index.shape[1] > 0 and not edge_direction_warning_printed:
        num_edges_to_check = min(10, data_obj.edge_index.shape[1])
        sample_edge_indices = torch.randperm(data_obj.edge_index.shape[1])[:num_edges_to_check]
        for i in sample_edge_indices:
            src, tgt = data_obj.edge_index[:, i]
            if data_obj.x[src, 3] > data_obj.x[tgt, 3]:
                print("Warning: Edge direction assumption (SF_source <= SF_target) might be violated.")
                edge_direction_warning_printed = True # Print only once
                break 

    for i in range(data_obj.edge_index.shape[1]):
        source_node, target_node = data_obj.edge_index[0, i].item(), data_obj.edge_index[1, i].item()
        successors_map[source_node].append(target_node)
        predecessors_map[target_node].append(source_node)
    return predecessors_map, successors_map

# Helper function to get all nodes and edges in a branch backwards from a root_node
def _get_branch_info(data_obj, root_node_idx, predecessors_map):
    """
    Performs a backward traversal (DFS) from root_node_idx to find all nodes in its branch.
    Extracts edges within this branch and computes intrinsic features.
    """
    branch_nodes = set()
    queue = [root_node_idx]
    head = 0
    while head < len(queue):
        curr_node = queue[head]
        head += 1
        if curr_node in branch_nodes:
            continue
        branch_nodes.add(curr_node)
        for progenitor in predecessors_map.get(curr_node, []):
            if progenitor not in branch_nodes:
                queue.append(progenitor)
    
    branch_nodes_list = sorted(list(branch_nodes)) 
    
    branch_edge_list_for_tensor = []
    node_to_new_idx = {old_idx: new_idx for new_idx, old_idx in enumerate(branch_nodes_list)}
    
    for u_original_idx, v_original_idx in data_obj.edge_index.T.tolist():
        if u_original_idx in branch_nodes and v_original_idx in branch_nodes:
            branch_edge_list_for_tensor.append([node_to_new_idx[u_original_idx], node_to_new_idx[v_original_idx]])

    if not branch_edge_list_for_tensor:
        branch_edge_index = torch.empty((2, 0), dtype=torch.long)
    else:
        branch_edge_index = torch.tensor(branch_edge_list_for_tensor, dtype=torch.long).T
    
    branch_nodes_tensor = torch.tensor(branch_nodes_list, dtype=torch.long)
    
    # Use original unnormalized features for scale factor stats, normalized for conc/vmax means
    if len(branch_nodes_tensor) > 0:
        mean_norm_log_conc_branch = torch.mean(data_obj.x_norm[branch_nodes_tensor, 1])
        mean_norm_log_vmax_branch = torch.mean(data_obj.x_norm[branch_nodes_tensor, 2])
        all_sfs_branch = data_obj.x[branch_nodes_tensor, 3]
        min_sf_branch = torch.min(all_sfs_branch) if all_sfs_branch.numel() > 0 else data_obj.x[root_node_idx, 3]
    else: # Should not happen if root_node_idx is always in branch_nodes
        mean_norm_log_conc_branch = torch.tensor(0.0)
        mean_norm_log_vmax_branch = torch.tensor(0.0)
        min_sf_branch = data_obj.x[root_node_idx, 3]


    intrinsic_stats = {
        'mean_norm_log_conc_branch': mean_norm_log_conc_branch.item(),
        'mean_norm_log_vmax_branch': mean_norm_log_vmax_branch.item(),
        'min_sf_branch': min_sf_branch.item(),
        'num_halos_in_branch': len(branch_nodes_list)
    }
    return branch_nodes_list, branch_edge_index, intrinsic_stats


def extract_substructures_for_tree(data_obj, config, predecessors_map):
    """
    Identifies substructures in a single merger tree and extracts their features.
    """
    substructures = []
    all_mass_ratios = []

    for node_idx in range(data_obj.num_nodes):
        progenitors_indices = predecessors_map.get(node_idx, [])
        if len(progenitors_indices) > 1:
            progenitor_masses = data_obj.x[progenitors_indices, 0]
            sorted_indices = torch.argsort(progenitor_masses, descending=True)
            # Ensure list conversion for indexing if needed, though tensor indexing is fine
            sorted_progenitors_indices = [progenitors_indices[i.item()] for i in sorted_indices]
            
            p_main_idx = sorted_progenitors_indices[0]
            for p_sub_idx in sorted_progenitors_indices[1:]:
                log_mass_ratio = data_obj.x[p_sub_idx, 0] - data_obj.x[p_main_idx, 0]
                all_mass_ratios.append(log_mass_ratio.item())

    if not all_mass_ratios:
        data_obj.substructures = []
        return

    mass_ratio_threshold = np.percentile(all_mass_ratios, config['mass_ratio_percentile_threshold'])

    for node_idx in range(data_obj.num_nodes):
        progenitors_indices = predecessors_map.get(node_idx, [])
        if len(progenitors_indices) > 1:
            descendant_idx = node_idx
            progenitor_masses = data_obj.x[progenitors_indices, 0]
            sorted_indices = torch.argsort(progenitor_masses, descending=True)
            sorted_progenitors_indices = [progenitors_indices[i.item()] for i in sorted_indices]
            
            p_main_idx = sorted_progenitors_indices[0]
            for p_sub_idx in sorted_progenitors_indices[1:]:
                log_mass_ratio_val = (data_obj.x[p_sub_idx, 0] - data_obj.x[p_main_idx, 0]).item()
                
                if log_mass_ratio_val >= mass_ratio_threshold:
                    feat_log10_mass_ratio_ps_pm = log_mass_ratio_val
                    feat_merger_scale_factor = data_obj.x[descendant_idx, 3].item()
                    feat_delta_norm_log_conc = (data_obj.x_norm[p_sub_idx, 1] - data_obj.x_norm[p_main_idx, 1]).item()
                    feat_delta_norm_log_vmax = (data_obj.x_norm[p_sub_idx, 2] - data_obj.x_norm[p_main_idx, 2]).item()
                    feat_p_sub_log_mass = data_obj.x[p_sub_idx, 0].item()
                    feat_p_sub_scale_factor = data_obj.x[p_sub_idx, 3].item()

                    branch_nodes_list, branch_edge_index, branch_intrinsic_stats = \
                        _get_branch_info(data_obj, p_sub_idx, predecessors_map)

                    feat_mean_norm_log_conc_branch = branch_intrinsic_stats['mean_norm_log_conc_branch']
                    feat_mean_norm_log_vmax_branch = branch_intrinsic_stats['mean_norm_log_vmax_branch']
                    feat_sf_span_branch = feat_p_sub_scale_factor - branch_intrinsic_stats['min_sf_branch']
                    feat_num_halos_in_branch = branch_intrinsic_stats['num_halos_in_branch']
                    
                    physical_features_list = [
                        feat_log10_mass_ratio_ps_pm, feat_merger_scale_factor,
                        feat_delta_norm_log_conc, feat_delta_norm_log_vmax,
                        feat_p_sub_log_mass, feat_p_sub_scale_factor,
                        feat_mean_norm_log_conc_branch, feat_mean_norm_log_vmax_branch,
                        feat_sf_span_branch, float(feat_num_halos_in_branch)
                    ]
                    
                    sub_graph_nodes_tensor = torch.tensor(branch_nodes_list, dtype=torch.long)
                    sub_data = Data(
                        x=data_obj.x_norm[sub_graph_nodes_tensor, :],
                        edge_index=branch_edge_index,
                        physical_features=torch.tensor(physical_features_list, dtype=torch.float32),
                        p_sub_idx_original=p_sub_idx,
                        p_main_idx_original=p_main_idx,
                        descendant_idx_original=descendant_idx,
                        num_nodes = len(branch_nodes_list)
                    )
                    substructures.append(sub_data)
    
    data_obj.substructures = substructures


def main():
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    print("Loading dataset...")
    try:
        # Explicitly set weights_only=False as the source is trusted and file requires it
        full_dataset = torch.load(F_TREE, map_location=torch.device('cpu'), weights_only=False)
    except FileNotFoundError:
        print("Error: The file " + str(F_TREE) + " was not found.")
        return
    except Exception as e: # Catch other loading errors
        print("A general error occurred during dataset loading: " + str(e))
        return


    print("Dataset loaded. Number of trees: " + str(len(full_dataset)))

    sim_groups = defaultdict(list)
    for i, data_obj in enumerate(full_dataset):
        if not hasattr(data_obj, 'y') or data_obj.y is None:
            print("Warning: Tree " + str(i) + " is missing 'y' attribute. Skipping for grouping.")
            continue
        params = tuple(data_obj.y.flatten().tolist())
        sim_groups[params].append(i)

    unique_sim_params = list(sim_groups.keys())
    if not unique_sim_params:
        print("Error: No simulations found for splitting. Check 'y' attributes in data.")
        return
        
    np.random.seed(RANDOM_SEED)
    np.random.shuffle(unique_sim_params)

    num_sims = len(unique_sim_params)
    num_train_sims = int(0.7 * num_sims)
    num_val_sims = int(0.15 * num_sims)
    
    train_sim_params = unique_sim_params[:num_train_sims]
    val_sim_params = unique_sim_params[num_train_sims : num_train_sims + num_val_sims]
    test_sim_params = unique_sim_params[num_train_sims + num_val_sims:]

    train_indices = [idx for params in train_sim_params for idx in sim_groups[params]]
    val_indices = [idx for params in val_sim_params for idx in sim_groups[params]]
    test_indices = [idx for params in test_sim_params for idx in sim_groups[params]]
    
    print("Data split (number of trees):")
    print("  Train: " + str(len(train_indices)))
    print("  Validation: " + str(len(val_indices)))
    print("  Test: " + str(len(test_indices)))

    print("Normalizing node features...")
    all_train_features = []
    for i in train_indices:
        if hasattr(full_dataset[i], 'x') and full_dataset[i].x is not None:
             all_train_features.append(full_dataset[i].x)
    
    if not all_train_features:
        print("Error: No node features 'x' found in the training set for normalization.")
        return
        
    all_train_features_tensor = torch.cat(all_train_features, dim=0)
    feature_means = torch.mean(all_train_features_tensor, dim=0)
    feature_stds = torch.std(all_train_features_tensor, dim=0)
    feature_stds[feature_stds == 0] = 1.0

    print("Node feature statistics (before normalization, from training set):")
    feature_names = ['log10(mass)', 'log10(concentration)', 'log10(Vmax)', 'scale_factor']
    for i in range(feature_means.shape[0]):
        print("  " + feature_names[i] + ": Mean = " + str(feature_means[i].item()) + ", Std = " + str(feature_stds[i].item()))

    for i, data_obj in enumerate(full_dataset):
        if hasattr(data_obj, 'x') and data_obj.x is not None:
            data_obj.x_norm = (data_obj.x - feature_means) / feature_stds
        else: # Add empty x_norm if x is missing, though this tree would be problematic
            data_obj.x_norm = torch.empty_like(data_obj.x) if hasattr(data_obj, 'x') and data_obj.x is not None else torch.empty((0,feature_means.shape[0]))


    if train_indices and hasattr(full_dataset[train_indices[0]], 'x_norm') and full_dataset[train_indices[0]].x_norm.numel() > 0 :
        sample_normalized_features = full_dataset[train_indices[0]].x_norm
        print("Normalized feature means (sample from train): " + str(torch.mean(sample_normalized_features, dim=0).tolist()))
        print("Normalized feature stds (sample from train): " + str(torch.std(sample_normalized_features, dim=0).tolist()))

    print("Extracting substructures...")
    all_substructure_counts = []
    total_substructures_found = 0
    
    all_predecessors_maps = [defaultdict(list) for _ in range(len(full_dataset))]
    print("Building adjacency lists for all trees...")
    for i, data_obj in enumerate(full_dataset):
        if not hasattr(data_obj, 'edge_index') or data_obj.edge_index is None or data_obj.edge_index.shape[0] != 2:
            print("Warning: Tree " + str(i) + " has invalid or missing edge_index. Shape: " + str(data_obj.edge_index.shape if hasattr(data_obj, 'edge_index') and data_obj.edge_index is not None else "None") + ". Skipping map building.")
            continue
        if data_obj.edge_index.shape[1] == 0:
            continue # No edges, map remains empty defaultdict
            
        pred_map, _ = _build_adj_lists(data_obj)
        all_predecessors_maps[i] = pred_map

    print("Processing trees for substructure extraction:")
    for i, data_obj in enumerate(full_dataset):
        if i % 100 == 0:
            print("  Processing tree " + str(i) + "/" + str(len(full_dataset)))
        
        # Ensure data_obj has x and x_norm before processing
        if not hasattr(data_obj, 'x') or data_obj.x is None or \
           not hasattr(data_obj, 'x_norm') or data_obj.x_norm is None or \
           not hasattr(data_obj, 'num_nodes') or data_obj.num_nodes == 0:
            print("  Skipping tree " + str(i) + " due to missing x, x_norm, or num_nodes=0.")
            data_obj.substructures = []
            all_substructure_counts.append(0)
            continue

        extract_substructures_for_tree(data_obj, SUBSTRUCTURE_CONFIG, all_predecessors_maps[i])
        
        all_substructure_counts.append(len(data_obj.substructures))
        total_substructures_found += len(data_obj.substructures)

    print("Substructure extraction complete.")
    print("  Total substructures found: " + str(total_substructures_found))
    if len(full_dataset) > 0 :
        avg_subs_per_tree = total_substructures_found / len(full_dataset)
        print("  Average substructures per tree: " + "{:.2f}".format(avg_subs_per_tree))
    if all_substructure_counts:
        print("  Min substructures in a tree: " + str(min(all_substructure_counts)))
        print("  Max substructures in a tree: " + str(max(all_substructure_counts)))
        print("  Median substructures per tree: " + str(np.median(all_substructure_counts)))
    
    all_physical_features_list = []
    for data_obj in full_dataset:
        if hasattr(data_obj, 'substructures'):
            for sub_data in data_obj.substructures:
                all_physical_features_list.append(sub_data.physical_features)
    
    if all_physical_features_list:
        all_physical_features_tensor = torch.stack(all_physical_features_list)
        print("Statistics of substructure physical features (across all " + str(total_substructures_found) + " substructures):")
        feature_desc = [
            "log10_mass_ratio_ps_pm", "merger_scale_factor", "delta_norm_log_conc", 
            "delta_norm_log_vmax", "p_sub_log_mass", "p_sub_scale_factor",
            "mean_norm_log_conc_branch", "mean_norm_log_vmax_branch",
            "sf_span_branch", "num_halos_in_branch"
        ]
        for i in range(all_physical_features_tensor.shape[1]):
            feat_values = all_physical_features_tensor[:, i]
            mean_val_str = "{:.3f}".format(torch.mean(feat_values).item())
            std_val_str = "{:.3f}".format(torch.std(feat_values).item())
            min_val_str = "{:.3f}".format(torch.min(feat_values).item())
            max_val_str = "{:.3f}".format(torch.max(feat_values).item())
            print("  " + feature_desc[i] + ":")
            print("    Mean: " + mean_val_str +
                  ", Std: " + std_val_str +
                  ", Min: " + min_val_str +
                  ", Max: " + max_val_str)
    else:
        print("No substructures found to compute feature statistics.")

    train_data_processed = [full_dataset[i] for i in train_indices]
    val_data_processed = [full_dataset[i] for i in val_indices]
    test_data_processed = [full_dataset[i] for i in test_indices]

    print("Saving processed data to: " + str(PROCESSED_DATA_PATH))
    torch.save({
        'train_data': train_data_processed,
        'val_data': val_data_processed,
        'test_data': test_data_processed,
        'feature_means': feature_means,
        'feature_stds': feature_stds,
        'substructure_config': SUBSTRUCTURE_CONFIG
    }, PROCESSED_DATA_PATH)
    print("Data saving complete.")


if __name__ == '__main__':
    main()