import torch
import os
import random
import utils
import data_utils
import similarity
import argparse
import datetime
import json
from sklearn.metrics import f1_score
import math 
import numpy as np
from dataset import datasets, ImbalancedDatasetSampler
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim import Adam, SGD
import torch.nn.functional as F
import torch_geometric
from transformers import get_scheduler, get_cosine_schedule_with_warmup
from tqdm import tqdm
from models import GCN, GNN, ConceptGNN
import matplotlib.pyplot as plt
import logging
import networkx as nx
from selection_algo import global_submodular_select

parser = argparse.ArgumentParser(description='Settings for creating CBM')


parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--concept_set", type=str, default=None, 
                    help="path to concept set name")
parser.add_argument("--clip_name", type=str, default="ViT-B/16", help="Which CLIP model to use")

parser.add_argument("--device", type=str, default="cuda", help="Which device to use")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size used when saving model/CLIP activations")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument("--activation_dir", type=str, default='saved_activations', help="save location for backbone and CLIP activations")
parser.add_argument("--save_dir", type=str, default='saved_models', help="where to save trained models")
parser.add_argument("--saved_model", type=str, default=None, help='directory to where the desired checkpoint is location')
parser.add_argument("--epoch", type=int, default=100, help="number of training epochs")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--scheme", type=str, default='joint', help='training schema [independent, sequential, joint]')
parser.add_argument("--supervision", default=False, action='store_true')
parser.add_argument("--filtering", type=str, default=None, help="choose from [None, submodular, random, similarity, discriminative, coverage]")
parser.add_argument("--filtering_size", type=int, default=200)
parser.add_argument("--tau", default=0.9, type=float)
parser.add_argument("--reg_alpha", default=0.7, type=float)
parser.add_argument("--init_beta", default=0.8, type=float)
parser.add_argument("--ratio", default=1, type=float)
parser.add_argument("--graph", action='store_false', default=True)
parser.add_argument("--logger", action='store_true')
parser.add_argument("--output_dir", type=str, default='./log/')

def graph_visualize1(data_set, fig_name, labels=None, edge_indices=None, color=None):
    plt.figure(figsize=(500, 100))

    G = nx.Graph()
    if labels is None:
        G.add_nodes_from(range(len(data)))
        labels = range(len(data))
    else:
        G.add_nodes_from(range(len(labels)))
    pos = nx.drawing.spring_layout(G)
    nx.draw(G, label=labels, with_labels=True, node_size=1000, pos=pos, node_color='#dde5b6')
    # pos = G.pos
    for edge_index, c in zip(edge_indices, color):
        print(c)
        if edge_index is not None:
            if edge_index.size(-1) != 2:
                edge_index = edge_index.t()
            # G.add_edges_from(edge_index.detach().tolist(), color=c)
            nx.draw_networkx_edges(G, pos=pos, edgelist=edge_index.detach().tolist(), edge_color=c)
        else:
            # G.add_edges_from(data.edge_index.t().detach().tolist(), color=c)
            nx.draw_networkx_edges(G, pos=pos, edgelist=data.edge_index.detach().tolist(), edge_color=c)
    
    nx.draw(G, label=labels, with_labels=True, node_size=1000, node_color='#dde5b6')
    plt.savefig('./visualization/'+fig_name+'.pdf',dpi=199)
    return

def graph_visualize(data, fig_name, labels=None, edge_index=None, node_size=1000, font_size=12):
    plt.figure(figsize=(75, 75))

    G = nx.Graph()
    # pos = nx.spring_layout(G, k=0.5/(np.sqrt(len(labels))))
    if labels is None:
        G.add_nodes_from(range(len(data)))
        labels = range(len(data))
    else:
        G.add_nodes_from(labels)

    if edge_index is not None:
        if edge_index.size(-1) != 2:
            edge_index = edge_index.t()
        G.add_edges_from(edge_index.detach().tolist())
        # G.add_edges_from(edge_index.detach().tolist(), pos=pos)
    else:
        G.add_edges_from(data.edge_index.t().detach().tolist(), pos=pos)
    
    nx.draw(G, label=labels, with_labels=True, node_size=node_size, node_color='#bde0fe', font_size=font_size)
    plt.savefig('./visualization/'+fig_name+'.pdf')
    return

def get_ucp(score):
    return 1/torch.abs(F.relu(score) - 0.5)**2

def minimum_spanning_tree(adj_matrix):
    num_nodes = adj_matrix.size(0)
    
    # Calculate degrees
    degrees = adj_matrix.sum(dim=1)
    
    # Find the node with the largest degree
    start_node = torch.argmax(degrees).item()
    
    # Initialize structures
    mst_edges = []
    visited = torch.zeros(num_nodes, dtype=torch.bool)
    node_order = []
    
    def prim_with_custom_selection(start_node):
        current_node = start_node
        visited[current_node] = True
        node_order.append(current_node)

        while True:
            # Get neighbors of the current node
            neighbors = [(neighbor, adj_matrix[current_node, neighbor].item()) 
                         for neighbor in range(num_nodes) 
                         if adj_matrix[current_node, neighbor] > 0 and not visited[neighbor]]
            
            if not neighbors:
                break
            
            # Choose the neighbor with the largest degree
            next_node = max(neighbors, key=lambda x: degrees[x[0]])[0]
            
            # Add the edge to the MST
            mst_edges.append((current_node, next_node, adj_matrix[current_node, next_node].item()))
            
            # Mark the node as visited and update the current node
            visited[next_node] = True
            node_order.append(next_node)
            current_node = next_node

    # Run the custom Prim's algorithm from the starting node
    prim_with_custom_selection(start_node)
    
    # Handle disconnected components
    for node in range(num_nodes):
        if not visited[node]:
            prim_with_custom_selection(node)
    
    return node_order, mst_edges


def main(args):
    if args.logger:
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S',
                            level=logging.INFO)
        logger = logging.getLogger(__name__)

        # saved_model = args.saved_model.split("_")
        seed = args.seed
        # print(seed)
        if not os.path.exists(f'log/intervention/{args.dataset}'):
            os.system(f"mkdir log/intervention/{args.dataset}")
        file_name = f'/intervention/{args.dataset}/seed_{seed}'
        if 'no_graph' in args.saved_model:
            file_name += "_no_graph"
        
        if args.filtering is not None:
            file_name += f'_{args.filtering}_{args.filtering_size}'
        
        file_name += f'_intervention_ratio_{args.ratio}'
        
        file_name += '.log'
        fh = logging.FileHandler(args.output_dir+file_name, mode="w", encoding="utf-8")
        logger.addHandler(fh)
    else:
        logger = None

    if logger is not None:
        logger.info(f"Intervention this model {args.saved_model}")

    dataset_train = args.dataset + "_train"
    dataset_val = args.dataset + "_val"

    backbone = args.clip_name.replace('/', '')

    if args.clip_name in ['patch_BioViL_T', 'BioViL_T', 'biomedclip']:
        image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_train, backbone)
        val_image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_val, backbone)
        concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
        if args.clip_name == 'patch_BioViL_T':
            args.clip_name = 'BioViL_T'
        text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)
    else:
        image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_train, backbone)
        val_image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_val, backbone)
        concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
        text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    image_features = torch.load(image_save_name, map_location="cpu").float()
    val_image_features = torch.load(val_image_save_name, map_location="cpu").float()

    text_features = torch.load(text_save_name, map_location="cpu").float()

    labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
    val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

    val_concept_scores = F.normalize(val_image_features, dim=-1) @ F.normalize(text_features, dim=-1).t()

    num_classes = torch.unique(labels)
    num_images_per_class = []
    for i in num_classes:
        num_images_per_class.append((labels == i).float().sum().item())
    num_images_per_class = torch.LongTensor(num_images_per_class)

    if args.filtering == 'submodular':
        text_features, selected = global_submodular_select(image_features, text_features, 1, args.filtering_size, num_images_per_class, [1,1])

    input_dim = val_image_features.size(-1)
    num_classes = len(torch.unique(val_labels))
    num_concepts = text_features.size(0)

    model = ConceptGNN(input_dim=val_image_features.size(-1), output_dim=len(torch.unique(val_labels)), num_concepts=text_features.size(0),
        hidden_dim=text_features.size(-1), module='gcn', tau=args.tau, alpha=args.reg_alpha, beta=args.init_beta)
    weight = F.normalize(text_features, dim=-1) @ F.normalize(text_features, dim=-1).t()
    # # model.init_edge_weight(weight)
    # model.edge_param1.requires_grad = args.learnable_graph
    # model.edge_param2.requires_grad = args.learnable_graph
    # model.edge_param3.requires_grad = args.learnable_graph

    model.load_state_dict(torch.load(os.path.join(args.saved_model, 'model.pkl')))

    model = model.to(device)
    val_set = datasets.ConceptDataset(image_features=val_image_features, text_features=text_features, tau=args.tau,
                                        concept_scores=val_concept_scores, labels = val_labels, graph_construction=False)
    loader = torch_geometric.loader.DataLoader(val_set, batch_size=args.batch_size, shuffle=False)

    n, total_loss, overall_acc = 0, 0, 0
    corrects = 0
    sample = val_set[0]
    model.eval()
    
    
    # graph_visualize(sample['graph'].x, f'val_before', labels=None, edge_index=sample['graph'].edge_index)
    wrong_scores, wrong_targets, right_scores, right_targets = None, None, None, None
    text_features = text_features.to(device)
    classification = {}
    for instance in tqdm(loader):

        x = instance['x'].to(device)
        concepts = instance['concept'].to(device)
        target = instance['label'].to(device)
        image = instance['image'].to(device)
        batch_size = x.size(0)
        
        loss, pred, c_adj = model(image, target, args.graph, text_features=text_features)

        wrong_indices = pred.max(1)[1] != target
        right_indices = pred.max(1)[1] == target
        score = model.get_score(image, text_features=text_features)
        if wrong_scores is None:
            wrong_scores = score[wrong_indices, :].detach().cpu()
            right_scores = score[right_indices, :].detach().cpu()
            wrong_targets = target[wrong_indices].detach().cpu()
            right_targets = target[right_indices].detach().cpu()
        else:
            wrong_scores = torch.cat([wrong_scores, score[wrong_indices, :].detach().cpu()], dim=0)
            right_scores = torch.cat([right_scores, score[right_indices, :].detach().cpu()], dim=0)
            wrong_targets = torch.cat([wrong_targets, target[wrong_indices].detach().cpu()])
            right_targets = torch.cat([right_targets, target[right_indices].detach().cpu()])

        acc = (pred.max(1)[1]== target).float().mean()
        corrects += (pred.max(1)[1]== target).float().sum()
        overall_acc += acc * batch_size

        total_loss += loss.item() * batch_size
        n += batch_size

        torch.cuda.empty_cache()
    # print(wrong_scores)
    # exit()
    
    for i in torch.unique(val_labels).detach().cpu().tolist():
        classification[i] = {}
        classification[i]['right'] = None
        classification[i]['wrong'] = None
    for i, label in enumerate(wrong_targets.detach().cpu().tolist()):
        if classification[label]['wrong'] is not None:
            classification[label]['wrong'] = torch.cat([classification[label]['wrong'], wrong_scores[i].unsqueeze(0)], dim=0)
        else:
            classification[label]['wrong'] = wrong_scores[i].unsqueeze(0)
    for i, label in enumerate(right_targets.detach().cpu().tolist()):
        if classification[label]['right'] is not None:
            classification[label]['right'] = torch.cat([classification[label]['right'], right_scores[i].unsqueeze(0)], dim=0)
        else:
            classification[label]['right'] = right_scores[i].unsqueeze(0)
    
    
    difference = {}
    random = {}
    rank_difference = {}
    rank_corrects = {}
    rank_ucp = {}
    rank_graph = {}
    rank_spanning = {}
    size = int(text_features.size(0) * args.ratio)
    # if edge_index_dir is None:
    #     size = args.size
    #     args.random_intervention = True
    # else:
    #     edge_index = torch.LongTensor(np.load(edge_index_dir))
    #     size = len(torch.unique(edge_index.flatten()))
    #     print(size)
    # if logger is not None:
    #     logger.info(f"The number of connected nodes or intervented nodes is {size}, and ratio is {args.ratio}")
    # connected_nodes = torch.zeros(text_features.size(0), dtype=torch.bool)
    random_nodes = torch.zeros(text_features.size(0), dtype=torch.bool)
    # node_indices = torch.randperm(text_features.size(0))[:size]
    # random_nodes[node_indices] = True
    
    node_indices = torch.randperm(text_features.size(0))[:size]

    if args.graph:
        degree = 0
        adj = 0
        for edge_param in model.edge_param:
            A = 0.5 * (edge_param + edge_param.t())
            adj += A
            degree += model.get_degree(A)
        degree = degree / len(model.edge_param) - 1
        _, graph_indices = torch.sort(degree, descending=True)
        graph_indices = graph_indices[:size]

        graph_nodes = torch.zeros(text_features.size(0), dtype=torch.bool)
        graph_nodes[graph_indices] = True
        for i in range(adj.size(0)):
            adj[i, i] = 0
        adj = torch.clamp(adj, max=1, min=0)
        adj = torch.where(adj > 0, 1, 0)
        # print(adj)
        node_order, _ = minimum_spanning_tree(adj)
        ordering_indices = node_order[:size]
        ordering_nodes =  torch.zeros(text_features.size(0), dtype=torch.bool)
        ordering_nodes[ordering_indices] = True
    else:
        degree = None
        graph_indices = None
    
    # print(node_indices)
    random_nodes[node_indices] = True
    # print(random_nodes)
    for i in torch.unique(val_labels).detach().cpu().tolist():
        if  classification[i]['wrong'] is None:
            difference[i] = torch.zeros_like(classification[i]['right'])
            random[i] = torch.zeros_like(classification[i]['right'])
            rank_graph[i] = torch.zeros_like(classification[i]['right'])
            rank_spanning[i] = torch.zeros_like(classification[i]['right'])
            rank_ucp[i] = torch.zeros_like(classification[i]['right'])
        elif classification[i]['right'] is None:
            random[i] = torch.zeros_like(classification[i]['wrong'])
            rank_corrects[i] = torch.zeros_like(classification[i]['wrong'])
            rank_difference[i] = torch.zeros_like(classification[i]['wrong'])
            rank_graph[i] = torch.zeros_like(classification[i]['wrong'])
            rank_spanning[i] = torch.zeros_like(classification[i]['wrong'])
            rank_ucp[i] = torch.zeros_like(classification[i]['wrong'])
        else:
            r_mean, w_mean = classification[i]['right'].mean(dim=0), classification[i]['wrong'].mean(dim=0)
            # difference[i] = r_mean - w_mean
            # difference[i][~connected_nodes] = 0

            random[i] = r_mean - w_mean
            random[i][~random_nodes] = 0
            
            _, indices = torch.sort(classification[i]['right'].mean(dim=0), descending=True)
            indices = indices[size:]
            rank_corrects[i] = r_mean - w_mean
            rank_corrects[i][indices] = 0

            _, indices = torch.sort(r_mean - w_mean,descending=True)
            indices = indices[size:]
            rank_difference[i] = r_mean - w_mean
            rank_difference[i][indices] = 0

            _, indices = torch.sort(get_ucp(classification[i]['right'].mean(dim=0)),descending=True)
            indices = indices[size:]
            rank_ucp[i] = r_mean - w_mean
            rank_ucp[i][indices] = 0

            if args.graph:
                rank_graph[i] = r_mean - w_mean
                rank_graph[i][~graph_nodes] = 0

                rank_spanning[i] = r_mean - w_mean
                rank_spanning[i][~ordering_nodes] = 0

    
    acc = 0
    score = None # + difference[label]
    for key in classification.keys():
        if classification[key]['wrong'] is not None:
            if score is None:
                score = classification[key]['wrong'] + random[key]
            else:
                score = torch.cat([score, classification[key]['wrong'] + random[key]], dim=0)
    # print(score.shape)
    score = score.to(device)
    import time
    start_time = time.time()
    pred = model.intervention(score, not args.graph)
    print("--- %s seconds ---" % (time.time() - start_time))
    random_difference = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()
    
    # print(rank_corrects.keys(), classification.keys())
    acc = 0
    score = None # + difference[label]
    for key in classification.keys():
        if classification[key]['wrong'] is not None:
            if score is None:
                score = classification[key]['wrong'] + rank_corrects[key]
            else:
                score = torch.cat([score, classification[key]['wrong'] + rank_corrects[key]], dim=0)
    # print(score.shape)
    score = score.to(device)
    pred = model.intervention(score, not args.graph)
    rank_corrects = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

    score = None # + difference[label]
    for key in classification.keys():
        if classification[key]['wrong'] is not None:
            if score is None:
                score = classification[key]['wrong'] + rank_difference[key]
            else:
                score = torch.cat([score, classification[key]['wrong'] + rank_difference[key]], dim=0)
    # print(score.shape)
    score = score.to(device)
    pred = model.intervention(score, not args.graph)
    rank_difference = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

    score = None # + difference[label]
    for key in classification.keys():
        if classification[key]['wrong'] is not None:
            if score is None:
                score = classification[key]['wrong'] + rank_ucp[key]
            else:
                score = torch.cat([score, classification[key]['wrong'] + rank_ucp[key]], dim=0)
    # print(score.shape)
    score = score.to(device)
    pred = model.intervention(score, not args.graph)
    rank_ucp = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()
    
    if args.graph:
        score = None # + difference[label]
        for key in classification.keys():
            if classification[key]['wrong'] is not None:
                if score is None:
                    score = classification[key]['wrong'] + rank_graph[key]
                else:
                    score = torch.cat([score, classification[key]['wrong'] + rank_graph[key]], dim=0)
        # print(score.shape)
        score = score.to(device)
        pred = model.intervention(score, not args.graph)
        rank_graph = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

        score = None # + difference[label]
        for key in classification.keys():
            if classification[key]['wrong'] is not None:
                if score is None:
                    score = classification[key]['wrong'] + rank_spanning[key]
                else:
                    score = torch.cat([score, classification[key]['wrong'] + rank_spanning[key]], dim=0)
        # print(score.shape)
        score = score.to(device)
        pred = model.intervention(score, not args.graph)
        rank_spanning = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

    # print("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
    # # logger.info("After intervention accuracy (graph) is {:.4f}".format(graph_difference / (val_image_features.size(0))))
    # print("After intervention accuracy (random) is {:.4f}".format(random_difference / (val_image_features.size(0))))
    # print("After intervention accuracy (rank corrects) is {:.4f}".format(rank_corrects / (val_image_features.size(0))))
    # print("After intervention accuracy (rank difference) is {:.4f}".format(rank_difference / (val_image_features.size(0))))
    
    if logger is not None:
        logger.info("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
        # logger.info("After intervention accuracy (graph) is {:.4f}".format(graph_difference / (val_image_features.size(0))))
        # logger.info("After intervention accuracy (random) is {:.4f}".format(random_difference / (val_image_features.size(0))))
        # logger.info("After intervention accuracy (rank corrects) is {:.4f}".format(rank_corrects / (val_image_features.size(0))))
        # logger.info("After intervention accuracy (rank difference) is {:.4f}".format(rank_difference / (val_image_features.size(0))))
        logger.info("After intervention accuracy (UCP) is {:.4f}".format(rank_ucp / (val_image_features.size(0))))
        # if args.graph:
        #     logger.info("After intervention accuracy (graph) is {:.4f}".format(rank_graph / (val_image_features.size(0))))
        #     logger.info("After intervention accuracy (spanning) is {:.4f}".format(rank_spanning / (val_image_features.size(0))))
    else:
        print("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
        # # logger.info("After intervention accuracy (graph) is {:.4f}".format(graph_difference / (val_image_features.size(0))))
        # print("After intervention accuracy (random) is {:.4f}".format(random_difference / (val_image_features.size(0))))
        # print("After intervention accuracy (rank corrects) is {:.4f}".format(rank_corrects / (val_image_features.size(0))))
        # print("After intervention accuracy (rank difference) is {:.4f}".format(rank_difference / (val_image_features.size(0))))
        print("After intervention accuracy (UCP) is {:.4f}".format(rank_ucp / (val_image_features.size(0))))
        # if args.graph:
        #     print("After intervention accuracy (graph) is {:.4f}".format(rank_graph / (val_image_features.size(0))))
        #     print("After intervention accuracy (spanning) is {:.4f}".format(rank_spanning / (val_image_features.size(0))))

    # if args.graph:
    #     edge_index = torch.nonzero(torch.triu(model.get_characteristic_matrix(adj), diagonal=1))
    #     index_pairs = torch.LongTensor([
    #         (10, 12), (10, 4), (27, 26), (19, 20), (36, 16), (6, 39), (17, 19), (13, 14),
    #         (8, 40), (43, 17), (16, 29), (43, 44), (24, 23), (33, 16), (28, 30), (51, 92),
    #         (11, 12), (49, 4), (53, 54), (14, 13), (41, 14), (69, 70), (81, 75), (73, 72),
    #         (77, 72), (80, 75), (63, 61), (68, 5), (88, 90), (110, 111), (121, 122), (125, 124),
    #         (126, 125), (3, 117), (130, 131), (135, 136), (140, 145), (160, 161), (162, 163),
    #         (169, 170), (174, 176), (179, 180), (186, 187), (196, 228), (197, 210), (199, 200),
    #         (203, 207), (211, 212), (216, 199), (217, 218)
    #     ])
    #     print(edge_index)
    #     # print(index_pairs)
    #     print(f'{args.dataset}_gen_num_edges_{edge_index.size(0)}_acc_{overall_acc / n}')
    #     graph_visualize(None, f'{args.dataset}_gen_num_edges_{edge_index.size(0)}_acc_{overall_acc / n}', 
    #     labels=range(text_features.size(0)), edge_index=edge_index, node_size=1000)
    #     edge_index = torch.LongTensor([
    #         [0, 15],[0, 224],[4, 30],[4, 241],[4, 245],[4, 246],[4, 287],[4, 300],[4, 347],[8, 15],[8, 36],[10, 224],
    #         [11, 47],[11, 245],[12, 36],[12, 239],[15, 36],[15, 40],[15, 47],[36, 224],[36, 245],[47, 225],[47, 245],
    #         [47, 288],[54, 239],[103, 335],[224, 225],[224, 240],[224, 300],[225, 240],[225, 337],[235, 355],[239, 241],
    #         [239, 245],[239, 347],[240, 347],[241, 245],[241, 288],[241, 300],[241, 337],[245, 300],[250, 347],[300, 363],[335, 347],[347, 355]])
    #     graph_visualize(None, f'{args.dataset}_gen_num_edges_{edge_index.size(0)}_acc_{overall_acc / n}_connected_only', 
    #     labels=torch.unique(edge_index.flatten()).detach().tolist(), edge_index=edge_index, node_size=10000, font_size=40)

        # random_graph = torch.rand(num_concepts, num_concepts)
        # random_graph = 0.5 * (random_graph + random_graph.t())
        
        # random_graph = torch.where(random_graph < args.ratio, 0.1, 0)
        # for i in range(num_concepts):
        #     random_graph[i,i] = 0
        # random_graph = random_graph.to(device)
        # # random_graph = model.get_characteristic_matrix(random_graph)

        # for edge_param in model.edge_param:
        #     edge_param.data += random_graph

        # n, total_loss, overall_acc = 0, 0, 0
        # for instance in tqdm(loader):

        #     x = instance['x'].to(device)
        #     concepts = instance['concept'].to(device)
        #     target = instance['label'].to(device)
        #     image = instance['image'].to(device)
        #     batch_size = x.size(0)
            
        #     # loss, pred, c_adj = model.graph_intervention(image, target, random_graph, args.graph, text_features=text_features)
        #     loss, pred, c_adj = model(image, target, args.graph, text_features=text_features)

        #     acc = (pred.max(1)[1]== target).float().mean()
        #     corrects += (pred.max(1)[1]== target).float().sum()
        #     overall_acc += acc * batch_size

        #     total_loss += loss.item() * batch_size
        #     n += batch_size

        #     torch.cuda.empty_cache()
        # print("After Graph-level Intervention Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
    # print(wrong_scores)
        
        
    

if __name__=='__main__':
    args = parser.parse_args()

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    main(args)