import torch
from torch import nn, einsum
import copy
import torchvision
import clip
import argparse
import random
import numpy as np
import os
from tqdm import tqdm
from dataset import datasets, ImbalancedDatasetSampler
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch_geometric
from models import GCN, GNN, ConceptGNN
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
import math
import torch.nn.functional as F
import networkx as nx
import matplotlib.pyplot as plt
import logging
import math
from transformers import get_scheduler, get_cosine_schedule_with_warmup
from sklearn.manifold import TSNE
import pickle
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

import utils
import data_utils

from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
from data_utils import get_data

from selection_algo import global_submodular_select


parser = argparse.ArgumentParser(description='Settings for creating CBM')

parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--concept_set", type=str, default=None, 
                    help="path to concept set name")
parser.add_argument("--raw_set", type=str, default=None, help='path to raw images')
parser.add_argument("--backbone", type=str, default="RN50", help="Which pretrained model to use as backbone")
parser.add_argument("--clip_name", type=str, default="ViT-B/16", help="Which CLIP model to use [ViT-B/16, RN50]")
parser.add_argument("--seed", type=int, default=42, help="Random seed configuration")
parser.add_argument("--saved_activations", action='store_true', help='Use saved features as image features and text features')
parser.add_argument("--output_dir", default='./log', type=str)
parser.add_argument("--logger", action='store_true')
parser.add_argument('--random_concept', action='store_true', default=False)

parser.add_argument("--device", type=str, default="cuda", help="Which device to use")
parser.add_argument("--batch_size", type=int, default=512, help="Batch size used when saving model/CLIP activations")
parser.add_argument("--saga_batch_size", type=int, default=256, help="Batch size used when fitting final layer")
parser.add_argument("--proj_batch_size", type=int, default=50000, help="Batch size to use when learning projection layer")

parser.add_argument("--feature_layer", type=str, default='layer4', 
                    help="Which layer to collect activations from. Should be the name of second to last layer in the model")
parser.add_argument("--activation_dir", type=str, default='saved_activations', help="save location for backbone and CLIP activations")
parser.add_argument("--save_dir", type=str, default='saved_models', help="where to save trained models")
parser.add_argument("--clip_cutoff", type=float, default=0.25, help="concepts with smaller top5 clip activation will be deleted")
parser.add_argument("--proj_steps", type=int, default=1000, help="how many steps to train the projection layer for")
parser.add_argument("--interpretability_cutoff", type=float, default=0.45, help="concepts with smaller similarity to target concept will be deleted")
parser.add_argument("--lam", type=float, default=0.0007, help="Sparsity regularization parameter, higher->more sparse")
parser.add_argument("--n_iters", type=int, default=1000, help="How many iterations to run the final layer solver for")
parser.add_argument("--print", action='store_true', help="Print all concepts being deleted in this stage")
parser.add_argument("--epoch", default=200, type=int, help="Set the number of epochs for training")
parser.add_argument("--model", default='linear', help='Choose the classifier model, [linear, mlp, transformer, gnn]')
parser.add_argument("--lr", default=0.001, type=float, help='Set the learning rate for optimizer')
parser.add_argument("--scheduler", default=None, type=str, help='Define the scheduler to use')
parser.add_argument("--coarsening", action='store_true', help='learn a sparse graph structure')
parser.add_argument("--sparse", action='store_true', help='Sparse linear layer for final classification')
parser.add_argument("--ratio", default=0.5, type=float, help="Coarsening ratio for graph pooling layer")
parser.add_argument("--end2end", action='store_true', default=False, help='Use End2End training')
parser.add_argument("--pooling", default='sag', type=str, help='Set the pooling module in model')
parser.add_argument("--save_stats", default=False, action='store_true')
parser.add_argument("--prototype", default=False, action='store_true', help="Clustering image features by their labels and taking the cluster center as prototype feature")
parser.add_argument("--auxiliary_mlp", default=False, action='store_true')
parser.add_argument("--module", default='gat_v2', type=str, help="define module in GNN model")
parser.add_argument("--sparsity", default=1, type=float)
parser.add_argument("--use_concept_vectors", action='store_true', default=False)
parser.add_argument("--independent", action='store_true', default=False)
parser.add_argument("--independent_epoch", type=int, default=100)
parser.add_argument("--method", default='gnn', type=str, help='choose a methodology from [gnn, diffusion]')
parser.add_argument("--checkpoint", default=False, action='store_true')
parser.add_argument("--linear_lr", default=0.1, type=float)
parser.add_argument("--linear_scheduler", default=False, type=str)
parser.add_argument("--fixed_structure", action='store_true', default=False)
parser.add_argument("--tau", default=0.9, type=float)
parser.add_argument("--num_cycles", default=0.5, type=float)
parser.add_argument("--reg_alpha", default=0.7, type=float)
parser.add_argument("--init_beta", default=0.8, type=float)
parser.add_argument("--learnable_graph", action='store_true', default=False)
parser.add_argument("--graph", action='store_false', default=True)
parser.add_argument("--with_supervision", action='store_true', default=False)
parser.add_argument("--ddi_name", type=str, default='deepderm')
parser.add_argument("--filtering", type=str, default=None, help="choose from [None, submodular, random, similarity, discriminative, coverage]")
parser.add_argument("--filtering_size", type=int, default=200)
parser.add_argument("--ssl", default=False, action='store_true')
parser.add_argument("--ground_graph", default=False, action='store_true')
parser.add_argument("--cdm", default=False, action='store_true')
parser.add_argument("--score_type", default=None, type=str)

def graph_visualize(data, fig_name, labels=None, edge_index=None):
    plt.figure(figsize=(50, 50))

    G = nx.Graph()
    if labels is None:
        G.add_nodes_from(range(len(data)))
        labels = range(len(data))
    else:
        G.add_nodes_from(range(len(labels)))

    if edge_index is not None:
        if edge_index.size(-1) != 2:
            edge_index = edge_index.t()
        G.add_edges_from(edge_index.detach().tolist())
    else:
        G.add_edges_from(data.edge_index.t().detach().tolist())
    
    nx.draw(G, label=labels, with_labels=True)
    plt.savefig('./visualization/'+fig_name+'.pdf')
    return

def train(dataset, model, opt, scheduler, args, device, logger=None, val_set=None, val_loader=None, text_features=None, score_type=None):
    if args.with_supervision:
        sampler = ImbalancedDatasetSampler(dataset)
        loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler)
    else:
        loader = torch_geometric.loader.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
    if text_features is not None:
        text_features = text_features.to(device)
    for epoch in tqdm(range(args.epoch)):
        n, total_loss, overall_acc = 0, 0, 0
        sample = dataset[0]
        # graph_visualize(sample['graph'].x, f'empty_sample_{epoch}', labels=None, edge_index=sample['graph'].edge_index)
        for instance in loader:
            opt.zero_grad()

            # image_feature = instance['image'].to(device)
            # concept_score = instance['concept'].to(device)
            # graph = instance['graph']
            # indices = instance['idx'].detach().numpy()
            # x, edge_index, batch = graph.x.to(device), graph.edge_index.to(device), graph.batch.to(device)
            # x.requires_grad = True
            # target = graph.y.to(device)
            # edge_index = torch_geometric.utils.to_undirected(edge_index)
            # batch_size = image_feature.size(0)

            # x = x.view(batch_size, -1, x.size(-1))

            x = instance['x'].to(device)
            concepts = instance['concept'].to(device)
            target = instance['label'].to(device)
            batch_size = x.size(0)

            image = instance['image'].to(device)


            # if args.independent:
            #     loss, edge_index = model.graph_construct(x, edge_index, image_feature, batc, graph_construction=not args.fixed_structure)
            #     edge_indices = torch_geometric.utils.unbatch_edge_index(edge_index, batch, batch_size=image_feature.size(0))
            #     # dataset.update(edge_indices, indices)
            # else:
            #     pred, loss, edge_indices = model(x, edge_index, image_feature, batch, 
            #     target, concept_score, args.end2end, graph_construction=epoch<10)
            #     # dataset.update(edge_indices, indices)
            #     if epoch == 9 :
            #         dataset.update(edge_indices, indices)
                    # if val_loader is not None:
                    #     for val_instance in val_loader:
            if args.with_supervision:
                loss, pred, c_adj = model(x, target, args.graph, concepts=concepts)
            else:
                # loss, pred, c_adj = model(x, target, args.graph)
                loss, pred, c_adj = model(image, target, args.graph, text_features=text_features, ssl=args.ssl, score_type=score_type)
            # print(loss)
            # print(pred)
            # exit()
                
                    
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            loss.backward()
            opt.step()

            if scheduler is not None and scheduler != 'step':
                scheduler.step()

            if not args.independent and not args.ssl:
                acc = (pred.max(1)[1]== target).float().mean()
                overall_acc += acc * batch_size

            total_loss += loss.item() * batch_size
            n += batch_size
            
            torch.cuda.empty_cache()
        
        # model.update()

        if (epoch+1) % 10 == 0:
            if logger is not None:
                if args.independent:
                    logger.info("Epoch {} Loss: {:.4f}".format(epoch+1, total_loss / n))
                else:
                    logger.info("Epoch {} Loss: {:.4f}, Acc: {:.4f}".format(epoch+1, total_loss / n, overall_acc / n))
            else:
                if args.independent:
                    print("Epoch {} Loss: {:.4f}".format(epoch+1, total_loss / n))
                else:
                    print("Epoch {} Loss: {:.4f}, Acc: {:.4f}".format(epoch+1, total_loss / n, overall_acc / n))
            # model.update()
        
        if args.scheduler == 'step':
            if epoch == 100:
                scheduler.gamma = 0.9
            scheduler.step()
    if args.graph:
        print(pred)
        print(torch.nonzero(torch.triu(c_adj, diagonal=1)))
        print(torch.nonzero(torch.triu(c_adj, diagonal=1)).size())

def main(args):
    if args.logger:
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S',
                            level=logging.INFO)
        logger = logging.getLogger(__name__)

        # if args.end2end:
        #     file_name = f'/end2end/dataset_{args.dataset}_seed_{args.seed}.log'
        # else:
        #     # file_name = f'/dataset_{args.dataset}_seed_{args.seed}_coarsening_ratio_{args.ratio}_no_concept_edge.log'
        #     if args.prototype:
        #         file_name = f'/prototype_{args.pooling}/dataset_{args.dataset}_seed_{args.seed}_coarsening_ratio_{args.ratio}'
        #     elif args.random_concept:
        #         file_name = f'/{args.pooling}/dataset_{args.dataset}_seed_{args.seed}_coarsening_ratio_{args.ratio}_random'
        #     else:
        #         file_name = f'/{args.pooling}/dataset_{args.dataset}_seed_{args.seed}_coarsening_ratio_{args.ratio}'
        # if args.use_concept_vectors:
        #     file_name += '_cv'
        # if args.independent:
        #     file_name +='_independent'
        dname = (args.concept_set.split("/")[-1]).split(".")[0].split('_')[0]
        
        if not os.path.exists(f'{args.output_dir}/{dname}'):
            os.system(f'mkdir {args.output_dir}/{dname}')
        backbone = args.clip_name.replace('/', '')
        if args.cdm:
            backbone1 = backbone + '_cdm'
        else:
            backbone1 = backbone
        if args.filtering is not None:
            if not os.path.exists(f'{args.output_dir}/{dname}/{args.filtering}'):
                os.system(f'mkdir {args.output_dir}/{dname}/{args.filtering}')
            if not os.path.exists(f'{args.output_dir}/{dname}/{args.filtering}/{backbone1}'):
                os.system(f'mkdir {args.output_dir}/{dname}/{args.filtering}/{backbone1}')
        elif not os.path.exists(f'{args.output_dir}/{dname}/{backbone1}'):
                os.system(f'mkdir {args.output_dir}/{dname}/{backbone1}')
        
        
        # print(os.path.exists(f'{args.output_dir}/{args.dataset}/{backbone}'))
        if args.filtering is not None:
            file_name = f'/{dname}/{args.filtering}/{backbone}/seed_{args.seed}_num_epochs_{args.epoch}_alpha_{args.reg_alpha}_beta_{args.init_beta}'
        else:
            file_name = f'/{dname}/{backbone}/seed_{args.seed}_num_epochs_{args.epoch}_alpha_{args.reg_alpha}_beta_{args.init_beta}'
        if args.ground_graph:
            file_name += '_ground_graph'
        if not args.learnable_graph:
            file_name += "_fixed"
        else:
            file_name += f"_tau_{args.tau}"
        if args.ssl:
            file_name = f'/{args.dataset}/{backbone}/ssl_num_epochs_{args.epoch}'
        file_name += '.log'

        if 'new' in args.concept_set:
            file_name = 'new_' + file_name
        # print('Logger file:', args.output_dir+file_name)
        fh = logging.FileHandler(args.output_dir+file_name, mode="w", encoding="utf-8")
        logger.addHandler(fh)
    else:
        backbone = args.clip_name.replace('/', '')
        logger = None
    # print(logger)
    # exit()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    dataset_train = args.dataset + "_train"
    dataset_val = args.dataset + "_val"
    if args.with_supervision:
        image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_train, args.ddi_name)
        val_image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_val, args.ddi_name)
        concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
        text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)
    else:
        if args.clip_name in ['patch_BioViL_T', 'BioViL_T', 'biomedclip']:
            image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_train, backbone)
            val_image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_val, backbone)
            concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
            if args.clip_name == 'patch_BioViL_T':
                args.clip_name = 'BioViL_T'
            text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)
        else:
            image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_train, backbone)
            val_image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_val, backbone)
            concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
            text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)
    
    
    image_features = torch.load(image_save_name, map_location="cpu").float()
    val_image_features = torch.load(val_image_save_name, map_location="cpu").float()
    if 'patch' in image_save_name:
        image_features = image_features.mean(dim=(2,3))
        val_image_features = val_image_features.mean(dim=(2,3))
    # print(image_features.shape)
    # print(val_image_features.shape)
    # print(text_features.shape)
    # exit()
    
    # print(image_save_name, val_image_save_name, text_save_name)
    # print(image_features.shape)
    # print(val_image_features.shape)
    # print(text_features.shape)
    text_features = torch.load(text_save_name, map_location="cpu").float()
    if args.random_concept:
        text_features = torch.rand(text_features.size())
        print("Use random concept features!")
        

    labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
    val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

    
    num_classes = torch.unique(labels)
    num_images_per_class = []
    for i in num_classes:
        num_images_per_class.append((labels == i).float().sum().item())
    num_images_per_class = torch.LongTensor(num_images_per_class)
    
    if args.filtering == 'submodular':
        
        text_features, selected = global_submodular_select(image_features, text_features, 1, args.filtering_size, num_images_per_class, [1,1])
        import numpy as np
        selected = np.array(selected, dtype=np.int_)
        # np.save(f'/home/hxu2/Label-free-CBM/data/concept_sets/{args.dataset}_{args.filtering}_{args.filtering_size}.npy', selected)
        # exit()
    # print(text_features.shape)
    # exit()
    
    # print(torch.norm(image_features.mean(dim=0) - val_image_features.mean(dim=0)))
    # print(labels)
    # print(val_labels)
    # exit()
    # print(labels.shape)
    # print(len(torch.unique(labels)))
    # exit()

    in_channels = image_features.size(-1)
    if args.with_supervision:
        concept_scores = data_utils.get_concepts_only(dataset_train)
        val_concept_scores = data_utils.get_concepts_only(dataset_val)

        train_set = datasets.ImageDataset(image_features=image_features, concept_scores=concept_scores, labels=labels)
    else:
        concept_scores = F.normalize(image_features, dim=-1) @ F.normalize(text_features, dim=-1).t()
        val_concept_scores = F.normalize(val_image_features, dim=-1) @ F.normalize(text_features, dim=-1).t()

        train_set = datasets.ConceptDataset(image_features=image_features, text_features=text_features, tau=args.tau,
                                            concept_scores=concept_scores, labels=labels, graph_construction=args.fixed_structure)

    # sample = train_set[0]
    # graph_visualize(sample['graph'].x, 'empty_sample', labels=None, edge_index=sample['graph'].edge_index)
    # exit()

    # sample1 = train_set[0]
    # print(sample1['graph'].x)
    # exit()

    # print(sample1['graph'].x)
    # print(sample2['graph'].x)
    # print(len(torch.unique(labels)))
    # exit()
    model = ConceptGNN(input_dim=image_features.size(-1), output_dim=len(torch.unique(labels)), num_concepts=text_features.size(0),
        hidden_dim=text_features.size(-1), module='gcn', tau=args.tau, alpha=args.reg_alpha, beta=args.init_beta, cdm=args.cdm)
    if args.dataset == 'cub' and args.ground_graph:
        import numpy as np

        # Define the number of nodes
        num_nodes = text_features.size(0)

        # Initialize an adjacency matrix with zeros
        adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)

        # Define the list of index pairs
        index_pairs = [
            (10, 12), (10, 4), (27, 26), (19, 20), (36, 16), (6, 39), (17, 19), (13, 14),
            (8, 40), (43, 17), (16, 29), (43, 44), (24, 23), (33, 16), (28, 30), (51, 92),
            (11, 12), (49, 4), (53, 54), (14, 13), (41, 14), (69, 70), (81, 75), (73, 72),
            (77, 72), (80, 75), (63, 61), (68, 5), (88, 90), (110, 111), (121, 122), (125, 124),
            (126, 125), (3, 117), (130, 131), (135, 136), (140, 145), (160, 161), (162, 163),
            (169, 170), (174, 176), (179, 180), (186, 187), (196, 228), (197, 210), (199, 200),
            (203, 207), (211, 212), (216, 199), (217, 218)
        ]

        # Fill in the adjacency matrix
        for i, j in index_pairs:
            i,j = i-1, j-1
            adj_matrix[i][j] = 1
            adj_matrix[j][i] = 1  # Since the graph is undirected

        adj_matrix = torch.FloatTensor(adj_matrix)

        model.init_ground_truth_graph(adj_matrix)
        # print(model.edge_param1.sum())
        # exit()


    weight = F.normalize(text_features, dim=-1) @ F.normalize(text_features, dim=-1).t()
    # model.init_edge_weight(weight)
    model.edge_param1.requires_grad = args.learnable_graph
    model.edge_param2.requires_grad = args.learnable_graph
    model.edge_param3.requires_grad = args.learnable_graph
    # for i, _ in enumerate(model.edge_param):
    #     model.edge_param[i] = model.edge_param[i].to(device)
    #     model.edge_param[i].requires_grad = args.learnable_graph
    print("Is graph learnable:", args.learnable_graph)
    model = model.to(device)
    # print(model)

    opt = Adam(model.parameters(), lr=args.lr)
    if args.scheduler in ['cosine', "polynomial", "inverse_sqrt", "reduce_lr_on_plateau"]:
        num_warmup_steps = (image_features.size(0) // args.batch_size) * 5
        num_training_steps = math.ceil(image_features.size(0) / args.batch_size) * args.epoch
        scheduler = get_scheduler(args.scheduler, opt, 
            num_warmup_steps=num_warmup_steps, 
            num_training_steps=num_training_steps,
        )
    elif args.scheduler == "cosine_with_restarts":
        num_warmup_steps = (image_features.size(0) // args.batch_size) * 5
        num_training_steps = math.ceil(image_features.size(0) / args.batch_size) * args.epoch
        scheduler = get_cosine_schedule_with_warmup(
            opt,
            num_warmup_steps=num_warmup_steps, 
            num_training_steps=num_training_steps,
            num_cycles=args.num_cycles,
        )
    elif args.scheduler == 'step':
        scheduler = StepLR(opt, step_size=10, gamma=0.99)
    else:
        scheduler = None

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.train()
    train(train_set, model, opt, scheduler, args, device, logger, text_features=text_features, score_type=args.score_type)
    if args.independent:
        model.layers1.requires_grad = False
        args.independent = False
        args.epoch = args.independent_epoch
        opt = Adam(model.parameters(), lr=args.linear_lr)
        if args.linear_scheduler in ['cosine', "cosine_with_restarts", "polynomial", "inverse_sqrt", "reduce_lr_on_plateau"]:
            num_warmup_steps = (image_features.size(0) // args.batch_size) * 5
            num_training_steps = math.ceil(image_features.size(0) / args.batch_size) * args.epoch
            scheduler = get_scheduler(args.scheduler, opt, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
        elif args.linear_scheduler == 'step':
            scheduler = StepLR(opt, step_size=10, gamma=0.99)
        else:
            scheduler = None
        
        train(train_set, model, opt, scheduler, args, device, logger)

    # model.eval()

    if args.with_supervision:
        val_set = datasets.ImageDataset(image_features=val_image_features, concept_scores=val_concept_scores, labels=val_labels)
        loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
    else:
        val_set = datasets.ConceptDataset(image_features=val_image_features, text_features=text_features, tau=args.tau,
                                    concept_scores=val_concept_scores, labels = val_labels, graph_construction=args.fixed_structure)
        loader = torch_geometric.loader.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=8)

    n, total_loss, overall_acc = 0, 0, 0
    sample = val_set[0]
    model.eval()
    text_features = text_features.to(device)
    # graph_visualize(sample['graph'].x, f'val_before', labels=None, edge_index=sample['graph'].edge_index)
    for instance in tqdm(loader):
        # image_feature = instance['image'].to(device)
        # concept_score = instance['concept'].to(device)
        # graph = instance['graph']
        # indices = instance['idx'].detach().numpy()
        # x, edge_index, batch = graph.x.to(device), graph.edge_index.to(device), graph.batch.to(device)
        # x.requires_grad = True
        # target = graph.y.to(device)
        # # if not args.fixed_structure:
        # #     _, edge_index, concept_score = model.graph_construct(x, edge_index, image_feature, batch)
        # #     edge_indices = torch_geometric.utils.unbatch_edge_index(edge_index, batch, batch_size=image_feature.size(0))
        # #     val_set.update(edge_indices, indices)

        # edge_index = torch_geometric.utils.to_undirected(edge_index)
        # batch_size = image_feature.size(0)
        # x = x.view(batch_size, -1, x.size(-1))

        x = instance['x'].to(device)
        concepts = instance['concept'].to(device)
        target = instance['label'].to(device)
        image = instance['image'].to(device)
        batch_size = x.size(0)

        # pred, loss, edge_indices = model(x, edge_index, image_feature, batch, 
        # target, concept_score, args.end2end)
        if args.with_supervision:
            loss, pred, c_adj = model(x, target, args.graph, concepts=concepts)
        else:
            # loss, pred, c_adj = model(x, target, args.graph)
            loss, pred, c_adj = model(image, target, args.graph, text_features=text_features, ssl=args.ssl, score_type=args.score_type)

        if not args.ssl:
            acc = (pred.max(1)[1]== target).float().mean()
            overall_acc += acc * batch_size

        total_loss += loss.item() * batch_size
        n += batch_size

        torch.cuda.empty_cache()
    # if args.graph:
    #     print(pred.max(1)[1])
    # print(pred.max(1)[1])
    # print(target)
    sample = val_set[0]

    if args.graph:
        edge_index = torch.nonzero(torch.triu(c_adj, diagonal=1))
        graph_visualize(None, f'{args.dataset}_gen_num_edges_{edge_index.size(0)}_acc_{overall_acc / n}', labels=range(text_features.size(0)), edge_index=edge_index)
    # graph_visualize(sample['graph'].x, f'val_after_{args.tau}', labels=None, edge_index=sample['graph'].edge_index)

    if logger is not None:
        logger.info("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
    else:
        print("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))

    if args.graph:
        model_dir_name = f"backbone_{backbone}_num_epochs_{args.epoch}_{args.dataset}_seed_{args.seed}_num_edges_{edge_index.size(0)}_alpha_{args.reg_alpha}_beta_{args.init_beta}_tau_{args.tau}"
    else:
        model_dir_name = f"backbone_{backbone}_num_epochs_{args.epoch}_{args.dataset}_seed_{args.seed}_no_graph"

    if args.ground_graph:
        model_dir_name += '_ground_graph'
    
    if args.filtering is not None:
        model_dir_name = f'{args.filtering}_{model_dir_name}'
    
    if 'new' in args.concept_set:
        model_dir_name = 'new_' + model_dir_name

    if args.ssl:
        model_dir_name = f"backbone_{backbone}_num_epochs_{args.epoch}_{args.dataset}_ssl"
    
    saved_path = os.path.join(args.save_dir, model_dir_name)
    try:
        os.mkdir(saved_path)
        if logger is not None:
            logger.info(f"Model and graph structure is saving to {saved_path}")
        else:
            print(f"Model and graph structure is saving to {saved_path}")
    except:
        if logger is not None:
            logger.info(f"{saved_path} exists")
        else:
            print(f"{saved_path} exists")
    if args.score_type is not None:
        torch.save(model.state_dict(), os.path.join(saved_path, f'{args.score_type}_model.pkl'))
    else:
        torch.save(model.state_dict(), os.path.join(saved_path, 'model.pkl'))
    try:
        np.save(os.path.join(saved_path, 'edge_index.npy'), edge_index.detach().cpu().numpy())
    except:
        pass

    if args.ssl:
        train_loader = torch_geometric.loader.DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=8)
        train_c = None
        for instance in tqdm(train_loader):

            x = instance['x'].to(device)
            concepts = instance['concept'].to(device)
            target = instance['label'].to(device)
            image = instance['image'].to(device)
            batch_size = x.size(0)

            loss, concept_scores, c_adj = model(image, target, args.graph, text_features=text_features, ssl=args.ssl, score_type=args.score_type)
            if train_c is None:
                train_c = concept_scores.detach().cpu()
            else:
                train_c = torch.cat([train_c, concept_scores.detach().cpu()], dim=0)
        
        val_c = None
        for instance in tqdm(loader):

            x = instance['x'].to(device)
            concepts = instance['concept'].to(device)
            target = instance['label'].to(device)
            image = instance['image'].to(device)
            batch_size = x.size(0)

            loss, concept_scores, c_adj = model(image, target, args.graph, text_features=text_features, ssl=args.ssl)
            if val_c is None:
                val_c = concept_scores.detach().cpu()
            else:
                val_c = torch.cat([val_c, concept_scores.detach().cpu()], dim=0)
        
        print(train_c.shape, val_c.shape)
        train_checkpoint = f"/home/hxu2/CBM-Graph/saved_activations/concept_scores/{args.dataset}_train_c"
        val_checkpoint = f"/home/hxu2/CBM-Graph/saved_activations/concept_scores/{args.dataset}_val_c"

        if args.score_type:
            train_checkpoint += f'_{args.score_type}'
            val_checkpoint += f'_{args.score_type}'

        torch.save(train_c, train_checkpoint + '.pt')
        torch.save(val_c, val_checkpoint + '.pt')
        

    

            
if __name__=='__main__':
    args = parser.parse_args()

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    main(args)