import torch
import torch as ch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR
from torch.autograd import grad
from torch.utils.data import random_split
import torch_geometric
import torch_geometric.utils as utils
from torch_geometric.utils import to_scipy_sparse_matrix, to_torch_sparse_tensor
from torch_geometric.nn import GCNConv, SAGEConv, GATv2Conv, GATConv
from torch_geometric.nn import global_mean_pool, global_add_pool#, SAGPooling
from torch.autograd import Function
# Toy example
import numpy as np
import time
import math
import copy
from tqdm import tqdm
import os
# from OTCoarsening.src.Sinkhorn import sinkhorn_loss_default
# Logging
import logging
import sys
import warnings
import random 
from torch.utils.data import TensorDataset, DataLoader, Dataset, random_split
# from Sinkhorn import sinkhorn_loss_default, sinkhorn_norm_default
from poolings import FisherPooling, SimPooling, SAGPooling
import math
from similarity import cos_similarity_cubed_single as similarity_fn


class Unweighten(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        result = torch.where(input != 0, 1.0, 0.0)
        return result
        
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output

def unweighten(x):
    return Unweighten.apply(x)

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        # self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        # self.bias = nn.Parameter(torch.FloatTensor(out_features))
        # nn.init.xavier_uniform_(self.weight.data)  # Initialize weights using Xavier initialization
        # nn.init.zeros_(self.bias.data)  # Initialize biases to zeros
        self.lin = nn.Linear(in_features, out_features)

    def forward(self, adj, z, node_weights=None):

        # Perform graph convolution
        # support = torch.matmul(node_features, self.weight)
        if node_weights is not None:
            z = z * node_weights
        # print(z.shape)
        # exit()
        output = torch.matmul(adj, z)
        output = self.lin(output)

        return output

class ConceptGNN(nn.Module):
    def __init__(self, input_dim, output_dim, num_concepts=370, hidden_dim=512, tau=0.3, ratio=0.1,
                module='gcn', temperature=0.3, num_layers=3, alpha=0.7, beta=0.8, image_encoder=None, cdm=False):
        super(ConceptGNN, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_concepts = num_concepts
        self.hidden_dim = hidden_dim
        self.image_encoder = image_encoder
        self.cdm = cdm
        if cdm:
            self.presence = nn.Linear(hidden_dim, num_concepts)
        
        self.tau = tau
        self.temperature = temperature
        self.alpha = alpha
        self.beta = beta

        if module == 'gcn':
            self.layers = nn.ModuleList([GCNLayer(input_dim, hidden_dim) for i in range(num_layers)])
        else:
            raise("Error for not recognizing the GNN module")

        self.bn = nn.BatchNorm1d(num_concepts)
        
        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        # self.init_module()
        
        # self.lin = nn.Linear(self.hidden_dim, self.output_dim)
        self.lin1 = nn.Linear(self.num_concepts, self.output_dim)
        self.lin2 = nn.Linear(self.num_concepts, self.output_dim)

        self.dim_project = MLP(self.hidden_dim, self.num_concepts, self.hidden_dim)
        # self.lin = nn.Linear(self.hidden_dim+self.num_concepts, self.output_dim)
        # self.lin = nn.Linear(math.ceil(float(ratio) * self.num_concepts), self.output_dim)

        self.sim = nn.CosineSimilarity(dim=-1)
        self.criterion = nn.CrossEntropyLoss()

        self.text_lin = nn.Linear(self.input_dim, self.output_dim)

        # self.edge_param = nn.ModuleList([nn.Parameter(torch.empty((num_concepts, num_concepts))) for i in range(num_layers)])
        self.edge_param1 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        self.edge_param2 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        self.edge_param3 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        torch.nn.init.xavier_uniform(self.edge_param1)
        torch.nn.init.xavier_uniform(self.edge_param2)
        torch.nn.init.xavier_uniform(self.edge_param3)
        self.edge_param = [self.edge_param1, self.edge_param2, self.edge_param3]
        # for i in range(num_layers):
        #     torch.nn.init.xavier_uniform(self.edge_param[i])

        self.act = nn.LeakyReLU(0.1)
        self.c_lins = nn.Sequential(
            nn.Linear(self.num_concepts, self.num_concepts),
            nn.ReLU(),
            nn.Linear(self.num_concepts, self.num_concepts),
        )

        self.test_lin = nn.Linear(self.hidden_dim, self.output_dim)

        self.mlp = nn.Sequential(
            nn.Linear(self.input_dim, self.num_concepts),
            nn.ReLU(),
            nn.Linear(self.num_concepts, self.num_concepts)
        )

        self.score_layer = nn.Linear(self.input_dim, 1)
        self.ratio = ratio
    
    def init_module(self):
        torch.nn.init.eye_(self.q_proj.weight)
        torch.nn.init.eye_(self.k_proj.weight)
        torch.nn.init.eye_(self.v_proj.weight)

    def init_ground_truth_graph(self, weight):
        self.edge_param1.data = weight
        self.edge_param3.data = weight
        self.edge_param2.data = weight
    
    def init_edge_weight(self, weight):
        weight = torch.where(weight > self.beta, weight, 0)
        num_edges = (torch.count_nonzero(weight).item() - weight.size(0)) / 2
        print(f"Number of edges is {num_edges}")
        self.edge_param.data = weight

    def attention(self, query, key, value, mask=None, dropout=None):
        bs, d_k = query.size(0), query.size(-1)
        scores = torch.matmul(F.normalize(query, dim=-1), F.normalize(key, dim=-1).transpose(-2, -1)) #\
               # / math.sqrt(d_k)
        # print(scores)
        # print(scores.shape)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim = -1)
        # p_attn = scores
        # print(p_attn)
        if dropout is not None:
            p_attn = dropout(p_attn)
        # p_score = p_attn.view(bs, -1, 1)
        # print(p_score)
        # exit()
        if value.size(0) != bs:
            value = value.repeat(bs, 1, 1)
            return value, scores

        return torch.matmul(p_attn, value), scores

    def SelectTopK(self, x):
        k = math.ceil(float(self.ratio) * self.num_concepts)
        x = x.squeeze()
        if len(x.shape) == 1:
            x, indices = torch.sort(x.view(-1), descending=True)
            return indices[:k], indices[k:]
        else:
            x, indices = torch.sort(x, dim=1, descending=True)
            return indices[:, :k], indices[:, k:]

    def cl_loss(self, z1, z2, z3=None):
        device = z1.device
        
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0)) / self.temperature
        # print(cos_sim.shape)

        if z3 is not None:
            hard_negatives = self.sim(z1.unsqueeze(1), z3.unsqueeze(0)) / self.temperature
            cos_sim = torch.cat([cos_sim, hard_negatives], dim=-1)
        #     print(cos_sim.shape)
        # exit()
        labels = torch.arange(cos_sim.size(0)).long().to(device)
        loss = self.criterion(cos_sim, labels)
        return loss

    def get_degree(self, A):
        device = A.device
        I = torch.eye(A.size(1)).to(device)
        A = A + I
        A = torch.clamp(A, max=1, min=0)
        # print(A)
        D = torch.ceil(A).sum(dim=-1).to(device)
        return D.detach().cpu()

    def get_characteristic_matrix(self, A):
        device = A.device
        I = torch.eye(A.size(1)).to(device)
        A = F.relu(A) + I
        A = torch.clamp(A, max=1, min=0)
        # print(A)
        D = torch.ceil(A).sum(dim=-1).to(device)
        # print(D)
        D_sqrt = D.pow(-0.5).unsqueeze(-1)

        A = D_sqrt * A
        A = D_sqrt * A.t()
        return A.t()
    
    def intervention(self, concept_score, no_graph=False):
        if no_graph:
            return self.lin1(concept_score)

        c1 = concept_score.unsqueeze(-1)
        l1_regularizer = 0
        for edge_param in self.edge_param:
            adj = 0.5 * (edge_param + edge_param.t())
            new_adj = self.get_characteristic_matrix(adj)
            c1 = F.relu(torch.matmul(new_adj, c1) + c1)
            
        
        concept_score = c1.squeeze() + concept_score

        return self.lin1(concept_score)
    
    def get_score(self, x, text_features=None):
        query = self.act(self.q_proj(x))
        if text_features is not None:
            key = self.act(self.k_proj(text_features))
            value = self.act(self.v_proj(text_features))
        else:
            key = self.act(self.k_proj(x))
            value = self.act(self.v_proj(x))
        value, attns = self.attention(query, key, value)
        
        if text_features is not None:
            c0 = attns.unsqueeze(-1)
        else:
            c0 = attns[:, 0, 1:].squeeze().unsqueeze(-1)

        return c0.squeeze()


    
    def forward(self, x, target, graph=True, concepts=None, text_features=None, ssl=False, score_type=None):

        # pred = self.test_lin(x)[:, 0, :].squeeze()
        
        # # pred = self.text_lin(value)

        # loss = F.cross_entropy(pred, target, reduction='mean')
        if self.image_encoder is not None:
            x = self.image_encoder(x).projected_global_embedding
            text_features = text_features.repeat(x.shape[0], 1, 1)
            x = torch.cat([x.unsqueeze(1), text_features], dim=1)
        # return loss, pred, None
        if concepts is not None:
            c0 = self.mlp(x)
            # print(c0.shape, concepts.shape)
            # exit()
            loss = F.binary_cross_entropy(torch.sigmoid(c0), concepts)

            adj = 0.5 * (self.edge_param + self.edge_param.t())
            # adj = self.edge_param @ self.edge_param.t()
            adj = F.relu(adj)
        
            new_adj = self.get_characteristic_matrix(adj)
            l1_regularizer = 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())
            
            c1 = torch.matmul(new_adj, c0.unsqueeze(-1)).squeeze() + c0
            mi = 0
        
        else:
            bs, num_nodes, device = x.size(0), x.size(1), x.device
            loss = 0.0

            
            query = self.act(self.q_proj(x))
            # concept_self_score = self.score_layer(text_features)
            # text_features = concept_self_score * text_features
            if text_features is not None:
                key = self.act(self.k_proj(text_features))
                value = self.act(self.v_proj(text_features))
            else:
                key = self.act(self.k_proj(x))
                value = self.act(self.v_proj(x))
            # print(value.shape)
            value, attns = self.attention(query, key, value)
            
            if text_features is not None:
                c0 = attns.unsqueeze(-1)
            else:
                c0 = attns[:, 0, 1:].squeeze().unsqueeze(-1)
            
            # A = self.edge_param @ self.edge_param.t()
            # A = F.relu(A)
            # concept_score = torch.matmul(concept_score, A)
            if graph:
                # c0 = F.relu(c0)
                # print(c0)
                # exit()
                
                # if text_features is not None:
                #     image_feature = query
                #     text_feature = value  * c0
                #     # concept_self_score = self.score_layer((text_feature).mean(0))
                    
                #     # concept_self_score = self.score_layer(text_feature)
                # else:
                #     # value = value + x
                #     image_feature = value[:, 0, :].squeeze()
                #     text_feature = value[:, 1:, :]

                #     # concept_self_score = self.score_layer(text_feature.mean(dim=0))
                
                # concept_mask, negative_mask = self.SelectTopK(concept_self_score)
                # # print(concept_mask.shape)
                # # exit()
                # # c0 = c0 #+ concept_self_score
                # # exit()
                
                # adj = 0.5 * (self.edge_param + self.edge_param.t())
                # adj = F.relu(adj)
            
                # new_adj = self.get_characteristic_matrix(adj)
                # l1_regularizer = 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())

                # for gnn in self.layers:
                #     text_feature = self.act(gnn(new_adj, text_feature, c0))
                
                # graph_feature = text_feature * c0
                # graph_feature = graph_feature[:, concept_mask, :].mean(dim=1)
                # mi = self.cl_loss(image_feature, graph_feature)
                # # positive_feature, negative_feature = None, None
                # # for i, s in enumerate(graph_feature):
                # #     if positive_feature is None:
                # #         positive_feature = s[concept_mask[i]].unsqueeze(0)
                # #         negative_feature = s[negative_mask[i]].unsqueeze(0)
                # #     else:
                # #         positive_feature = torch.cat([positive_feature, s[concept_mask[i]].unsqueeze(0)], dim=0)
                # #         negative_feature = torch.cat([negative_feature, s[negative_mask[i]].unsqueeze(0)], dim=0)
                # # positive_feature, negative_feature = positive_feature.mean(dim=1), negative_feature.mean(dim=1)
                # # mi = self.cl_loss(image_feature, positive_feature, negative_feature)
                # # c1 = torch.matmul(F.normalize(graph_feature, dim=-1), F.normalize(image_feature, dim=-1).unsqueeze(-1))
                
                # c1 = torch.matmul(new_adj, c0).squeeze() + c0.squeeze()
                # c1 = c1[:, concept_mask]
                # # print(c1.shape, c0.shape)
                # # exit()
                

                ########################################################################################################
                # c0 = c0 #+ concept_self_score.unsqueeze(0)
                
                
                # print(value)
                # print(x)
                # exit()
                # print(text_feature.mean(dim=0))
                # print(x[:, 1:, :].mean(dim=0))
                # exit()
                # graph_new_param = F.normalize(text_feature.mean(dim=0), dim=-1) @ F.normalize(text_feature.mean(dim=0), dim=-1).t()
                # graph_new_param = torch.where(graph_new_param > self.beta, graph_new_param, 0)
                # graph_new_param = graph_new_param[concept_mask, :][:, concept_mask]
                # graph_new_param = torch.where(graph_new_param > graph_new_param.mean(), graph_new_param, 0)
                # positive1 = torch.zeros_like(self.edge_param.data[concept_mask, :], device=x.device)
                # positive = torch.zeros_like(self.edge_param.data, device=x.device)
                # positive1[:, concept_mask] = graph_new_param
                # positive[concept_mask, :] = positive1
                # self.edge_param.data = 0.9*self.edge_param.data + 0.1*positive

                # self.edge_param.data = 0.9*self.edge_param.data + 0.1*graph_new_param
                # print((self.edge_param.data[concept_mask, :][:, concept_mask]).mean(dim=-1))
                # self.edge_param.data = torch.clamp(self.edge_param.data, min=0, max=1)

                # adj = 0.5 * (self.edge_param[concept_mask, :][:, concept_mask] + self.edge_param[concept_mask, :][:, concept_mask].t())

                # c_pos = c0[:, concept_mask, :]
                # c_neg = c0[:, negative_mask, :]
                # pos_adj, neg_adj = new_adj[concept_mask, :][:, concept_mask], new_adj[negative_mask, :][:, negative_mask]

                # positive_feature = text_feature[:, concept_mask]
                # for gnn in self.layers:
                #     # print(positive_feature.shape)
                #     # exit()
                #     positive_feature = self.act(gnn(pos_adj, positive_feature, c_pos)) + positive_feature#+ text_feature

                # positive_feature = positive_feature * c_pos
                # graph_feature = positive_feature.mean(dim=1)

                # # hard negatives
                
                # negative_feature = text_feature[:, negative_mask]
                # for gnn in self.layers:
                #     negative_feature = self.act(gnn(neg_adj, negative_feature, c_neg)) + negative_feature#+ text_feature
                
                # negative_feature = negative_feature * c_neg
                # neg_graph_feature = negative_feature.mean(dim=1)

                
                # c_pos = torch.matmul(pos_adj, c_pos)
                # c_neg = torch.matmul(neg_adj, c_neg)
                # zeros = torch.zeros_like(c0, device=c0.device)
                # zeros[:, concept_mask, :] = c_pos
                # zeros[:, negative_mask, :] = c_neg
                # c1 = zeros.squeeze() #+ c0.squeeze()
                
                # mi = self.cl_loss(image_feature, graph_feature, neg_graph_feature)
                #############################################################################################################

                c0 = F.relu(c0)
                # adj = 0.5 * (self.edge_param + self.edge_param.t())
                # adj = self.edge_param @ self.edge_param.t()
                # adj = F.relu(adj)

                # adj = 0.5 * (self.edge_param + self.edge_param.t())
                # new_adj = self.get_characteristic_matrix(adj)
                # l1_regularizer = 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())
                
                if text_features is not None:
                    image_feature = query
                    text_feature = value  * c0
                    # concept_self_score = self.score_layer((text_feature).mean(0))
                    
                    # concept_self_score = self.score_layer(text_feature)
                else:
                    # value = value + x
                    image_feature = value[:, 0, :].squeeze()
                    text_feature = value[:, 1:, :]
                
                adjs = []
                l1_regularizer = 0
                for gnn, edge_param in zip(self.layers, self.edge_param):
                    adj = 0.5 * (edge_param + edge_param.t())
                    new_adj = self.get_characteristic_matrix(adj)
                    adjs.append(new_adj)
                    l1_regularizer += 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())
                    text_feature = self.act(gnn(new_adj, text_feature, c0)) #+ text_feature
                
                text_feature = text_feature * c0
                graph_feature = text_feature.mean(dim=1)
                
                if self.cdm:
                    presence = torch.bernoulli(F.sigmoid(self.presence(x)))
                    c1 = (c0.squeeze() * presence).unsqueeze(-1)
                else:
                    c1 = c0

                for i, new_adj in enumerate(adjs):
                    c1 = F.relu(torch.matmul(new_adj, c1) + c1)
                    
                
                c1 = c1.squeeze() + c0.squeeze()
                # print(c1)
                # exit()

                mi1 = self.cl_loss(graph_feature, image_feature)
                projected_x = self.dim_project(image_feature)

                mi2 = self.cl_loss(projected_x, c1)
                if score_type == 'mi1':
                    mi = mi1
                elif score_type == 'mi2':
                    mi = mi2
                else:
                    mi = mi1 + mi2
                    

                if ssl:
                    loss = mi
                    loss += self.alpha *l1_regularizer
                    return loss, c1, new_adj

                pred = self.lin2(projected_x)
                loss += F.cross_entropy(pred, target, reduction='mean')
            else:
                # print("Not graph structure")
                if self.cdm:
                    presence = torch.bernoulli(F.sigmoid(self.presence(x)))
                    c1 = c0.squeeze() * presence
                else:
                    c1 = c0.squeeze()
                mi, l1_regularizer = 0, 0
                new_adj = None
        # print(c0.squeeze())
        # print(c1)
        # exit()
        # pred = self.lin(c0.squeeze())
        # if self.cdm:
        #     presence = torch.bernoulli(F.sigmoid(self.presence(x)))
        #     pred = self.lin1(c1 * presence)
        # else:
        #     pred = self.lin1(c1)
        
        pred = self.lin1(c1)
        
        # pred = self.lin(x)
        # z = torch.cat([x, c1], dim=-1)
        # pred = self.lin(z)
        # pred = self.text_lin(value)
        # print(pred.shape)
        # print(F.cross_entropy(pred, target, reduction='mean'))
        # exit()
        # exit()

        loss += F.cross_entropy(pred, target, reduction='mean')
        loss += self.tau * mi
        loss += self.alpha *l1_regularizer
        if self.cdm:
            p = torch.FloatTensor([10e-4]).to(x.device)
            kl = F.kl_div(F.sigmoid(presence), p, reduction='mean')
            loss += 10e-4*kl
        return loss, pred, new_adj

    def graph_intervention(self, x, target, intervention_graph, graph=True, concepts=None, text_features=None, ssl=False):

        # pred = self.test_lin(x)[:, 0, :].squeeze()
        
        # # pred = self.text_lin(value)

        # loss = F.cross_entropy(pred, target, reduction='mean')
        if self.image_encoder is not None:
            x = self.image_encoder(x).projected_global_embedding
            text_features = text_features.repeat(x.shape[0], 1, 1)
            x = torch.cat([x.unsqueeze(1), text_features], dim=1)
        # return loss, pred, None
        if concepts is not None:
            c0 = self.mlp(x)
            # print(c0.shape, concepts.shape)
            # exit()
            loss = F.binary_cross_entropy(torch.sigmoid(c0), concepts)

            adj = 0.5 * (self.edge_param + self.edge_param.t())
            # adj = self.edge_param @ self.edge_param.t()
            adj = F.relu(adj)
        
            new_adj = self.get_characteristic_matrix(adj)
            l1_regularizer = 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())
            
            c1 = torch.matmul(new_adj, c0.unsqueeze(-1)).squeeze() + c0
            mi = 0
        
        else:
            bs, num_nodes, device = x.size(0), x.size(1), x.device
            loss = 0.0

            
            query = self.act(self.q_proj(x))
            # concept_self_score = self.score_layer(text_features)
            # text_features = concept_self_score * text_features
            if text_features is not None:
                key = self.act(self.k_proj(text_features))
                value = self.act(self.v_proj(text_features))
            else:
                key = self.act(self.k_proj(x))
                value = self.act(self.v_proj(x))
            # print(value.shape)
            value, attns = self.attention(query, key, value)
            
            if text_features is not None:
                c0 = attns.unsqueeze(-1)
            else:
                c0 = attns[:, 0, 1:].squeeze().unsqueeze(-1)
            
            # A = self.edge_param @ self.edge_param.t()
            # A = F.relu(A)
            # concept_score = torch.matmul(concept_score, A)
            if graph:

                c0 = F.relu(c0)
                if text_features is not None:
                    image_feature = query
                    text_feature = value  * c0
                    # concept_self_score = self.score_layer((text_feature).mean(0))
                    
                    # concept_self_score = self.score_layer(text_feature)
                else:
                    # value = value + x
                    image_feature = value[:, 0, :].squeeze()
                    text_feature = value[:, 1:, :]
                
                adjs = []
                l1_regularizer = 0
                for gnn, edge_param in zip(self.layers, self.edge_param):
                    adj = 0.5 * (edge_param + edge_param.t())
                    new_adj = self.get_characteristic_matrix(adj)
                    adjs.append(new_adj)
                    l1_regularizer += 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())
                    text_feature = self.act(gnn(new_adj+intervention_graph, text_feature, c0)) #+ text_feature
                
                text_feature = text_feature * c0
                graph_feature = text_feature.mean(dim=1)
                
                c1 = c0
                for i, new_adj in enumerate(adjs):
                    c1 = F.relu(torch.matmul(new_adj+intervention_graph, c1) + c1)
                    
                
                c1 = c1.squeeze() + c0.squeeze()
                # print(c1)
                # exit()

                mi1 = self.cl_loss(graph_feature, image_feature)
                projected_x = self.dim_project(image_feature)

                mi2 = self.cl_loss(projected_x, c1)
                mi = mi2+mi1

                if ssl:
                    loss = mi
                    loss += self.alpha *l1_regularizer
                    return loss, c1, new_adj

                pred = self.lin2(projected_x)
                loss += F.cross_entropy(pred, target, reduction='mean')
            else:
                # print("Not graph structure")
                c1 = c0.squeeze()
                mi, l1_regularizer = 0, 0
                new_adj = None
                
        pred = self.lin1(c1)
        loss += F.cross_entropy(pred, target, reduction='mean')
        loss += self.tau * mi
        loss += self.alpha *l1_regularizer
        return loss, pred, new_adj

def bin_concrete_sample(a, temperature=0.1, eps=1e-8):
    """"
    Sample from the binary concrete distribution
    """

    U = torch.rand_like(a).clamp(eps, 1. - eps)
    L = torch.log(U) - torch.log(1. - U)
    X = torch.sigmoid((L + a) / temperature)

    return X

import random
        
def intervene_on_graph(edge_param):
    component = {
        0: torch.LongTensor(
            [
                [0, 1],
                [1, 0],
                [0, 4],
                [0, 4],
                [1, 4],
                [4, 1],
            ]
        ),
        1: torch.LongTensor(
            [
                
                [2, 3],
                [3, 2],
                [2, 5],
                [5, 2],
                [2, 6],
                [6, 2],
                [3, 5],
                [5, 3], 
                [3, 6],
                [6, 3],
                [5, 6],
                [6, 5],
            ]
        ),
        2: torch.LongTensor(
            [
                
                [7, 8],
                [8, 7],
                [7, 9],
                [9, 7],
                [7, 10],
                [10, 7],
                [8, 9],
                [9, 8],
                [8, 10],
                [10, 8],
                [9, 10],
                [10, 9]
            ]
        )
    }
    num_edges = [6, 12, 12]
    mean_value = torch.mean(edge_param)
    for label, edges in component.items():
        active_values = []
        non_active_edges = []
        for edge in edges:
            i, j = edge
            if edge_param[i, j].item() != 0:
                active_values.append(edge_param[i, j].item())
            else:
                non_active_edges.append((i, j))
        if len(active_values) == 0:
            edge_value = 0
        else:
            edge_value = 0.1 * torch.mean(torch.tensor(active_values).to(edge_param.device))
        for edge in non_active_edges:
            i, j = edge
            edge_param[i, j] = edge_value
            # if random.random() < 0.25:
            #     edge_param[i, j] = edge_value
            #     edge_param[j, i] = edge_value
    
    return edge_param
        





class Linear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Linear, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.model = nn.Linear(self.input_dim, self.output_dim)
    
    def forward(self, x, target):
        preds = self.model(x)
        if len(preds.shape) == 3:
            preds = preds.mean(dim=1)
        loss = F.cross_entropy(preds, target, reduction='mean')
        return preds, loss

class End2EndModel(torch.nn.Module):
    def __init__(self, model1, model2, model3=None, model4=None, use_relu=False, use_sigmoid=False, use_graph=False, alpha=0.1, beta=0.1, temperature=0.3, n_class_attr=2):
        super(End2EndModel, self).__init__()
        self.first_model = model1
        self.sec_model = model2
        self.trd_model = model3
        self.fth_model = model4
        self.use_relu = use_relu
        self.use_sigmoid = use_sigmoid
        self.use_graph = use_graph

        self.alpha = alpha
        self.beta = beta

        if self.use_graph:
            # self.use_relu = False
            # self.use_sigmoid = False

            self.sim = nn.CosineSimilarity(dim=-1)
            self.criterion = nn.CrossEntropyLoss()
            self.temperature = temperature

            # self.graph_init()
    
    def graph_init(self, num_concepts):

        self.edge_param1 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        self.edge_param2 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        self.edge_param3 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        
        torch.nn.init.xavier_uniform(self.edge_param1)
        torch.nn.init.xavier_uniform(self.edge_param2)
        torch.nn.init.xavier_uniform(self.edge_param3)
        
        self.edge_param = [self.edge_param1, self.edge_param2, self.edge_param3]
        
        return

    def get_characteristic_matrix(self, A):
        device = A.device
        I = torch.eye(A.size(1)).to(device)
        A = F.relu(A) + I
        A = torch.clamp(A, max=1, min=0)
        # print(A)
        D = torch.ceil(A).sum(dim=-1).to(device)
        # print(D)
        D_sqrt = D.pow(-0.5).unsqueeze(-1)

        A = D_sqrt * A
        A = D_sqrt * A.t()
        return A.t()
    
    def cl_loss(self, z1, z2, z3=None):
        device = z1.device
        # print(z1.shape, z2.shape)
        # # print('*'*10)
        # print(z1)
        # print(z1, z2)
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0)) / self.temperature
        # print(cos_sim)
        if z3 is not None:
            hard_negatives = self.sim(z1.unsqueeze(1), z3.unsqueeze(0)) / self.temperature
            cos_sim = torch.cat([cos_sim, hard_negatives], dim=-1)
            
        labels = torch.arange(cos_sim.size(0)).long().to(device)
        loss = self.criterion(cos_sim, labels)
        return loss

    def message_passing(self, stage1_output):

        attr_outputs = stage1_output.unsqueeze(-1)
        # print(self.use_relu)
        # if self.use_relu:
        #     attr_outputs = F.gelu(stage1_output).unsqueeze(-1)
        # elif self.use_sigmoid:
        #     attr_outputs = F.sigmoid(stage1_output).unsqueeze(-1)
        l1_regularizer = 0
        for edge_param in self.edge_param:
            adj = 0.5 * (edge_param + edge_param.t())
            new_adj = self.get_characteristic_matrix(adj)
            # attr_outputs = F.gelu(torch.matmul(new_adj, attr_outputs) + attr_outputs)
            attr_outputs = F.gelu(torch.matmul(new_adj, attr_outputs))
            l1_regularizer += 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())

        # if self.use_relu:
        #     attr_outputs = attr_outputs.squeeze() + F.gelu(stage1_output)
        # elif self.use_sigmoid:
        #     attr_outputs = attr_outputs.squeeze() + F.sigmoid(stage1_output)
        attr_outputs = attr_outputs.squeeze() #+ stage1_output
        
        return attr_outputs, l1_regularizer

    def intervene(self, stage1_out):
        l1_regularizer = 0
        if self.use_graph:
            attr_outputs, l1_regularizer = self.message_passing(stage1_out)
            stage2_inputs = attr_outputs + stage1_out
            # stage1_out += attr_outputs
        elif self.use_relu:
            attr_outputs = F.gelu(stage1_out)
            stage2_inputs = attr_outputs
        elif self.use_sigmoid:
            attr_outputs = F.sigmoid(stage1_out)
            stage2_inputs = attr_outputs
        else:
            attr_outputs = stage1_out
            stage2_inputs = attr_outputs
        
        stage2_out = self.sec_model(stage2_inputs)
        return F.sigmoid(stage1_out), stage2_out
    
    def intervene_on_graph(self, stage1_out, graph_mask=None):
        l1_regularizer = 0
        if self.use_graph:
            attr_outputs, l1_regularizer = self.message_passing_with_mask(stage1_out, graph_mask)
            stage2_inputs = attr_outputs + stage1_out
            # stage1_out += attr_outputs
        elif self.use_relu:
            attr_outputs = F.gelu(stage1_out)
            stage2_inputs = attr_outputs
        elif self.use_sigmoid:
            attr_outputs = F.sigmoid(stage1_out)
            stage2_inputs = attr_outputs
        else:
            attr_outputs = stage1_out
            stage2_inputs = attr_outputs
        
        stage2_out = self.sec_model(stage2_inputs)
        return F.sigmoid(stage1_out), stage2_out
    
    def message_passing_with_mask(self, stage1_output, graph_mask=None):

        attr_outputs = stage1_output.unsqueeze(-1)
        # print(self.use_relu)
        # if self.use_relu:
        #     attr_outputs = F.gelu(stage1_output).unsqueeze(-1)
        # elif self.use_sigmoid:
        #     attr_outputs = F.sigmoid(stage1_output).unsqueeze(-1)
        l1_regularizer = 0
        # edge_param = (self.edge_param1 + self.edge_param2 + self.edge_param3)
        # adj = 0.5 * (edge_param + edge_param.t())
        # new_adj = self.get_characteristic_matrix(adj)
        # extra_adj = intervene_on_graph(new_adj.clone())
        # if graph_mask is not None:
        #     extra_adj = extra_adj * graph_mask
        # attr_outputs = F.gelu(torch.matmul(new_adj, attr_outputs) + attr_outputs)
        # attr_outputs = F.gelu(torch.matmul(extra_adj, attr_outputs))
        # l1_regularizer += 0.5*(torch.norm(extra_adj, p=1, dim=-1).mean() + torch.norm(extra_adj.T, p=1, dim=-1).mean())
        for edge_param in self.edge_param:
            adj = 0.5 * (edge_param + edge_param.t())
            # extra_adj = intervene_on_graph(F.relu(adj))
            # if graph_mask is not None:
            #     extra_adj = extra_adj * graph_mask
            # adj = 0.5 * adj + 0.5 * extra_adj
            new_adj = self.get_characteristic_matrix(adj)
            if graph_mask is not None:
                new_adj = new_adj * graph_mask
            # extra_adj = intervene_on_graph(new_adj.clone())
            # new_adj = extra_adj
            # if graph_mask is not None:
            #     nonzero_indices = torch.nonzero(new_adj)
            #     for edge in nonzero_indices:
            #         i, j = edge
            #         if random.random() < 0.25:
            #             graph_mask[i, j] = 1
            #             graph_mask[j, i] = 1
            #     new_adj = new_adj * graph_mask
            # extra_adj = intervene_on_graph(new_adj.clone())
            # new_adj = extra_adj
            # print(new_adj)
            # print(extra_adj)
            # exit()
            # new_adj = (1 - 1e-10) * new_adj + 1e-10 * extra_adj
            
            
            # attr_outputs = F.gelu(torch.matmul(new_adj, attr_outputs) + attr_outputs)
            attr_outputs = F.gelu(torch.matmul(new_adj, attr_outputs))
            l1_regularizer += 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())
        
        

        
            

        # if self.use_relu:
        #     attr_outputs = attr_outputs.squeeze() + F.gelu(stage1_output)
        # elif self.use_sigmoid:
        #     attr_outputs = attr_outputs.squeeze() + F.sigmoid(stage1_output)
        attr_outputs = attr_outputs.squeeze() #+ stage1_output
        
        return attr_outputs, l1_regularizer
        

    def forward_stage2(self, stage1_out):
        l1_regularizer = 0
        if self.use_graph:
            attr_outputs, l1_regularizer = self.message_passing(stage1_out)
            stage2_inputs = attr_outputs + stage1_out
            # stage1_out += attr_outputs
        elif self.use_relu:
            attr_outputs = F.gelu(stage1_out)
            stage2_inputs = attr_outputs
        elif self.use_sigmoid:
            attr_outputs = F.sigmoid(stage1_out)
            stage2_inputs = attr_outputs
        else:
            attr_outputs = stage1_out
            stage2_inputs = attr_outputs
        
        stage2_out = self.sec_model(stage2_inputs)
        return F.sigmoid(stage1_out), stage2_out, attr_outputs, l1_regularizer
        # return F.sigmoid(stage1_out), stage2_out, F.gelu(stage1_out), l1_regularizer
    
    def forward_stage1(self, x):
        return self.first_model(x)

    def get_score(self, x):
        return self.forward_stage1(x)

    def forward(self, x, target, concept=None):
        if concept is None:
            concept_pred = F.sigmoid(self.forward_stage1(x))
            target_pred = self.sec_model(x)
            loss = F.cross_entropy(target_pred, target, reduction='mean')
            return loss, concept_pred, target_pred
        outputs = self.forward_stage1(x)
        if self.trd_model is not None:
            auxiliary_score = F.gelu(self.trd_model(x))
            auxiliary = self.fth_model(auxiliary_score)
        else:
            auxiliary_score = None
            auxiliary = None

        concept_pred, target_pred, concept_score, l1_regularizer = self.forward_stage2(outputs)
        # print(target_pred.shape)
        loss = F.cross_entropy(target_pred, target, reduction='mean')

        if concept is not None:
            loss += F.binary_cross_entropy(concept_pred, concept, reduction='mean')

        if auxiliary is not None:
            # print(concept_score)
            loss += self.alpha * self.cl_loss(auxiliary_score, concept_score)
            loss += F.cross_entropy(auxiliary, target, reduction='mean')
        
        loss += self.beta * l1_regularizer

        return loss, concept_pred, target_pred

class MLP(nn.Module):
    def __init__(self, input_dim, num_classes, expand_dim=None):
        super(MLP, self).__init__()
        self.expand_dim = expand_dim
        if self.expand_dim:
            self.linear = nn.Linear(input_dim, expand_dim)
            self.activation = torch.nn.ReLU()
            self.linear2 = nn.Linear(expand_dim, num_classes) #softmax is automatically handled by loss function
        else:
            self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        x = self.linear(x)
        if hasattr(self, 'expand_dim') and self.expand_dim:
            x = self.activation(x)
            x = self.linear2(x)
        return x

class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, nhead=1, num_layers=2, dropout_rate=0.1):
        super(Transformer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.nhead = nhead
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.input_dim, nhead=self.nhead),
                                            num_layers=self.num_layers)
        self.decoder = nn.Linear(self.input_dim, self.output_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.dropout(self.encoder(x))
        x = x.mean(dim=1).squeeze()
        return self.decoder(x), x

# class ConceptGNN(torch.nn.Module):
#     def __init__(self, input_dim, output_dim, hidden_dim=1024, module='gat', num_layers=1, 
#                 temperature=0.05, tau=0.93, num_concepts=370, grad_lambda=0.2):
#         super(ConceptGNN, self).__init__()
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.hidden_dim = hidden_dim
#         self.num_layers = num_layers
#         self.temperature = temperature
#         self.tau = tau
#         self.num_concepts = num_concepts
#         self.grad_lambda = grad_lambda
#         if module == 'gat':
#             self.layers1 = nn.ModuleList([GATConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
#             self.layers2 = nn.ModuleList([GATConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
#         elif module == 'gat_v2':
#             self.layers1 = nn.ModuleList([GATv2Conv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
#             self.layers2 = nn.ModuleList([GATv2Conv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
#         elif module == 'gcn':
#             self.layers1 = nn.ModuleList([GCNConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
#             self.layers2 = nn.ModuleList([GCNConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
            
#         self.classifier = GCNConv(1, 1, edge_dim=1, bias=True) # bias term important!
#         # nn.init.eye_(self.classifier.lin.weight.data)
#         # self.classifier.requires_grad = False
#         self.bns1 = nn.ModuleList([nn.BatchNorm1d(self.hidden_dim) for _ in range(num_layers)])
#         self.bns2 = nn.ModuleList([nn.BatchNorm1d(self.hidden_dim) for _ in range(num_layers)])
#         self.act = nn.LeakyReLU(0.1)

#         self.linear = nn.Linear(self.num_concepts, self.output_dim, bias=True)
#         self.lin = nn.Linear(self.hidden_dim, self.output_dim)

#         self.edge_param = nn.Parameter(torch.empty(num_concepts, num_concepts))
#         torch.nn.init.xavier_uniform(self.edge_param)

#         # self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
#         # self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
#         # self.classifier = 
#         self.sim = nn.CosineSimilarity(dim=-1)
#         self.criterion = nn.CrossEntropyLoss()

#     def update(self, growth_rate=1.005, threshold=0.95):
#         self.tau = min(self.tau * growth_rate, threshold)
#         return

#     # def update(self, decay_rate=0.98, threshold=0.4):
#     #     self.tau = max(self.tau * decay_rate, threshold)
#     #     return

#     def attention(self, query, key, value, mask=None, dropout=None):
#         "Compute 'Scaled Dot Product Attention'"
#         bs, d_k = query.size(0), query.size(-1)
#         scores = torch.matmul(query, key.transpose(-2, -1)) \
#                 / math.sqrt(d_k)
#         if mask is not None:
#             scores = scores.masked_fill(mask == 0, -1e9)
#         p_attn = F.softmax(scores, dim = -1)
#         if dropout is not None:
#             p_attn = dropout(p_attn)
#         p_attn = p_attn.view(bs, -1, 1)
#         return p_attn * value, p_attn
    
#     def adj_to_edge_index(self, batch):
#         bs = len(torch.unique(batch))
#         symm = self.edge_param @ self.edge_param.t()
#         symm = F.relu(torch.triu(symm, diagonal=1), self.tau)
#         edge_index = get_edge_index(symm)
#         edge_index = edge_index.repeat(bs) + torch.unique(batch)
#         return edge_index.t()


#     def cl_loss(self, z1, z2):
#         device = z1.device
#         batch_size = z1.size(0)

#         # z1 = self.projection_head(z1)
#         # z2 = self.projection_head(z2)

#         z1 = F.normalize(z1, dim=-1)
#         z2 = F.normalize(z2, dim=-1)

#         cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0)) / self.temperature

#         labels = torch.arange(cos_sim.size(0)).long().to(device)

#         loss = self.criterion(cos_sim, labels)
#         return loss
    
#     def graph_construct(self, x, edge_index, image_feature, batch):
#         batch_size = image_feature.size(0)

#         x_prime = x
#         concept_score = F.normalize(x.view(batch_size, -1, x.size(-1)), dim=-1) @ F.normalize(image_feature, dim=-1).unsqueeze(-1)
#         concept_score = concept_score.view(-1, 1)
#         for i, layer in enumerate(self.layers1):
#             x_prime = layer(x_prime, edge_index)
#             x_prime = self.bns1[i](x_prime)
#             x_prime = self.act(x_prime)

#             # x_prime = x_prime.view(batch_size, -1, x_prime.shape[-1])

#             # concept_score = None
#             # for j in range(batch_size):
#             #     if concept_score is None:
#             #         concept_score = x_prime[j, :, :].squeeze() @ image_feature[j].unsqueeze(-1)
#             #     else:
#             #         concept_score = torch.cat([concept_score, 
#             #             x_prime[j, :, :].squeeze() @ image_feature[j].unsqueeze(-1)], dim=0)
            
#             x_prime = x_prime * concept_score

#         concept_feature = global_mean_pool(x_prime, batch)

#         # query, key, value = image_feature.unsqueeze(1), x_prime.view(batch_size, -1, x.size(-1)), x.view(batch_size, -1, x.size(-1))
#         # concept_feature, concept_score = self.attention(query, key, value)

#         # concept_feature = concept_feature.sum(dim=1).squeeze()
#         # concept_score = concept_score.squeeze().view(batch_size, -1)

#         loss = self.cl_loss(concept_feature, image_feature)

#         grad = torch.autograd.grad(loss, x, create_graph=True)[0]
#         grad = grad.view(batch_size, -1, grad.shape[-1])

#         edge_index = None
#         grad_align = 0.0
#         for k in range(batch_size):
#             gradient = grad[k, :, :].squeeze()
#             similarity = F.normalize(gradient, dim=-1) @ F.normalize(gradient, dim=-1).t()
            
#             grad_align += torch.mean(1.0 - similarity)
#             similarity = torch.triu(similarity, diagonal=1)
#             similarity = torch.where(similarity > self.tau, similarity, 0)

#             if edge_index is None:
#                 edge_index = (torch.nonzero(similarity) + k * similarity.size(0)).long()
#             else:
#                 edge_index = torch.cat([
#                     edge_index, (torch.nonzero(similarity) + k * similarity.size(0)).long()
#                 ], dim=0)
#         #     print(edge_index.shape)
#         # exit()
                
#         edge_index = edge_index.t()
        
#         # loss += self.grad_lambda * (grad_align / batch_size)
    
#         return loss, edge_index, concept_score

#     def propagation(self, concept_vectors, edge_indices):
#         updated_cv = None
#         device = concept_vectors.device
#         for concept_vector, edge_index in zip(concept_vectors, edge_indices):
#             adj = to_torch_sparse_tensor(edge_index, size=concept_vectors.size(-1)).to_dense()
#             adj += torch.eye(concept_vectors.size(-1)).long().to(device)
#             adj = adj.to(device)
#             concept_vector = (concept_vector @ adj.T).unsqueeze(0)
#             if updated_cv is None:
#                 updated_cv = concept_vector
#             else:
#                 updated_cv = torch.cat([updated_cv, concept_vector], dim=0)
                
        
#         return updated_cv
    
#     def forward(self, x=None, edge_index=None, image_feature=None, batch=None, 
#                 target=None, concept_score=None, end2end=False, graph_construction=False):
#         batch_size, num_concepts = concept_score.shape
#         if not end2end:
#             # key = self.act(self.k_proj(x.view(batch_size, -1, x.size(-1))))
#             # query = self.act(self.q_proj(image_feature)).unsqueeze(1)
#             # x, attn = self.attention(query=query, key=key, value=key)
#             # print(x.shape)
#             # print(attn.squeeze())
#             # exit()
#             if graph_construction:
#                 loss, edge_index, concept_score = self.graph_construct(x, edge_index, image_feature, batch)
#                 edge_indices = utils.unbatch_edge_index(edge_index, batch, batch_size=image_feature.size(0))
#                 edge_index = torch_geometric.utils.to_undirected(edge_index)
#                 # concept_score = self.propagation(concept_score, edge_indices)
#             else:
#                 loss = 0.0
#                 edge_indices = utils.unbatch_edge_index(edge_index, batch, batch_size=image_feature.size(0))
#                 # concept_score = self.propagation(concept_score, edge_indices)
#             concept_score = concept_score.view(-1, 1)

#             x = x * concept_score
#             for i, layer in enumerate(self.layers1):
#                 x = layer(x, edge_index)
#                 x = self.bns1[i](x)
#                 x = self.act(x)
            
#             x = global_add_pool(x, batch)
#             # x = torch.sum(x.view(batch_size, -1, x.size(-1)) * concept_score.unsqueeze(-1), dim=1).squeeze()
            
#             # concept_score = concept_score.view(batch_size * num_concepts, 1)
#             # concept_score = self.classifier(concept_score, edge_index) + concept_score
#             # concept_score = concept_score.view(batch_size, -1)
#         else:
#             loss = 0.0
#             edge_indices = None
        
#         # pred = self.linear(concept_score)
#         pred = self.lin(x)
        
#         loss += F.cross_entropy(pred, target, reduction='mean')
#         return pred, loss, edge_indices

        
        


            




class GNN(torch.nn.Module):
    def __init__(self, input_dim, output_dim, module='gat', hidden_dim=1024, linear_input_dim=None, num_layers=1, ratio=0.5, temperature=0.05, pooling='sag'):
        super(GNN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.temperature = temperature
        if linear_input_dim is None:
            self.linear_input_dim = self.hidden_dim
        else:
            self.linear_input_dim = linear_input_dim
        self.ratio = ratio
        if module == 'gat':
            self.layers1 = nn.ModuleList([GATConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
            self.layers2 = nn.ModuleList([GATConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
            self.gnn_classifier = GATConv(hidden_dim, 1, edge_dim=1)
        elif module == 'gatv2':
            self.layers1 = nn.ModuleList([GATv2Conv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
            self.layers2 = nn.ModuleList([GATv2Conv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
            self.gnn_classifier = GATv2Conv(hidden_dim, 1, edge_dim=1)
        elif module == 'gcn':
            self.layers1 = nn.ModuleList([GCNConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
            self.layers2 = nn.ModuleList([GCNConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
            self.gnn_classifier = GCNConv(hidden_dim, 1, edge_dim=1)
        self.linears = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.bns1 = nn.ModuleList([nn.BatchNorm1d(self.hidden_dim) for _ in range(num_layers)])
        self.bns2 = nn.ModuleList([nn.BatchNorm1d(self.hidden_dim) for _ in range(num_layers)])
        self.act = nn.LeakyReLU(0.1)
        self.dropout = nn.Dropout(p=0.1)
        self.regularizer = nn.KLDivLoss(reduction="mean")
        if self.ratio >= 1:
            self.pooling = None
        else:
            if pooling == 'sag':
                self.pooling = SAGPooling(in_channels=hidden_dim, ratio=self.ratio)
            elif pooling == 'fisher':
                self.pooling = FisherPooling(in_channels=hidden_dim, ratio=self.ratio)
            elif pooling == 'sim':
                self.pooling = SimPooling(in_channels=hidden_dim, ratio=self.ratio)
            self.pooling_method = pooling
        self.lin = nn.Linear(self.linear_input_dim, output_dim)
        self.lin1 = nn.Linear(self.hidden_dim, output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.sim = nn.CosineSimilarity(dim=-1)
        self.projection_head = nn.Sequential(
            nn.Linear(self.linear_input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256)
        )

    def update_edge_weight(self, x, text_features, batch_size):
        x = x.view(batch_size, -1, x.shape[-1])
        image_features = x[:, 0, :]
        norm = torch.clamp(image_features.norm(dim=-1, keepdim=True), min=1e-8)
        image_features /= norm
        image_features = image_features.squeeze()
        edge_weight = image_features @ text_features.t()
        return edge_weight.flatten().unsqueeze(-1)
    
    def attribute_masking(self, input_feature, batch_size, drop_percent=0.2):

        node_num = input_feature.shape[0]
        mask_num = int(node_num * drop_percent)
        node_idx = [i for i in range(node_num)]
        mask_idx = random.sample(node_idx, mask_num)
        aug_feature = copy.deepcopy(input_feature)
        zeros = torch.zeros_like(aug_feature[0])
        for j in mask_idx:
            if j % batch_size != 0:
                aug_feature[j] = zeros
        return aug_feature
    
    def d_forward(self, image_features, text_features, edge_index, edge_weight, batch):
        perturbation_rate = random.random()

    def cl_loss(self, z1, z2):
        device = z1.device
        batch_size = z1.size(0)

        # z1 = self.projection_head(z1)
        # z2 = self.projection_head(z2)

        z1 = z1 ** 3
        z2 = z2 ** 3

        z1 = F.normalize(z1, dim=-1)
        z2 = F.normalize(z2, dim=-1)

        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0)) / self.temperature

        labels = torch.arange(cos_sim.size(0)).long().to(device)

        loss = self.criterion(cos_sim, labels)
        return loss, cos_sim, batch_size
    
    def cl_foward(self, image_features, text_features, edge_index, edge_weight, batch, perturbation_rate=0.2):
        device = edge_index.device
        batch_size = len(torch.unique(batch))

        num_perturb = int(text_features.size(0) * perturbation_rate)
        indices1 = torch.randperm(text_features.size(0))[:num_perturb]
        indices2 = torch.randperm(text_features.size(0))[:num_perturb]

        perturbation1, perturbation2 = text_features.clone(), text_features.clone()
        perturbation1[indices1] = torch.randn_like(perturbation1[indices1])
        perturbation2[indices2] = torch.randn_like(perturbation2[indices2])
        perturbation3 = torch.rand_like(text_features)

        image_features = F.normalize(image_features, dim=-1)
        perturbation1 = F.normalize(perturbation1, dim=-1)
        perturbation2 = F.normalize(perturbation2, dim=-1)
        perturbation3 = F.normalize(perturbation3, dim=-1)

        z1 = (image_features @ perturbation1.t()).to(device)
        z2 = (image_features @ perturbation2.t()).to(device)
        z3 = (image_features @ perturbation3.t()).to(device)

        z1 = z1.view(-1, z1.shape[-1])
        z2 = z2.view(-1, z2.shape[-1])
        z3 = z3.view(-1, z3.shape[-1])

        # print(z1.shape)

        for i, layer in enumerate(self.layers):
            z1 = layer(z1, edge_index, edge_weight)
            z2 = layer(z2, edge_index, edge_weight)
            z3 = layer(z3, edge_index, edge_weight)
            
            z1 = self.bns[i](z1)
            z1 = F.relu(z1)
            # z1 = self.act(z1)
            z2 = self.bns[i](z2)
            z2 = F.relu(z2)
            # z2 = self.act(z2)
            z3 = self.bns[i](z3)
            z3 = F.relu(z3)
            # z3 = self.act(z3)

            # z1 = self.linears[i](z1)
            # z1 = self.act(z1)
            # z2 = self.linears[i](z2)
            # z2 = self.act(z2)
            # z3 = self.linears[i](z3)
            # z3 = self.act(z3)

        z1 = z1.view(batch_size, -1, z1.shape[-1])
        z1 = z1[:, 0, :]
        z2 = z2.view(batch_size, -1, z2.shape[-1])
        z2 = z2[:, 0, :]
        z3 = z3.view(batch_size, -1, z3.shape[-1])
        z3 = z3[:, 0, :]

        z1 = self.projection_head(z1)
        z2 = self.projection_head(z2)
        z3 = self.projection_head(z3)

        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0)) / self.temperature
        z1_z3_cos = self.sim(z1.unsqueeze(1), z3.unsqueeze(0)) / self.temperature
        cos_sim = torch.cat([cos_sim, z1_z3_cos], 1)

        labels = torch.arange(cos_sim.size(0)).long().to(device)

        loss = self.criterion(cos_sim, labels)
        return loss, cos_sim, batch_size
        # positives = (z1 * z2).sum(dim=-1)


        #concept1 and concept2 are positive pairs, and concept3 is the negative
    
    def propagation(self, concept_vectors, edge_indices):
        updated_cv = None
        device = concept_vectors.device
        
        for concept_vector, edge_index in zip(concept_vectors, edge_indices):
            adj = to_scipy_sparse_matrix(edge_index).toarray()
            adj = torch.FloatTensor(adj)[1:, 1:] + torch.eye(concept_vectors.size(-1))
            adj = adj.to(device)
            concept_vector = (concept_vector @ adj.T).unsqueeze(0)
            if updated_cv is None:
                updated_cv = concept_vector
            else:
                updated_cv = torch.cat([updated_cv, concept_vector], dim=0)
        
        return updated_cv


    def forward(self, x, edge_index, target, batch, attn=None, edge_weight=None, end2end=False, text_features=None, concept_score=None):
        if batch is not None:
            batch_size = len(torch.unique(batch))
        else:
            batch_size = 1

        loss = 0.0
        # if target is None:
        #     # if not self.training:
        #     #     if self.pooling is not None:
        #     #         if self.pooling_method == 'fisher':
        #     #             x, edge_index, edge_weight, batch, perm, score, mutual_info = self.pooling(x, edge_index, edge_weight, batch, attn)
        #     #             loss += mutual_info
        #     #         else:
        #     #             x, edge_index, edge_weight, batch, perm, score = self.pooling(x, edge_index, edge_weight, batch, attn)
        #     #     else:
        #     #         perm, score = None, None
        #     #     for i, layer in enumerate(self.layers):
        #     #         x = layer(x, edge_index, edge_weight)
        #     #         x = self.bns[i](x)
        #     #         x = self.act(x)
                
        #     #     x = x.view(batch_size, -1, x.shape[-1])
        #     #     x = x[:, 0, :]

        #     #     return x
                
        #     x1, x2 = self.attribute_masking(x, batch_size), self.attribute_masking(x, batch_size)
        #     if self.pooling is not None:
        #         if self.pooling_method == 'fisher':
        #             x1, edge_index1, edge_weight1, batch1, perm1, score1, mutual_info1 = self.pooling(x1, edge_index, edge_weight, batch, attn)
        #             x2, edge_index2, edge_weight2, batch2, perm2, score2, mutual_info2 = self.pooling(x2, edge_index, edge_weight, batch, attn)
        #             loss += mutual_info
        #         else:
        #             x1, edge_index1, edge_weight1, batch1, perm1, score1 = self.pooling(x1, edge_index, edge_weight, batch, attn)
        #             x2, edge_index2, edge_weight2, batch2, perm2, score2 = self.pooling(x2, edge_index, edge_weight, batch, attn)
        #     else:
        #         perm, score = None, None
            
        #     for i, layer in enumerate(self.layers):
        #         x1 = layer(x1, edge_index1, edge_weight1)
        #         x2 = layer(x2, edge_index2, edge_weight2)
                
        #         x1 = self.bns[i](x1)
        #         x1 = self.act(x1)
        #         x2 = self.bns[i](x2)
        #         x2 = self.act(x2)

        #         x1 = self.linears[i](x1)
        #         x1 = self.act(x1)
        #         x2 = self.linears[i](x2)
        #         x2 = self.act(x2)

        #     # x1 = global_mean_pool(x1, batch1)
        #     # x2 = global_mean_pool(x2, batch2)

        #     x1 = x1.view(batch_size, -1, x.shape[-1])
        #     x1 = x1[:, 0, :]
        #     x2 = x2.view(batch_size, -1, x.shape[-1])
        #     x2 = x2[:, 0, :]

        #     z1 = self.projection_head(x1)
        #     z2 = self.projection_head(x2)

        #     z1 = torch.nn.functional.normalize(z1, dim=-1)
        #     z2 = torch.nn.functional.normalize(z2, dim=-1)

        #     similarities = torch.matmul(z1, z2.T) / self.temperature
        #     # exit()

        #     labels = torch.arange(similarities.size(0)).to(z1.device)

        #     loss = self.criterion(similarities, labels)

        #     return loss, batch_size, similarities

        # edge_weights = [edge_weight]
        if end2end:
            image_feature = x.view(batch_size, -1, x.shape[-1])[:, 0, :]
            if text_features is not None:
                image_feature = F.normalize(image_feature, dim=-1) @ F.normalize(text_features, dim=-1).t()
            # print(image_feature.shape)
            # exit()
            pred = self.lin(image_feature)
            perm = None
            score = None
        else:
            image_feature = x.view(batch_size, -1, x.shape[-1])[:, 0, :]

            for i, layer in enumerate(self.layers1):
                x = layer(x, edge_index, edge_weight)
                x = self.bns1[i](x)
                x = self.act(x)

            # if self.pooling is not None:
            #     if self.pooling_method == 'fisher':
            #         x, edge_index, edge_weight, batch, perm, score, mutual_info = self.pooling(x, edge_index, edge_weight, batch, attn)
            #         loss += mutual_info
            #     else:
            #         x, edge_index, edge_weight, batch, perm, score = self.pooling(x, edge_index, edge_weight, batch, attn)
            #         # print(perm.view(batch_size, -1))
            #         # exit()
            # else:
            #     perm, score = None, None
            perm, score = None, None

            # for i, layer in enumerate(self.layers2):
            #     x = layer(x, edge_index, edge_weight)
            #     x = self.bns2[i](x)
            #     x = self.act(x)

            # Node classification: extract the image node from concept graph
            # TODO: 
            # x = global_mean_pool(x, batch)
            c = self.act(self.gnn_classifier(x, edge_index, edge_weight))
            c = c.view(batch_size, -1, c.shape[-1]).squeeze()
            c = c[:, 1:]
            # print(c)
            
            # print(c.shape)
            # exit()
            x = x.view(batch_size, -1, x.shape[-1])
            x = x[:, 1:, :].mean(dim=1).squeeze()

            edge_indices = utils.unbatch_edge_index(edge_index, batch)

            # exit()

            # x = x.view(batch_size, -1, x.shape[-1])
            # x = x[:, 0, :]
            # x = F.softmax(x, dim=-1)
            # x += image_feature

            # print(x)
            # exit()
            # x = self.dropout(x)
            if text_features is not None and concept_score is not None:
                # mutual_info, _, _ = self.cl_loss(x, image_feature)
                # loss += mutual_info
                out = F.normalize(x, dim=-1) @ F.normalize(text_features, dim=-1).t()
                # loss += F.mse_loss(out, concept_score, reduction='mean')
                loss += -torch.mean(similarity_fn(concept_score, out))
                # print(loss)
                # exit()
                # print(x)
                # print(loss)
                out = self.propagation(c, edge_indices)
                pred = self.lin(c)
                # pred = self.lin(out)
                # print(pred)
                # exit()
                # print(pred)
            elif target is not None and concept_score is not None:
                # x = F.normalize(x, dim=-1) @ F.normalize(text_features, dim=-1).t()
                pred = self.lin(x)
                # pred = self.lin(c)
            else:
                pred = self.lin1(x)

            # if self.training:
            #     positives =  
        if target is None:
            # return self.cl_loss(out, concept_score)
            return -torch.mean(similarity_fn(concept_score, out)), out, x.size(0)
        else:
            loss += F.cross_entropy(pred, target, reduction='mean')
        # reg = self.regularizer(F.log_softmax(pred, dim=-1), F.log_softmax(original_pred, dim=-1))

        # loss += 0.3 * reg

        # return pred, x, loss, batch_size, edge_index, perm, score
        return pred, c, loss, batch_size, edge_index, perm, score

class GCN(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=1024, num_layers=1, ratio=0.5, pooling='sag'):
        super(GCN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.ratio = ratio
        self.layers = nn.ModuleList([GATConv(input_dim, hidden_dim, edge_dim=1) for _ in range(num_layers)])
        self.bns = nn.ModuleList([nn.BatchNorm1d(self.hidden_dim) for _ in range(num_layers)])
        self.act = nn.LeakyReLU(0.1)
        self.dropout = nn.Dropout(p=0.5)
        self.regularizer = nn.KLDivLoss(reduction="mean")
        if pooling == 'sag':
            self.pooling = SAGPooling(in_channels=hidden_dim, ratio=self.ratio)
        elif pooling == 'fisher':
            self.pooling = FisherPooling(in_channels=hidden_dim, ratio=self.ratio)
        elif pooling == 'sim':
            self.pooling = SimPooling(ratio=self.ratio)
        self.pooling_method = pooling
        self.lin = nn.Linear(hidden_dim, output_dim)

    def update_edge_weight(self, x, text_features, batch_size):
        x = x.view(batch_size, -1, x.shape[-1])
        image_features = x[:, 0, :]
        norm = torch.clamp(image_features.norm(dim=-1, keepdim=True), min=1e-8)
        image_features /= norm
        image_features = image_features.squeeze()
        edge_weight = image_features @ text_features.t()
        return edge_weight.flatten().unsqueeze(-1)

    def forward(self, x, edge_index, target, text_features, batch, attn=None, edge_weight=None, end2end=False):
        if batch is not None:
            batch_size = len(torch.unique(batch))
        else:
            batch_size = 1

        loss = 0.0
        # edge_weights = [edge_weight]
        if end2end:
            image_feature = x.view(batch_size, -1, x.shape[-1])[:, 0, :]
            pred = self.lin(image_feature)
            perm = None
            score = None
        else:
            image_feature = x.view(batch_size, -1, x.shape[-1])[:, 0, :]
            if self.pooling_method == 'fisher':
                x, edge_index, edge_weight, batch, perm, score, mutual_info = self.pooling(x, edge_index, edge_weight, batch, attn)
                loss += mutual_info
            else:
                x, edge_index, edge_weight, batch, perm, score = self.pooling(x, edge_index, edge_weight, batch, attn)
            # perm = None
            for i, layer in enumerate(self.layers):
                x = layer(x, edge_index, edge_weight)
                x = self.bns[i](x)
                x = self.act(x)

            # Node classification: extract the image node from concept graph
            x = x.view(batch_size, -1, x.shape[-1])
            x = x[:, 0, :]
            # x += image_feature
            x = x.squeeze()
            x = self.dropout(x)
            
            pred = self.lin(x)

        loss += F.cross_entropy(pred, target, reduction='mean')
        # reg = self.regularizer(F.log_softmax(pred, dim=-1), F.log_softmax(original_pred, dim=-1))

        # loss += 0.3 * reg

        return pred, loss, batch_size, edge_index, perm, score


# class SAGE(torch.nn.Module):
#     def __init__(self, input_dim, output_dim, hidden_dim=512, num_nodes=200, coarsening=False, tau=0.7, eps=0.01, niter=100):
#         super(SAGE, self).__init__()
#         self.input_dim = input_dim
#         self.output_dim = output_dim
#         self.hidden_dim = hidden_dim
#         self.num_nodes = num_nodes
#         self.coarsening = coarsening

#         self.conv1 = SAGEConv(input_dim, hidden_dim)
#         self.bn1 = nn.BatchNorm1d(self.hidden_dim)
#         self.conv2 = SAGEConv(hidden_dim, hidden_dim)
#         self.bn2 = nn.BatchNorm1d(self.hidden_dim)
#         self.conv3 = SAGEConv(hidden_dim, hidden_dim)
#         self.bn3 = nn.BatchNorm1d(self.hidden_dim)

#         self.linear = nn.Linear(hidden_dim, output_dim)
#         self.proj = nn.Sequential(
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#         )
        
#         self.bn1 = nn.BatchNorm1d(self.hidden_dim)

#         if self.coarsening:
#             self.pooling = SAGPooling(in_channels=hidden_dim, ratio=0.5)

#         self.tau = tau
#         self.eps = eps
#         self.niter = niter

#     def forward(self, x, target, edge_index, batch=None, criterion='contrast'):
#         device = x.device
#         if batch is None:
#             batch = edge_index.new_zeros(x.size(0))
#         batch_size = len(torch.unique(batch))
#         pool = self.pooling(x, edge_index, batch=batch)
#         y, corasened_index, corasened_batch, perm= pool[0], pool[1], pool[3], pool[4]

#         x = F.relu(self.bn1(self.conv1(x, edge_index)))
#         x = F.relu(self.bn2(self.conv2(x, edge_index)))
#         x = F.relu(self.bn3(self.conv3(x, edge_index)))

#         y = F.relu(self.bn1(self.conv1(y, corasened_index)))
#         y = F.relu(self.bn2(self.conv2(y, corasened_index)))
#         y = F.relu(self.bn3(self.conv3(y, corasened_index)))

#         loss = 0.0

#         if self.coarsening:
#             # pool = self.pooling(x, edge_index, batch=batch)
#             # y, corasened_index, corasened_batch, perm= pool[0], pool[1], pool[3], pool[4]
#             corasened_graph_embed = global_mean_pool(y, corasened_batch)
#             if batch_size == 1:
#                 return y, corasened_index
#             if criterion == 'contrast':
#                 # Graph level metric learning
#                 graph_embed = global_mean_pool(x, batch)
#                 x_graph, y_graph = self.proj(graph_embed), self.proj(corasened_graph_embed)
#                 x_graph, y_graph = F.normalize(x_graph, dim=-1), F.normalize(y_graph, dim=-1)
#                 score = F.cosine_similarity(x_graph.unsqueeze(1), y_graph.unsqueeze(0), dim=-1)


#                 positive = torch.diag(torch.ones(len(x_graph), dtype=torch.bool, device=device))
#                 mutual_info = (score[positive] - score.logsumexp(dim=-1)).mean()

#                 loss -= mutual_info
                
#             elif criterion == 'wasserstein':
#                 # Node level metric learning
#                 w_2 = 0.0
#                 for i in range(batch_size):
#                     indices, corasened_indices = batch == i, corasened_batch == i
#                     x1, y1 = x[indices], y[corasened_indices]
#                     w_2 += sinkhorn_loss_default(x1, y1, epsilon=self.eps, niter=self.niter)
                
#                 loss += w_2 / batch_size

#             out = self.linear(corasened_graph_embed) # TODO: attentional-based pooling, sparsely connected layer
            
#             score = pool[5]
#             edge_index = corasened_index
#         else:
#             graph_embed = global_mean_pool(x, batch)
#             out = self.linear(graph_embed)
#             perm, score = None, None

#         fct_loss = F.cross_entropy(out, target, reduction='mean')
#         loss += fct_loss

#         return out, edge_index, loss, perm, score

class SAGE_Scalar(torch.nn.Module):
    def __init__(self, input_dim, output_dim, num_nodes):
        super(SAGE_Scalar, self).__init__()
        assert input_dim == 1
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_nodes = num_nodes
        self.conv1 = SAGEConv(input_dim, input_dim)
        self.linear = nn.Linear(num_nodes, output_dim)
            
        
        
    def forward(self, x, target, edge_index, batch, norms, score=None, perm=None,linear=None):
        batch_size = len(torch.unique(batch))
        device = x.device

        x = self.conv1(x, edge_index)

        # graph_embed = torch.zeros((batch_size * self.num_nodes), requires_grad=True).to(device)
        graph_embed = score.flatten()
        indices = torch.nonzero((norms > 1e-8).flatten()).to(device)
        if perm is not None:
            indices = indices[perm]
        graph_embed[indices] = x

        graph_embed = graph_embed.view(batch_size, self.num_nodes, -1).squeeze()

        if linear is None:
            out = self.linear(graph_embed)
        else:
            linear = linear.to(device)
            out = linear(graph_embed)
        loss = F.cross_entropy(out, target)

        return graph_embed, out, loss
        

        
'''
class SAGE(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=512, num_nodes=200, auxiliary=False):
        super(SAGE, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_nodes = num_nodes
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.bn = nn.BatchNorm1d(self.num_nodes)
        if auxiliary:
            self.pool = SAGPooling(in_channels=input_dim, ratio=0.1)
        # self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        # self.conv3 = SAGEConv(hidden_dim, hidden_dim)
        self.auxiliary = auxiliary
        if self.hidden_dim == 1:
            self.lin = nn.Linear(num_nodes, output_dim)
        else:
            self.lin = nn.Linear(hidden_dim, output_dim)
        self.proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x, edge_index, batch, target, indices, image_embeds=None, batch_size=None, node_classification=False, num_nodes=None):
        # 1. Obtain node embeddings 
        # print(x.shape, self.input_dim, self.output_dim)
        x = self.conv1(x, edge_index)
        device = x.device
        if self.auxiliary:
            pool = self.pool(x, edge_index)
            x, edge_index, batch = pool[0], pool[1], pool[3]
        # x = x.relu()
        # x = self.conv2(x, edge_index)
        # x = x.relu()
        # x = self.conv3(x, edge_index)

        # 2. Readout layer
        if self.hidden_dim == 1:
            if node_classification:
                assert num_nodes is not None
                start = 0
                values = None
                for num_node in num_nodes:
                    if values is None:
                        values = x[start+num_node:start+num_node+self.output_dim]
                    else:
                        values = torch.cat([values, x[start+num_node:start+num_node+self.output_dim]])
                    start += num_node + self.output_dim
                values = values.long().to(device)
                graph_embeds = x[values]
                graph_embeds = graph_embeds.view(batch_size, -1)
            else:    
                graph_embed = torch.zeros((batch_size * self.num_nodes), requires_grad=True).to(device)
                indices = torch.nonzero((indices > 1e-8).flatten()).to(device)
                graph_embed[indices] = x

                # graph_embed = x
                graph_embeds = self.bn(graph_embed.view(batch_size, -1))

            # print(graph_embeds)
            # out = F.dropout(graph_embeds, p=0.3, training=self.training)
            out = self.lin(graph_embeds)
            loss = F.cross_entropy(out, target, reduction='mean')

            return out, loss, graph_embeds, x
        else:
            graph_embeds = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        
        # print(graph_embeds.shape)

        # 3. Apply a final classifier
        # out = F.dropout(graph_embeds, p=0.1, training=self.training)
        # out = self.lin(graph_embeds)

        # if image_embeds is not None:
        #     return out, self.proj(graph_embeds), self.proj(image_embeds), x
        # return out, self.proj(graph_embeds), None, x
        if not self.auxiliary:

            out = F.dropout(graph_embeds, p=0.1, training=self.training)
            out = self.lin(out)

            loss = F.cross_entropy(out, target, reduction='mean')

            if image_embeds is not None:
                return out, loss, self.proj(graph_embeds), self.proj(image_embeds), x
            return out, loss, self.proj(graph_embeds), None, x
        else:
            if image_embeds is not None:
                return None, None, self.proj(graph_embeds), self.proj(image_embeds), x
            return None, None, self.proj(graph_embeds), None, x
'''       