import torch
import os
import random
import utils
import data_utils
import similarity
import argparse
import datetime
import json
from sklearn.metrics import f1_score
from sklearn.manifold import TSNE
import seaborn as sns
import colorcet as cc
import math 
import numpy as np
from dataset import datasets, ImbalancedDatasetSampler
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim import Adam, SGD
import torch.nn.functional as F
import torch_geometric
from transformers import get_scheduler, get_cosine_schedule_with_warmup
from tqdm import tqdm
from models import GCN, GNN, ConceptGNN
import matplotlib.pyplot as plt
import logging
import networkx as nx
from selection_algo import global_submodular_select

parser = argparse.ArgumentParser(description='Settings for creating CBM')


parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--concept_set", type=str, default=None, 
                    help="path to concept set name")
parser.add_argument("--clip_name", type=str, default="ViT-B/16", help="Which CLIP model to use")

parser.add_argument("--device", type=str, default="cuda", help="Which device to use")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size used when saving model/CLIP activations")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument("--activation_dir", type=str, default='saved_activations', help="save location for backbone and CLIP activations")
parser.add_argument("--save_dir", type=str, default='saved_models', help="where to save trained models")
parser.add_argument("--saved_model", type=str, default=None, help='directory to where the desired checkpoint is location')
parser.add_argument("--epoch", type=int, default=100, help="number of training epochs")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--scheme", type=str, default='joint', help='training schema [independent, sequential, joint]')
parser.add_argument("--supervision", default=False, action='store_true')
parser.add_argument("--filtering", type=str, default=None, help="choose from [None, submodular, random, similarity, discriminative, coverage]")
parser.add_argument("--filtering_size", type=int, default=200)
parser.add_argument("--tau", default=0.9, type=float)
parser.add_argument("--reg_alpha", default=0.7, type=float)
parser.add_argument("--init_beta", default=0.8, type=float)
parser.add_argument("--ratio", default=1, type=float)
parser.add_argument("--graph", action='store_false', default=True)
parser.add_argument("--logger", action='store_true')
parser.add_argument("--output_dir", type=str, default='./log/')


def visualize_tsne(features, labels, dataset, graph=True):
    """
    Visualize high-dimensional features using t-SNE in 2D and 3D
    
    Parameters:
    features (torch.Tensor): Input features of shape (n_samples, n_features)
    labels (torch.Tensor): Labels of shape (n_samples,)
    """
    # Convert to NumPy if features are PyTorch tensors
    if torch.is_tensor(features):
        features = features.numpy()
    if torch.is_tensor(labels):
        labels = labels.numpy()
    
    # Unique labels for color mapping
    unique_labels = np.unique(labels)
    
    # Color palette
    colors = sns.color_palette(cc.glasbey, n_colors=len(unique_labels))
    
    # 2D t-SNE
    plt.figure(figsize=(12, 5))
    
    # 2D Subplot
    tsne_2d = TSNE(n_components=2, random_state=42).fit_transform(features)
    
    for i, label in enumerate(unique_labels):
        mask = (labels == label)
        plt.scatter(tsne_2d[mask, 0], tsne_2d[mask, 1], 
                    color=colors[i], label=f'Label {label}', alpha=0.7)
    
    
    name = "G-PCBM" if graph else "PCBM"
    plt.title(f't-SNE 2D Visualization for {dataset.upper()}102 ({name})')
    plt.xlabel('t-SNE Feature 1')
    plt.ylabel('t-SNE Feature 2')
    # plt.legend()
    
    # # 3D Subplot
    # plt.subplot(122, projection='3d')
    # tsne_3d = TSNE(n_components=3, random_state=42).fit_transform(features)
    
    # for i, label in enumerate(unique_labels):
    #     mask = (labels == label)
    #     plt.scatter(tsne_3d[mask, 0], tsne_3d[mask, 1], tsne_3d[mask, 2], 
    #                 color=colors[i], label=f'Label {label}', alpha=0.7)
    
    # plt.title('t-SNE 3D Visualization')
    # plt.xlabel('t-SNE Feature 1')
    # plt.ylabel('t-SNE Feature 2')
    # plt.zlabel('t-SNE Feature 3')
    # plt.legend()
    
    plt.tight_layout()
    name = "Graph" if graph else "Baseline"
    plt.savefig(f'./visualization_tsne/{dataset}_{name}.pdf')

def graph_visualize1(data_set, fig_name, labels=None, edge_indices=None, color=None):
    plt.figure(figsize=(500, 100))

    G = nx.Graph()
    if labels is None:
        G.add_nodes_from(range(len(data)))
        labels = range(len(data))
    else:
        G.add_nodes_from(range(len(labels)))
    pos = nx.drawing.spring_layout(G)
    nx.draw(G, label=labels, with_labels=True, node_size=1000, pos=pos, node_color='#dde5b6')
    # pos = G.pos
    for edge_index, c in zip(edge_indices, color):
        print(c)
        if edge_index is not None:
            if edge_index.size(-1) != 2:
                edge_index = edge_index.t()
            # G.add_edges_from(edge_index.detach().tolist(), color=c)
            nx.draw_networkx_edges(G, pos=pos, edgelist=edge_index.detach().tolist(), edge_color=c)
        else:
            # G.add_edges_from(data.edge_index.t().detach().tolist(), color=c)
            nx.draw_networkx_edges(G, pos=pos, edgelist=data.edge_index.detach().tolist(), edge_color=c)
    
    nx.draw(G, label=labels, with_labels=True, node_size=1000, node_color='#dde5b6')
    plt.savefig('./visualization/'+fig_name+'.pdf',dpi=199)
    return

def graph_visualize(data, fig_name, labels=None, edge_index=None, node_size=1000, font_size=12):
    plt.figure(figsize=(75, 75))

    G = nx.Graph()
    # pos = nx.spring_layout(G, k=0.5/(np.sqrt(len(labels))))
    if labels is None:
        G.add_nodes_from(range(len(data)))
        labels = range(len(data))
    else:
        G.add_nodes_from(labels)

    if edge_index is not None:
        if edge_index.size(-1) != 2:
            edge_index = edge_index.t()
        G.add_edges_from(edge_index.detach().tolist())
        # G.add_edges_from(edge_index.detach().tolist(), pos=pos)
    else:
        G.add_edges_from(data.edge_index.t().detach().tolist(), pos=pos)
    
    nx.draw(G, label=labels, with_labels=True, node_size=node_size, node_color='#bde0fe', font_size=font_size)
    plt.savefig('./visualization/'+fig_name+'.pdf')
    return

def get_ucp(score):
    return 1/torch.abs(F.relu(score) - 0.5)**2

def minimum_spanning_tree(adj_matrix):
    num_nodes = adj_matrix.size(0)
    
    # Calculate degrees
    degrees = adj_matrix.sum(dim=1)
    
    # Find the node with the largest degree
    start_node = torch.argmax(degrees).item()
    
    # Initialize structures
    mst_edges = []
    visited = torch.zeros(num_nodes, dtype=torch.bool)
    node_order = []
    
    def prim_with_custom_selection(start_node):
        current_node = start_node
        visited[current_node] = True
        node_order.append(current_node)

        while True:
            # Get neighbors of the current node
            neighbors = [(neighbor, adj_matrix[current_node, neighbor].item()) 
                         for neighbor in range(num_nodes) 
                         if adj_matrix[current_node, neighbor] > 0 and not visited[neighbor]]
            
            if not neighbors:
                break
            
            # Choose the neighbor with the largest degree
            next_node = max(neighbors, key=lambda x: degrees[x[0]])[0]
            
            # Add the edge to the MST
            mst_edges.append((current_node, next_node, adj_matrix[current_node, next_node].item()))
            
            # Mark the node as visited and update the current node
            visited[next_node] = True
            node_order.append(next_node)
            current_node = next_node

    # Run the custom Prim's algorithm from the starting node
    prim_with_custom_selection(start_node)
    
    # Handle disconnected components
    for node in range(num_nodes):
        if not visited[node]:
            prim_with_custom_selection(node)
    
    return node_order, mst_edges


def main(args):
    if args.logger:
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S',
                            level=logging.INFO)
        logger = logging.getLogger(__name__)

        # saved_model = args.saved_model.split("_")
        seed = args.seed
        # print(seed)
        if not os.path.exists(f'log/intervention/{args.dataset}'):
            os.system(f"mkdir log/intervention/{args.dataset}")
        file_name = f'/intervention/{args.dataset}/seed_{seed}'
        if 'no_graph' in args.saved_model:
            file_name += "_no_graph"
        
        if args.filtering is not None:
            file_name += f'_{args.filtering}_{args.filtering_size}'
        
        file_name += f'_intervention_ratio_{args.ratio}'
        
        file_name += '.log'
        fh = logging.FileHandler(args.output_dir+file_name, mode="w", encoding="utf-8")
        logger.addHandler(fh)
    else:
        logger = None

    if logger is not None:
        logger.info(f"Intervention this model {args.saved_model}")

    dataset_train = args.dataset + "_train"
    dataset_val = args.dataset + "_val"

    backbone = args.clip_name.replace('/', '')

    if args.clip_name in ['patch_BioViL_T', 'BioViL_T', 'biomedclip']:
        image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_train, backbone)
        val_image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_val, backbone)
        concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
        if args.clip_name == 'patch_BioViL_T':
            args.clip_name = 'BioViL_T'
        text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)
    else:
        image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_train, backbone)
        val_image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_val, backbone)
        concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
        text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    image_features = torch.load(image_save_name, map_location="cpu").float()
    val_image_features = torch.load(val_image_save_name, map_location="cpu").float()

    text_features = torch.load(text_save_name, map_location="cpu").float()

    labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
    val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

    val_concept_scores = F.normalize(val_image_features, dim=-1) @ F.normalize(text_features, dim=-1).t()

    num_classes = torch.unique(labels)
    num_images_per_class = []
    for i in num_classes:
        num_images_per_class.append((labels == i).float().sum().item())
    num_images_per_class = torch.LongTensor(num_images_per_class)

    if args.filtering == 'submodular':
        text_features, selected = global_submodular_select(image_features, text_features, 1, args.filtering_size, num_images_per_class, [1,1])

    input_dim = val_image_features.size(-1)
    num_classes = len(torch.unique(val_labels))
    num_concepts = text_features.size(0)

    model = ConceptGNN(input_dim=val_image_features.size(-1), output_dim=len(torch.unique(val_labels)), num_concepts=text_features.size(0),
        hidden_dim=text_features.size(-1), module='gcn', tau=args.tau, alpha=args.reg_alpha, beta=args.init_beta)
    weight = F.normalize(text_features, dim=-1) @ F.normalize(text_features, dim=-1).t()
    # # model.init_edge_weight(weight)
    # model.edge_param1.requires_grad = args.learnable_graph
    # model.edge_param2.requires_grad = args.learnable_graph
    # model.edge_param3.requires_grad = args.learnable_graph

    model.load_state_dict(torch.load(os.path.join(args.saved_model, 'model.pkl')))

    model = model.to(device)
    val_set = datasets.ConceptDataset(image_features=val_image_features, text_features=text_features, tau=args.tau,
                                        concept_scores=val_concept_scores, labels = val_labels, graph_construction=False)
    loader = torch_geometric.loader.DataLoader(val_set, batch_size=args.batch_size, shuffle=False)

    n, total_loss, overall_acc = 0, 0, 0
    corrects = 0
    sample = val_set[0]
    model.eval()
    
    
    # graph_visualize(sample['graph'].x, f'val_before', labels=None, edge_index=sample['graph'].edge_index)
    wrong_scores, wrong_targets, right_scores, right_targets = None, None, None, None
    text_features = text_features.to(device)
    classification = {}
    all_scores = []
    for instance in tqdm(loader):

        x = instance['x'].to(device)
        concepts = instance['concept'].to(device)
        target = instance['label'].to(device)
        image = instance['image'].to(device)
        batch_size = x.size(0)
        
        loss, pred, c_adj = model(image, target, args.graph, text_features=text_features)

        wrong_indices = pred.max(1)[1] != target
        right_indices = pred.max(1)[1] == target
        score = model.get_score(image, text_features=text_features)
        all_scores.append(score)
        if wrong_scores is None:
            wrong_scores = score[wrong_indices, :].detach().cpu()
            right_scores = score[right_indices, :].detach().cpu()
            wrong_targets = target[wrong_indices].detach().cpu()
            right_targets = target[right_indices].detach().cpu()
        else:
            wrong_scores = torch.cat([wrong_scores, score[wrong_indices, :].detach().cpu()], dim=0)
            right_scores = torch.cat([right_scores, score[right_indices, :].detach().cpu()], dim=0)
            wrong_targets = torch.cat([wrong_targets, target[wrong_indices].detach().cpu()])
            right_targets = torch.cat([right_targets, target[right_indices].detach().cpu()])

        acc = (pred.max(1)[1]== target).float().mean()
        corrects += (pred.max(1)[1]== target).float().sum()
        overall_acc += acc * batch_size

        total_loss += loss.item() * batch_size
        n += batch_size

        torch.cuda.empty_cache()
    all_scores = torch.cat(all_scores, dim=0)
    print(all_scores.shape)
    visualize_tsne(all_scores.detach().cpu(), val_labels.detach().cpu(), dataset=args.dataset, graph=args.graph)
    exit()
    # print(wrong_scores)
    # exit()
    
    for i in torch.unique(val_labels).detach().cpu().tolist():
        classification[i] = {}
        classification[i]['right'] = None
        classification[i]['wrong'] = None
    for i, label in enumerate(wrong_targets.detach().cpu().tolist()):
        if classification[label]['wrong'] is not None:
            classification[label]['wrong'] = torch.cat([classification[label]['wrong'], wrong_scores[i].unsqueeze(0)], dim=0)
        else:
            classification[label]['wrong'] = wrong_scores[i].unsqueeze(0)
    for i, label in enumerate(right_targets.detach().cpu().tolist()):
        if classification[label]['right'] is not None:
            classification[label]['right'] = torch.cat([classification[label]['right'], right_scores[i].unsqueeze(0)], dim=0)
        else:
            classification[label]['right'] = right_scores[i].unsqueeze(0)
    
    
    difference = {}
    random = {}
    rank_difference = {}
    rank_corrects = {}
    rank_ucp = {}
    rank_graph = {}
    rank_spanning = {}
    size = int(text_features.size(0) * args.ratio)
    # if edge_index_dir is None:
    #     size = args.size
    #     args.random_intervention = True
    # else:
    #     edge_index = torch.LongTensor(np.load(edge_index_dir))
    #     size = len(torch.unique(edge_index.flatten()))
    #     print(size)
    # if logger is not None:
    #     logger.info(f"The number of connected nodes or intervented nodes is {size}, and ratio is {args.ratio}")
    # connected_nodes = torch.zeros(text_features.size(0), dtype=torch.bool)
    random_nodes = torch.zeros(text_features.size(0), dtype=torch.bool)
    # node_indices = torch.randperm(text_features.size(0))[:size]
    # random_nodes[node_indices] = True
    
    node_indices = torch.randperm(text_features.size(0))[:size]

    if args.graph:
        degree = 0
        adj = 0
        for edge_param in model.edge_param:
            A = 0.5 * (edge_param + edge_param.t())
            adj += A
            degree += model.get_degree(A)
        degree = degree / len(model.edge_param) - 1
        _, graph_indices = torch.sort(degree, descending=True)
        graph_indices = graph_indices[:size]

        graph_nodes = torch.zeros(text_features.size(0), dtype=torch.bool)
        graph_nodes[graph_indices] = True
        for i in range(adj.size(0)):
            adj[i, i] = 0
        adj = torch.clamp(adj, max=1, min=0)
        adj = torch.where(adj > 0, 1, 0)
        # print(adj)
        node_order, _ = minimum_spanning_tree(adj)
        ordering_indices = node_order[:size]
        ordering_nodes =  torch.zeros(text_features.size(0), dtype=torch.bool)
        ordering_nodes[ordering_indices] = True
    else:
        degree = None
        graph_indices = None
    
    # print(node_indices)
    random_nodes[node_indices] = True
    # print(random_nodes)
    for i in torch.unique(val_labels).detach().cpu().tolist():
        if  classification[i]['wrong'] is None:
            difference[i] = torch.zeros_like(classification[i]['right'])
            random[i] = torch.zeros_like(classification[i]['right'])
            rank_graph[i] = torch.zeros_like(classification[i]['right'])
            rank_spanning[i] = torch.zeros_like(classification[i]['right'])
            rank_ucp[i] = torch.zeros_like(classification[i]['right'])
        elif classification[i]['right'] is None:
            random[i] = torch.zeros_like(classification[i]['wrong'])
            rank_corrects[i] = torch.zeros_like(classification[i]['wrong'])
            rank_difference[i] = torch.zeros_like(classification[i]['wrong'])
            rank_graph[i] = torch.zeros_like(classification[i]['wrong'])
            rank_spanning[i] = torch.zeros_like(classification[i]['wrong'])
            rank_ucp[i] = torch.zeros_like(classification[i]['wrong'])
        else:
            r_mean, w_mean = classification[i]['right'].mean(dim=0), classification[i]['wrong'].mean(dim=0)
            # difference[i] = r_mean - w_mean
            # difference[i][~connected_nodes] = 0

            random[i] = r_mean - w_mean
            random[i][~random_nodes] = 0
            
            _, indices = torch.sort(classification[i]['right'].mean(dim=0), descending=True)
            indices = indices[size:]
            rank_corrects[i] = r_mean - w_mean
            rank_corrects[i][indices] = 0

            _, indices = torch.sort(r_mean - w_mean,descending=True)
            indices = indices[size:]
            rank_difference[i] = r_mean - w_mean
            rank_difference[i][indices] = 0

            _, indices = torch.sort(get_ucp(classification[i]['right'].mean(dim=0)),descending=True)
            indices = indices[size:]
            rank_ucp[i] = r_mean - w_mean
            rank_ucp[i][indices] = 0

            if args.graph:
                rank_graph[i] = r_mean - w_mean
                rank_graph[i][~graph_nodes] = 0

                rank_spanning[i] = r_mean - w_mean
                rank_spanning[i][~ordering_nodes] = 0

    
    acc = 0
    score = None # + difference[label]
    for key in classification.keys():
        if classification[key]['wrong'] is not None:
            if score is None:
                score = classification[key]['wrong'] + random[key]
            else:
                score = torch.cat([score, classification[key]['wrong'] + random[key]], dim=0)
    # print(score.shape)
    score = score.to(device)
    import time
    start_time = time.time()
    pred = model.intervention(score, not args.graph)
    print("--- %s seconds ---" % (time.time() - start_time))
    random_difference = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()
    
    # # print(rank_corrects.keys(), classification.keys())
    # acc = 0
    # score = None # + difference[label]
    # for key in classification.keys():
    #     if classification[key]['wrong'] is not None:
    #         if score is None:
    #             score = classification[key]['wrong'] + rank_corrects[key]
    #         else:
    #             score = torch.cat([score, classification[key]['wrong'] + rank_corrects[key]], dim=0)
    # # print(score.shape)
    # score = score.to(device)
    # pred = model.intervention(score, not args.graph)
    # rank_corrects = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

    # score = None # + difference[label]
    # for key in classification.keys():
    #     if classification[key]['wrong'] is not None:
    #         if score is None:
    #             score = classification[key]['wrong'] + rank_difference[key]
    #         else:
    #             score = torch.cat([score, classification[key]['wrong'] + rank_difference[key]], dim=0)
    # # print(score.shape)
    # score = score.to(device)
    # pred = model.intervention(score, not args.graph)
    # rank_difference = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

    score = None # + difference[label]
    for key in classification.keys():
        if classification[key]['wrong'] is not None:
            if score is None:
                score = classification[key]['wrong'] + rank_ucp[key] #* (rank_ucp[key].size(-1))
            else:
                # score = torch.cat([score, classification[key]['wrong'] + rank_ucp[key] * (rank_ucp[key].size(-1))], dim=0)
                score = torch.cat([score, classification[key]['wrong'] + rank_ucp[key]], dim=0)
    # print(score.shape)
    # print(rank_ucp)
    score = score.to(device)
    pred = model.intervention(score, not args.graph)
    rank_ucp = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()
    
    # if args.graph:
    #     score = None # + difference[label]
    #     for key in classification.keys():
    #         if classification[key]['wrong'] is not None:
    #             if score is None:
    #                 score = classification[key]['wrong'] + rank_graph[key]
    #             else:
    #                 score = torch.cat([score, classification[key]['wrong'] + rank_graph[key]], dim=0)
    #     # print(score.shape)
    #     score = score.to(device)
    #     pred = model.intervention(score, not args.graph)
    #     rank_graph = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

    #     score = None # + difference[label]
    #     for key in classification.keys():
    #         if classification[key]['wrong'] is not None:
    #             if score is None:
    #                 score = classification[key]['wrong'] + rank_spanning[key]
    #             else:
    #                 score = torch.cat([score, classification[key]['wrong'] + rank_spanning[key]], dim=0)
    #     # print(score.shape)
    #     score = score.to(device)
    #     pred = model.intervention(score, not args.graph)
    #     rank_spanning = corrects + (pred.max(1)[1]== wrong_targets.to(device)).float().sum()

    # print("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
    # # logger.info("After intervention accuracy (graph) is {:.4f}".format(graph_difference / (val_image_features.size(0))))
    # print("After intervention accuracy (random) is {:.4f}".format(random_difference / (val_image_features.size(0))))
    # print("After intervention accuracy (rank corrects) is {:.4f}".format(rank_corrects / (val_image_features.size(0))))
    # print("After intervention accuracy (rank difference) is {:.4f}".format(rank_difference / (val_image_features.size(0))))
    
    if logger is not None:
        logger.info("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
        # logger.info("After intervention accuracy (graph) is {:.4f}".format(graph_difference / (val_image_features.size(0))))
        # logger.info("After intervention accuracy (random) is {:.4f}".format(random_difference / (val_image_features.size(0))))
        # logger.info("After intervention accuracy (rank corrects) is {:.4f}".format(rank_corrects / (val_image_features.size(0))))
        # logger.info("After intervention accuracy (rank difference) is {:.4f}".format(rank_difference / (val_image_features.size(0))))
        logger.info("After intervention accuracy (UCP) is {:.4f}".format(rank_ucp / (val_image_features.size(0))))
        # if args.graph:
        #     logger.info("After intervention accuracy (graph) is {:.4f}".format(rank_graph / (val_image_features.size(0))))
        #     logger.info("After intervention accuracy (spanning) is {:.4f}".format(rank_spanning / (val_image_features.size(0))))
    else:
        print("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
        # # logger.info("After intervention accuracy (graph) is {:.4f}".format(graph_difference / (val_image_features.size(0))))
        # print("After intervention accuracy (random) is {:.4f}".format(random_difference / (val_image_features.size(0))))
        # print("After intervention accuracy (rank corrects) is {:.4f}".format(rank_corrects / (val_image_features.size(0))))
        # print("After intervention accuracy (rank difference) is {:.4f}".format(rank_difference / (val_image_features.size(0))))
        print("After intervention accuracy (UCP) is {:.4f}".format(rank_ucp / (val_image_features.size(0))))
        # if args.graph:
        #     print("After intervention accuracy (graph) is {:.4f}".format(rank_graph / (val_image_features.size(0))))
        #     print("After intervention accuracy (spanning) is {:.4f}".format(rank_spanning / (val_image_features.size(0))))

    # if args.graph:
    #     edge_index = torch.nonzero(torch.triu(model.get_characteristic_matrix(adj), diagonal=1))
    #     index_pairs = torch.LongTensor([
    #         (10, 12), (10, 4), (27, 26), (19, 20), (36, 16), (6, 39), (17, 19), (13, 14),
    #         (8, 40), (43, 17), (16, 29), (43, 44), (24, 23), (33, 16), (28, 30), (51, 92),
    #         (11, 12), (49, 4), (53, 54), (14, 13), (41, 14), (69, 70), (81, 75), (73, 72),
    #         (77, 72), (80, 75), (63, 61), (68, 5), (88, 90), (110, 111), (121, 122), (125, 124),
    #         (126, 125), (3, 117), (130, 131), (135, 136), (140, 145), (160, 161), (162, 163),
    #         (169, 170), (174, 176), (179, 180), (186, 187), (196, 228), (197, 210), (199, 200),
    #         (203, 207), (211, 212), (216, 199), (217, 218)
    #     ])
    #     print(edge_index)
    #     # print(index_pairs)
    #     print(f'{args.dataset}_gen_num_edges_{edge_index.size(0)}_acc_{overall_acc / n}')
    #     graph_visualize(None, f'{args.dataset}_gen_num_edges_{edge_index.size(0)}_acc_{overall_acc / n}', 
    #     labels=range(text_features.size(0)), edge_index=edge_index, node_size=1000)
    #     edge_index = torch.LongTensor([
    #         [0, 15],[0, 224],[4, 30],[4, 241],[4, 245],[4, 246],[4, 287],[4, 300],[4, 347],[8, 15],[8, 36],[10, 224],
    #         [11, 47],[11, 245],[12, 36],[12, 239],[15, 36],[15, 40],[15, 47],[36, 224],[36, 245],[47, 225],[47, 245],
    #         [47, 288],[54, 239],[103, 335],[224, 225],[224, 240],[224, 300],[225, 240],[225, 337],[235, 355],[239, 241],
    #         [239, 245],[239, 347],[240, 347],[241, 245],[241, 288],[241, 300],[241, 337],[245, 300],[250, 347],[300, 363],[335, 347],[347, 355]])
    #     graph_visualize(None, f'{args.dataset}_gen_num_edges_{edge_index.size(0)}_acc_{overall_acc / n}_connected_only', 
    #     labels=torch.unique(edge_index.flatten()).detach().tolist(), edge_index=edge_index, node_size=10000, font_size=40)

        # random_graph = torch.rand(num_concepts, num_concepts)
        # random_graph = 0.5 * (random_graph + random_graph.t())
        
        # random_graph = torch.where(random_graph < args.ratio, 0.1, 0)
        # for i in range(num_concepts):
        #     random_graph[i,i] = 0
        # random_graph = random_graph.to(device)
        # # random_graph = model.get_characteristic_matrix(random_graph)

        # for edge_param in model.edge_param:
        #     edge_param.data += random_graph

        # n, total_loss, overall_acc = 0, 0, 0
        # for instance in tqdm(loader):

        #     x = instance['x'].to(device)
        #     concepts = instance['concept'].to(device)
        #     target = instance['label'].to(device)
        #     image = instance['image'].to(device)
        #     batch_size = x.size(0)
            
        #     # loss, pred, c_adj = model.graph_intervention(image, target, random_graph, args.graph, text_features=text_features)
        #     loss, pred, c_adj = model(image, target, args.graph, text_features=text_features)

        #     acc = (pred.max(1)[1]== target).float().mean()
        #     corrects += (pred.max(1)[1]== target).float().sum()
        #     overall_acc += acc * batch_size

        #     total_loss += loss.item() * batch_size
        #     n += batch_size

        #     torch.cuda.empty_cache()
        # print("After Graph-level Intervention Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
    # print(wrong_scores)
        
        
    

if __name__=='__main__':
    args = parser.parse_args()

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    main(args)