# filename: codebase/step2_gnn_embedding.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import numpy as np
import os
from collections import defaultdict  # Required for loading data from Step 1 if not implicitly handled by torch.load

# Configuration
PROCESSED_DATA_PATH_INPUT = 'data/processed_merger_trees.pt'  # From Step 1
GNN_MODEL_PATH = 'data/gnn_encoder_model.pt'
FINAL_OUTPUT_PATH = 'data/final_processed_data.pt'
OUTPUT_DIR = 'data'

# GNN Hyperparameters
NODE_FEATURE_DIM = 4  # From normalized features
GNN_HIDDEN_DIM = 128
GNN_EMBEDDING_DIM = 64  # Output of SAGEEncoder, and final graph embedding dim
NUM_PHYSICAL_FEATURES = 10  # From Step 1

# Training Hyperparameters for GNN
GNN_LEARNING_RATE = 1e-3
GNN_EPOCHS = 5  # Keep low for pipeline testing; increase for better embeddings
GNN_BATCH_SIZE = 64

# Tensor Construction Hyperparameters
MAX_N_SUB = 60  # Max substructures per tree (padding/truncation)

# Determine device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: " + str(DEVICE))

# --- GNN Definition ---
class GraphSAGEEncoder(torch.nn.Module):
    """
    GraphSAGE Encoder network to produce node embeddings.
    """
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGEEncoder, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.relu = torch.nn.ReLU()

    def forward(self, x, edge_index):
        """
        Forward pass for the encoder.
        Args:
            x (Tensor): Node features.
            edge_index (Tensor): Graph connectivity.
        Returns:
            Tensor: Node embeddings.
        """
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        return x


class GNNAutoencoder(torch.nn.Module):
    """
    GNN Autoencoder for self-supervised training of the GraphSAGEEncoder.
    It tries to reconstruct node features.
    """
    def __init__(self, encoder_in_channels, encoder_hidden_channels, encoder_out_channels):
        super(GNNAutoencoder, self).__init__()
        self.encoder = GraphSAGEEncoder(encoder_in_channels, encoder_hidden_channels, encoder_out_channels)
        # Decoder to reconstruct original node features (4 dims) from node embeddings (encoder_out_channels dims)
        self.decoder = torch.nn.Linear(encoder_out_channels, encoder_in_channels)

    def forward(self, x, edge_index):
        """
        Forward pass for the autoencoder.
        Args:
            x (Tensor): Node features.
            edge_index (Tensor): Graph connectivity.
        Returns:
            Tensor: Reconstructed node features.
        """
        node_embeddings = self.encoder(x, edge_index)
        reconstructed_x = self.decoder(node_embeddings)
        return reconstructed_x


# --- GNN Training Function ---

def train_gnn(train_substructures, num_node_features, gnn_hidden_dim, gnn_embedding_dim, epochs, lr, batch_size, device):
    """
    Trains the GNNAutoencoder.
    Args:
        train_substructures (list): List of Data objects (substructures) for training.
        num_node_features (int): Dimensionality of input node features.
        gnn_hidden_dim (int): Hidden dimension for GNN.
        gnn_embedding_dim (int): Output dimension of GNN encoder (node embeddings).
        epochs (int): Number of training epochs.
        lr (float): Learning rate.
        batch_size (int): Batch size for training.
        device (torch.device): Device to train on.
    Returns:
        GraphSAGEEncoder: The trained encoder part of the GNN.
    """
    print("Starting GNN training...")
    model = GNNAutoencoder(num_node_features, gnn_hidden_dim, gnn_embedding_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    # Filter out substructures with no nodes, if any (should not happen with Step 1 logic)
    train_substructures_filtered = [s for s in train_substructures if s.num_nodes > 0]
    if not train_substructures_filtered:
        print("Warning: No valid substructures found for GNN training.")
        # Return an initialized encoder if no training data
        encoder = GraphSAGEEncoder(num_node_features, gnn_hidden_dim, gnn_embedding_dim).to(device)
        return encoder

    dataloader = DataLoader(train_substructures_filtered, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()
            # Ensure batch.x is float, as SAGEConv expects float inputs
            reconstructed_x = model(batch.x.float(), batch.edge_index)
            loss = criterion(reconstructed_x, batch.x.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        print("GNN Training Epoch " + str(epoch + 1) + "/" + str(epochs) + ", Average Loss: " + "%.6f" % avg_loss)

    print("GNN training complete.")
    # Return the trained encoder part
    return model.encoder



# --- Main Processing Function ---

def generate_embeddings_and_combine_features(datasets, trained_gnn_encoder, device):
    """
    Generates topological embeddings for substructures, combines them with physical features,
    and pads/truncates to create fixed-size tensors per tree.
    Args:
        datasets (dict): Dictionary containing 'train_data', 'val_data', 'test_data'.
                         Each is a list of main tree Data objects.
        trained_gnn_encoder (GraphSAGEEncoder): The trained GNN encoder.
        device (torch.device): Device for inference.
    Returns:
        dict: Processed data containing lists of tensors and labels for each set.
    """
    print("Generating embeddings and combining features...")
    trained_gnn_encoder.eval()  # Set encoder to evaluation mode

    # Create null substructure for padding
    # Normalized features have mean ~0, so zeros are appropriate for a null node's features.
    null_substructure_x = torch.zeros(1, NODE_FEATURE_DIM, device=device, dtype=torch.float)
    null_substructure_edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
    
    with torch.no_grad():
        null_node_embedding = trained_gnn_encoder(null_substructure_x, null_substructure_edge_index)
        # Global mean pool for a single node is just the node's embedding
        null_topological_embedding = global_mean_pool(null_node_embedding, torch.zeros(1, dtype=torch.long, device=device)) 
    
    null_physical_features = torch.zeros(NUM_PHYSICAL_FEATURES, device=device, dtype=torch.float)
    padding_vector = torch.cat([null_physical_features, null_topological_embedding.squeeze(0)], dim=0)
    
    D_feat_combined = NUM_PHYSICAL_FEATURES + GNN_EMBEDDING_DIM
    print("Combined feature dimension per substructure (D_feat_combined): " + str(D_feat_combined))
    print("Padding vector shape: " + str(padding_vector.shape))


    processed_output = {}
    for split_name, data_list in datasets.items():
        print("Processing " + str(split_name) + "...")
        all_tree_tensors = []
        all_tree_labels = []

        for i, tree_data_obj in enumerate(data_list):
            if (i + 1) % 100 == 0:
                print("  Processing " + str(split_name) + " tree " + str(i + 1) + "/" + str(len(data_list)))
            
            substructure_feature_vectors = []
            if hasattr(tree_data_obj, 'substructures'):
                for sub_data in tree_data_obj.substructures:
                    if sub_data.num_nodes == 0:  # Should not happen
                        continue
                    
                    # Ensure sub_data tensors are on the correct device and dtype
                    sub_x = sub_data.x.to(device).float()
                    sub_edge_index = sub_data.edge_index.to(device)
                    
                    with torch.no_grad():
                        node_embeddings = trained_gnn_encoder(sub_x, sub_edge_index)
                        # Create batch vector for pooling: all nodes belong to graph 0
                        batch_vector = torch.zeros(sub_data.num_nodes, dtype=torch.long, device=device)
                        graph_embedding = global_mean_pool(node_embeddings, batch_vector)
                    
                    physical_features = sub_data.physical_features.to(device).float()
                    combined_features = torch.cat([physical_features, graph_embedding.squeeze(0)], dim=0)
                    substructure_feature_vectors.append(combined_features)
            
            # Pad or truncate
            if len(substructure_feature_vectors) > MAX_N_SUB:
                current_tree_tensor_list = substructure_feature_vectors[:MAX_N_SUB]
            else:
                current_tree_tensor_list = substructure_feature_vectors + [padding_vector] * (MAX_N_SUB - len(substructure_feature_vectors))
            
            # Stack to form the (MAX_N_SUB, D_feat_combined) tensor for the tree
            if not current_tree_tensor_list:  # Should not happen if MAX_N_SUB > 0
                 current_tree_tensor = torch.full((MAX_N_SUB, D_feat_combined), padding_vector[0].item() if padding_vector.numel() > 0 else 0.0, device=device, dtype=torch.float)  # Fallback
            else:
                 current_tree_tensor = torch.stack(current_tree_tensor_list)

            all_tree_tensors.append(current_tree_tensor.cpu())  # Store on CPU
            all_tree_labels.append(tree_data_obj.y.cpu())       # Store on CPU

        processed_output[split_name + "_tensors"] = all_tree_tensors
        processed_output[split_name + "_labels"] = all_tree_labels
        print("  Finished processing " + str(split_name) + ". Number of trees: " + str(len(all_tree_tensors)))
        if all_tree_tensors:
            print("  Shape of tensor for one tree: " + str(all_tree_tensors[0].shape))

    return processed_output


def main():
    """
    Main function to run Step 2 processing.
    """
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    print("Loading processed data from Step 1: " + str(PROCESSED_DATA_PATH_INPUT))
    try:
        # Ensure all custom classes like PyG Data are known or handle them appropriately
        # For PyG Data objects, this should be fine if PyG is in the environment.
        loaded_data = torch.load(PROCESSED_DATA_PATH_INPUT, map_location='cpu')  # Load to CPU first
    except FileNotFoundError:
        print("Error: Processed data file not found: " + str(PROCESSED_DATA_PATH_INPUT))
        return
    except Exception as e:
        print("Error loading data: " + str(e))
        return

    train_data_processed = loaded_data['train_data']
    val_data_processed = loaded_data['val_data']
    test_data_processed = loaded_data['test_data']
    feature_means = loaded_data['feature_means']
    feature_stds = loaded_data['feature_stds']

    # Collect all substructures from the training set for GNN training
    all_train_substructures = []
    for tree_obj in train_data_processed:
        if hasattr(tree_obj, 'substructures'):
            all_train_substructures.extend(tree_obj.substructures)
    
    print("Total number of substructures in training set for GNN: " + str(len(all_train_substructures)))

    # Train GNN
    # Note: For actual research, GNN_EPOCHS should be higher.
    trained_encoder = train_gnn(
        all_train_substructures,
        NODE_FEATURE_DIM,
        GNN_HIDDEN_DIM,
        GNN_EMBEDDING_DIM,
        GNN_EPOCHS,
        GNN_LEARNING_RATE,
        GNN_BATCH_SIZE,
        DEVICE
    )
    
    # Save trained GNN encoder model
    torch.save(trained_encoder.state_dict(), GNN_MODEL_PATH)
    print("Trained GNN encoder model saved to: " + str(GNN_MODEL_PATH))
    
    num_params = sum(p.numel() for p in trained_encoder.parameters() if p.requires_grad)
    print("Number of trainable parameters in GNN encoder: " + str(num_params))


    # Prepare datasets dictionary for processing
    datasets_for_embedding_gen = {
        'train': train_data_processed,
        'val': val_data_processed,
        'test': test_data_processed
    }

    # Generate embeddings and combine features
    final_data = generate_embeddings_and_combine_features(
        datasets_for_embedding_gen,
        trained_encoder,  # Pass the model itself, not state_dict
        DEVICE
    )

    # Add metadata to the final output
    final_data['feature_means_original_nodes'] = feature_means
    final_data['feature_stds_original_nodes'] = feature_stds
    final_data['max_N_sub'] = MAX_N_SUB
    final_data['D_feat_combined'] = NUM_PHYSICAL_FEATURES + GNN_EMBEDDING_DIM
    final_data['gnn_config'] = {
        'hidden_dim': GNN_HIDDEN_DIM,
        'embedding_dim': GNN_EMBEDDING_DIM,
        'epochs': GNN_EPOCHS,
        'lr': GNN_LEARNING_RATE,
        'batch_size': GNN_BATCH_SIZE
    }
    
    # Save the final processed data
    torch.save(final_data, FINAL_OUTPUT_PATH)
    print("Final processed data with embeddings saved to: " + str(FINAL_OUTPUT_PATH))
    print("Summary of saved data contents:")
    for key, value in final_data.items():
        if isinstance(value, list):
            print("  " + str(key) + ": list of " + str(len(value)) + " items")
            if value and isinstance(value[0], torch.Tensor):
                 print("    (example item shape: " + str(value[0].shape) + ")")
        elif isinstance(value, torch.Tensor):
            print("  " + str(key) + ": tensor of shape " + str(value.shape))
        else:
            print("  " + str(key) + ": " + str(value))


if __name__ == '__main__':
    main()