#Example code of SP

import torch
import torch.nn.functional as F
from torch.nn import BatchNorm1d as BN
from torch.nn import Linear, ReLU, Sequential
from torch_geometric.nn import GINConv, global_mean_pool, GCNConv
from torch_geometric.data import Data, Batch
from torch_scatter import scatter_mean


class SPGCN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden)
        self.convs = torch.nn.ModuleList([
            GCNConv(hidden, hidden) for _ in range(num_layers - 1)
        ])
        self.lin = Linear(hidden, hidden)
        self.sub_lin = Linear(hidden, hidden)
        self.final_lin = Linear(2 * hidden, 1)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin.reset_parameters()
        self.sub_lin.reset_parameters()
        self.final_lin.reset_parameters()

    def forward_shared(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        return x

    def forward(self, data):
        # Global Branch
        x_global = self.forward_shared(data.x, data.edge_index)
        global_emb = global_mean_pool(x_global, data.batch)
        global_emb = F.relu(self.lin(global_emb))
        global_emb = F.dropout(global_emb, p=0.5, training=self.training)
        batch_size = global_emb.size(0)

        # Subgraph Branch
        if hasattr(data, 'subgraphs') and data.subgraphs:
            if isinstance(data.subgraphs, list) and len(data.subgraphs) == batch_size:
                sub_data_list = []
                sample_idx = []
                for i, subs in enumerate(data.subgraphs):
                    for sub in subs:
                        sub_data_list.append(Data(x=data.x[sub['subset']], edge_index=sub['edge_index']))
                        sample_idx.append(i)
                if len(sub_data_list) > 0:
                    batch_sub = Batch.from_data_list(sub_data_list)
                    sub_x = self.forward_shared(batch_sub.x, batch_sub.edge_index)
                    sub_means = global_mean_pool(sub_x, batch_sub.batch)
                    aggregated_sub = scatter_mean(sub_means, torch.tensor(sample_idx, device=global_emb.device), dim=0, dim_size=batch_size)
                else:
                    aggregated_sub = torch.zeros(batch_size, global_emb.size(1), device=global_emb.device)
            else:
                sub_data_list = [Data(x=data.x[sub['subset']], edge_index=sub['edge_index']) for sub in data.subgraphs]
                batch_sub = Batch.from_data_list(sub_data_list)
                sub_x = self.forward_shared(batch_sub.x, batch_sub.edge_index)
                sub_means = global_mean_pool(sub_x, batch_sub.batch)
                aggregated_sub = sub_means.mean(dim=0, keepdim=True).repeat(batch_size, 1)
            aggregated_sub = F.relu(self.sub_lin(aggregated_sub))
            aggregated_sub = F.dropout(aggregated_sub, p=0.5, training=self.training)
        else:
            aggregated_sub = torch.zeros(batch_size, global_emb.size(1), device=global_emb.device)

        # Final Combination
        combined = torch.cat([global_emb, aggregated_sub], dim=1)
        out = self.final_lin(combined)
        return out


class SPGIN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super().__init__()
        self.conv1 = GINConv(
            Sequential(
                Linear(dataset.num_features, hidden),
                ReLU(),
                BN(hidden),
                Linear(hidden, hidden),
                ReLU(),
                BN(hidden),
            ), train_eps=True)
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(
                GINConv(
                    Sequential(
                        Linear(hidden, hidden),
                        ReLU(),
                        BN(hidden),
                        Linear(hidden, hidden),
                        ReLU(),
                        BN(hidden),
                    ), train_eps=True))
        self.lin = Linear(hidden, hidden)
        self.sub_lin = Linear(hidden, hidden)
        self.final_lin = Linear(2 * hidden, 1)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin.reset_parameters()
        self.sub_lin.reset_parameters()
        self.final_lin.reset_parameters()

    def forward_shared(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        return x

    def forward(self, data):
        x_global = self.forward_shared(data.x, data.edge_index)
        global_emb = global_mean_pool(x_global, data.batch)
        global_emb = F.relu(self.lin(global_emb))
        global_emb = F.dropout(global_emb, p=0.5, training=self.training)
        batch_size = global_emb.size(0)

        if hasattr(data, 'subgraphs') and data.subgraphs:
            if isinstance(data.subgraphs, list) and len(data.subgraphs) == batch_size:
                sub_data_list = []
                sample_idx = []
                for i, subs in enumerate(data.subgraphs):
                    for sub in subs:
                        sub_data_list.append(Data(x=data.x[sub['subset']], edge_index=sub['edge_index']))
                        sample_idx.append(i)
                if len(sub_data_list) > 0:
                    batch_sub = Batch.from_data_list(sub_data_list)
                    sub_x = self.forward_shared(batch_sub.x, batch_sub.edge_index)
                    sub_means = global_mean_pool(sub_x, batch_sub.batch)
                    aggregated_sub = scatter_mean(sub_means, torch.tensor(sample_idx, device=global_emb.device), dim=0, dim_size=batch_size)
                else:
                    aggregated_sub = torch.zeros(batch_size, global_emb.size(1), device=global_emb.device)
            else:
                sub_data_list = [Data(x=data.x[sub['subset']], edge_index=sub['edge_index']) for sub in data.subgraphs]
                batch_sub = Batch.from_data_list(sub_data_list)
                sub_x = self.forward_shared(batch_sub.x, batch_sub.edge_index)
                sub_means = global_mean_pool(sub_x, batch_sub.batch)
                aggregated_sub = sub_means.mean(dim=0, keepdim=True).repeat(batch_size, 1)
            aggregated_sub = F.relu(self.sub_lin(aggregated_sub))
            aggregated_sub = F.dropout(aggregated_sub, p=0.5, training=self.training)
        else:
            aggregated_sub = torch.zeros(batch_size, global_emb.size(1), device=global_emb.device)

        combined = torch.cat([global_emb, aggregated_sub], dim=1)
        out = self.final_lin(combined)
        return out


class subGIN_TU(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super().__init__()
        self.conv1 = GINConv(
            Sequential(
                Linear(dataset.num_features, hidden),
                ReLU(),
                BN(hidden),
                Linear(hidden, hidden),
                ReLU(),
                BN(hidden),
            ), train_eps=True)
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(
                GINConv(
                    Sequential(
                        Linear(hidden, hidden),
                        ReLU(),
                        BN(hidden),
                        Linear(hidden, hidden),
                        ReLU(),
                        BN(hidden),
                    ), train_eps=True))
        self.lin = Linear(hidden, hidden)
        self.sub_lin = Linear(hidden, hidden)
        self.final_lin = Linear(2 * hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin.reset_parameters()
        self.sub_lin.reset_parameters()
        self.final_lin.reset_parameters()

    def forward_shared(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        return x

    def forward(self, data):
        x_global = self.forward_shared(data.x, data.edge_index)
        global_emb = global_mean_pool(x_global, data.batch)
        global_emb = F.relu(self.lin(global_emb))
        global_emb = F.dropout(global_emb, p=0.5, training=self.training)
        batch_size = global_emb.size(0)

        if hasattr(data, 'subgraphs') and data.subgraphs:
            if isinstance(data.subgraphs, list) and len(data.subgraphs) == batch_size:
                sub_data_list = []
                sample_idx = []
                for i, subs in enumerate(data.subgraphs):
                    for sub in subs:
                        sub_data_list.append(Data(x=data.x[sub['subset']], edge_index=sub['edge_index']))
                        sample_idx.append(i)
                if len(sub_data_list) > 0:
                    batch_sub = Batch.from_data_list(sub_data_list)
                    sub_x = self.forward_shared(batch_sub.x, batch_sub.edge_index)
                    sub_means = global_mean_pool(sub_x, batch_sub.batch)
                    aggregated_sub = scatter_mean(sub_means, torch.tensor(sample_idx, device=global_emb.device), dim=0, dim_size=batch_size)
                else:
                    aggregated_sub = torch.zeros(batch_size, global_emb.size(1), device=global_emb.device)
            else:
                sub_data_list = [Data(x=data.x[sub['subset']], edge_index=sub['edge_index']) for sub in data.subgraphs]
                batch_sub = Batch.from_data_list(sub_data_list)
                sub_x = self.forward_shared(batch_sub.x, batch_sub.edge_index)
                sub_means = global_mean_pool(sub_x, batch_sub.batch)
                aggregated_sub = sub_means.mean(dim=0, keepdim=True).repeat(batch_size, 1)
            aggregated_sub = F.relu(self.sub_lin(aggregated_sub))
            aggregated_sub = F.dropout(aggregated_sub, p=0.5, training=self.training)
        else:
            aggregated_sub = torch.zeros(batch_size, global_emb.size(1), device=global_emb.device)

        combined = torch.cat([global_emb, aggregated_sub], dim=1)
        x = self.final_lin(combined)
        return F.log_softmax(x, dim=-1)
