# src/utils.py

import numpy as np
import torch
import random
from sklearn.preprocessing import StandardScaler
import torch.nn as nn

"""
utils.py

Utility functions for reproducibility, data normalization, and model initialization
in the GNN meta-graph fusion pipeline.
"""


def set_seed(seed=1):
    import numpy as np, random
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)


def set_weights(model):
    for layer in model.modules():
        if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
            torch.manual_seed(1)
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)



def normalize_features(train_data, test_data, val_data):
    """
    Normalizes the LFP dataset using StandardScaler, ensuring data type remains unchanged.
    """
    train_features = train_data.reshape(-1, train_data.shape[-1])  # Flatten across samples
    scaler = StandardScaler()
    scaler.fit(train_features)  # Fit only on training data

    # Apply transformation without changing data type
    train_data = scaler.transform(train_data.reshape(-1, train_data.shape[-1])).reshape(train_data.shape)
    test_data = scaler.transform(test_data.reshape(-1, test_data.shape[-1])).reshape(test_data.shape)
    val_data = scaler.transform(val_data.reshape(-1, val_data.shape[-1])).reshape(val_data.shape)

    return train_data, test_data, val_data


def normalize_features_eval(train_data, test_data):
    """
    Normalizes the LFP dataset using StandardScaler, ensuring data type remains unchanged.
    """
    train_features = train_data.reshape(-1, train_data.shape[-1])  # Flatten across samples
    scaler = StandardScaler()
    scaler.fit(train_features)  # Fit only on training data

    # Apply transformation without changing data type
    train_data = scaler.transform(train_data.reshape(-1, train_data.shape[-1])).reshape(train_data.shape)
    test_data = scaler.transform(test_data.reshape(-1, test_data.shape[-1])).reshape(test_data.shape)

    return train_data, test_data


def generate_hyperparam_combinations(hyperparameters):
    keys, values = zip(*hyperparameters.items())
    return [dict(zip(keys, v)) for v in product(*values)]




