# gnn-meta-graph/src/dataset_simulated_setting2.py
import os
import numpy as np
import torch
from torch_geometric.data import Data
import GPy
import random
import warnings

class SyntheticMultimodalGraphDataset:
    def __init__(self, n=1000, node_size=50, use_node=40, p=10, modalities=5, seed=1):
        self.n = n
        self.node_size = node_size
        self.use_node = use_node
        self.p = p
        self.modalities = modalities
        self.seed = seed
        self.modality_variances = [1.7, 1.2, 1.5, 1.1, 1.9]
        self.weights = [0.4, 0.5, 0.7, 0.45, 0.35]
        self.threshold = 0.3
        self.gnn_datasets = []

        self._set_seed()
        warnings.filterwarnings('ignore')
        torch.use_deterministic_algorithms(True)

        self.data_modalities = []
        self.labels_modalities = []

        self._generate_modalities()
        self._create_shared_labels()
        self._build_gnn_datasets()

    def _set_seed(self):
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)

    def _generate_modalities(self):
        for mod in range(self.modalities):
            data_samples = []
            labels = []

            for _ in range(self.n):
                # Important nodes
                kernel_imp = GPy.kern.RBF(input_dim=self.p, variance=1.0, lengthscale=1.0)
                X_imp = np.random.uniform(0, 1, (self.use_node, self.p))
                gp_model_imp = GPy.models.GPRegression(X_imp, np.zeros((self.use_node, 1)), kernel_imp)
                x1 = gp_model_imp.posterior_samples_f(X_imp, full_cov=True, size=self.p).squeeze()

                # Label (nonlinear rule)
                col_x1 = np.quantile(x1, 0.4, axis=0)
                y = (
                    np.sin(np.sum(col_x1[:self.p // 4])) *
                    np.cos(np.sum(col_x1[self.p // 4:4 * (self.p // 4)])) +
                    0.1 * np.sum(col_x1[2 * (self.p // 4):4 * (self.p // 4)] ** 3)
                ) >= 0
                labels.append(int(y))

                # Non-important nodes
                kernel_noise = GPy.kern.RBF(input_dim=self.p, variance=self.modality_variances[mod], lengthscale=0.5)
                X_noise = np.random.uniform(0, 1, (self.node_size - self.use_node, self.p))
                gp_model_noise = GPy.models.GPRegression(X_noise, np.zeros((self.node_size - self.use_node, 1)), kernel_noise)
                x2 = gp_model_noise.posterior_samples_f(X_noise, full_cov=True, size=self.p).squeeze()

                # Combine
                graph_x = np.transpose(np.concatenate((x1, x2), axis=0))  # shape: (p, node_size)
                data_samples.append(graph_x)

            self.data_modalities.append(np.stack(data_samples))  # (n, p, node_size)
            self.labels_modalities.append(np.array(labels))      # (n,)

    def _create_shared_labels(self):
        self.y_shared = (
            np.sum([w * labels for w, labels in zip(self.weights, self.labels_modalities)], axis=0) >= self.threshold
        ).astype(int)

    def _build_gnn_datasets(self):
        for i in range(self.modalities):
            dataset = []
            mod_data = self.data_modalities[i]
            mod_labels = self.y_shared

            for idx in range(self.n):
                node_features = mod_data[idx].T  # shape: (node_size, p)

                # Correlation-based adjacency
                corr = np.corrcoef(node_features.T)
                adj = np.abs(corr)
                thres = np.quantile(adj, 0.5)
                adj[adj < thres] = 0
                adj[adj == 1] = 0
                adj[adj >= thres] = 1
                np.fill_diagonal(adj, 1)

                edge_index = torch.tensor(adj, dtype=torch.float).nonzero(as_tuple=False).t().contiguous()
                x_tensor = torch.tensor(node_features, dtype=torch.float)
                y_tensor = torch.tensor(mod_labels[idx], dtype=torch.long)

                graph = Data(x=x_tensor, edge_index=edge_index, y=y_tensor)
                dataset.append(graph)

            self.gnn_datasets.append(dataset)

    def get_all_modalities(self):
        return self.gnn_datasets

    def get_dataset_by_index(self, index):
        return self.gnn_datasets[index]
