# gnn-meta-graph/src/dataset_simulated_setting1.py
import os
import numpy as np
import torch
from torch_geometric.data import Data

def generate_synthetic_data(modalities=5, num_samples=100, node_size=10, use_node=10, p=30, seed=123):
    np.random.seed(seed)
    data_modalities = []
    labels_modalities = []

    for modality in range(modalities):
        label = []
        noise_level = 0.1 + modality * 0.1
        samples = []
        for _ in range(num_samples):
            mean = np.zeros(p)
            cov = np.full((p, p), noise_level)
            np.fill_diagonal(cov, 1)
            x1 = np.random.multivariate_normal(mean, cov, size=use_node)
            y = int(np.mean(x1) >= 0)
            x2 = np.random.uniform(0, 0.5, size=((node_size - use_node), p))
            x = np.vstack((x1, x2))
            samples.append(x)
            label.append(y)
        data_modalities.append(np.stack(samples))
        labels_modalities.append(np.array(label))
    return data_modalities, labels_modalities

def create_edge_index_from_random(num_nodes):
    adj = np.random.rand(num_nodes, num_nodes)
    adj = (adj + adj.T) / 2  # Symmetrize
    np.fill_diagonal(adj, 1)
    threshold = np.quantile(adj, 0.5)
    adj[adj < threshold] = 0
    adj[adj >= threshold] = 1
    edge_index = torch.tensor(np.array(adj.nonzero()), dtype=torch.long)
    return edge_index

def prepare_synthetic_gnn_datasets(data_modalities, labels_modalities):
    datasets = []
    for mod_data, mod_labels in zip(data_modalities, labels_modalities):
        mod_dataset = []
        for i in range(len(mod_data)):
            x = torch.tensor(mod_data[i], dtype=torch.float)
            edge_index = create_edge_index_from_random(x.size(0))
            y = torch.tensor(mod_labels[i], dtype=torch.long)
            mod_dataset.append(Data(x=x, edge_index=edge_index, y=y))
        datasets.append(mod_dataset)
    return datasets

def load_simulated_datasets():
    data_modalities, labels_modalities = generate_synthetic_data()
    gnn_datasets = prepare_synthetic_gnn_datasets(data_modalities, labels_modalities)
    return gnn_datasets
