import torch
from torch import optim
import torchmetrics
from torch_geometric.loader import DataLoader
from torch.nn import functional as F
from prompt_graph.model import GAT, GCN, GCov, GIN, GraphSAGE, GraphTransformer
from prompt_graph.data import load4graph, load4zero
from prompt_graph.prompt import MorpherGraphPrompt, MorpherTextPrompt
from prompt_graph.utils import center_embedding, seed_torch
from transformers import PreTrainedModel, PreTrainedTokenizer
import time
import os.path as osp
import numpy as np
import pdb



SAVE_PROJ = True
    
SAVE_DATA = True






class ZeroShot:
    def __init__(self, pre_train_model_path=None, pretrain_method=None, gnn_type='GCN', hid_dim = 128, num_layer = 2, dataset_name='MUTAG', prompt_type='Morpher', epochs=10, shot_num=10, 
                batch_size=16, prompt_graph_token_num = 10, tokenizer: PreTrainedTokenizer = None, llm: PreTrainedModel = None, device : int = 1,
                projector_lr=0.01, projector_weight_decay=0.1, projector_tune_lr=0.001, projector_tune_weight_decay=0.1,
                pg_lr=0.001, pg_weight_decay=0.001, text_prompt_lr=0.001, text_prompt_weight_decay=0.001,
                projector_dropout_ratio=0.2, temperature=2.0, text_prompt_start_vocab='a graph with property',
                projector_epochs=2001, projector_train_eval_diff_threshold=0.1, projector_train_modular=100, projector_tune_epochs=50, prompt_tune_epochs=50,
                train_val_acc_diff_tol=0.0, val_acc_threshold=1.0, warmup_epochs=0, random_seed=42,
                train_good_threshold = 0.9):
        
        # pretrained GNN, LLM and experiment settings hyperparameters
        self.pre_train_model_path = pre_train_model_path
        self.pretrain_method = pretrain_method
        self.device = torch.device('cuda:'+ str(device) if torch.cuda.is_available() else 'cpu')
        self.hid_dim = hid_dim
        self.num_layer = num_layer
        self.dataset_name = dataset_name
        self.shot_num = shot_num
        self.batch_size = batch_size
        self.gnn_type = gnn_type
        self.prompt_type = prompt_type
        self.epochs = epochs
        self.prompt_graph_token_num = prompt_graph_token_num
        self.tokenizer = tokenizer
        self.llm = llm
        self.llm_dim = llm.config.hidden_size

        self.projector_epochs = projector_epochs
        self.projector_train_eval_diff_threshold = projector_train_eval_diff_threshold
        self.projector_train_modular = projector_train_modular
        self.projector_tune_epochs = projector_tune_epochs
        self.prompt_tune_epochs = prompt_tune_epochs

        # optimization hyperparameters
        self.projector_lr = projector_lr
        self.projector_weight_decay = projector_weight_decay
        self.projector_tune_lr = projector_tune_lr
        self.projector_tune_weight_decay = projector_tune_weight_decay
        self.pg_lr = pg_lr
        self.pg_weight_decay = pg_weight_decay
        self.text_prompt_lr = text_prompt_lr
        self.text_prompt_weight_decay = text_prompt_weight_decay
        self.projector_dropout_ratio = projector_dropout_ratio
        self.temperature = temperature
        self.text_prompt_start_vocab = text_prompt_start_vocab
        self.initialize_lossfn()

        # picking best model
        self.train_val_acc_diff_tol = train_val_acc_diff_tol
        self.val_acc_threshold = val_acc_threshold
        self.warmup_epochs = warmup_epochs
        self.random_seed = random_seed

        self.load_data()
        self.initialize_gnn()
        self.initialize_prompt()
        # self.projector = torch.nn.Sequential(torch.nn.Linear(self.hid_dim, self.llm_dim)).to(self.device)
        self.projector = torch.nn.Sequential(torch.nn.Dropout(self.projector_dropout_ratio), torch.nn.Linear(self.hid_dim, self.llm_dim), torch.nn.Tanh()).to(self.device)
        # self.projector = torch.nn.Sequential(torch.nn.Linear(self.hid_dim, self.llm_dim)).to(self.device)
        self.answering =  torch.nn.Sequential(torch.nn.Linear(self.hid_dim, self.output_dim),
                                            torch.nn.Softmax(dim=1)).to(self.device)
        self.initialize_optimizer()

        self.gnn_temp = 0.01
        self.train_good_threshold = train_good_threshold


    def meannormalize_labeltotextemb(self):
        # calculate the mean and substract it from the embeddings
        if self.dataset_name in ['zero1', 'zero2', 'zero3']:
            mean_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])[:-1].mean(dim=0)
            for label in self.label_to_text_emb:
                self.label_to_text_emb[label] = self.label_to_text_emb[label] - mean_emb

        else:
            mean_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb]).mean(dim=0)
            for label in self.label_to_text_emb:
                self.label_to_text_emb[label] = self.label_to_text_emb[label] - mean_emb


    def load_data(self):
        if self.dataset_name in ['MUTAG', 'ENZYMES', 'PROTEINS', 'MSRC_21', 'MSRC_21C']:
            self.input_dim, self.output_dim, self.train_dataset, self.test_dataset, self.val_dataset, _= load4graph(self.dataset_name, self.shot_num)
        elif self.dataset_name in ['zero1', 'zero2', 'zero3']:
            self.input_dim, self.output_dim, self.train_dataset, self.test_dataset, self.graph_list = load4zero(self.dataset_name)
            # print dataset statistics, given self.graph_list to be a list of Data(x, edge_index, y)
            print(f"Dataset: {self.dataset_name}, Number of graphs: {len(self.graph_list)}, Number of classes: {self.output_dim}, Number of features: {self.input_dim}")
            # average number of nodes per graph
            avg_num_nodes = sum([graph.num_nodes for graph in self.graph_list]) / len(self.graph_list)
            print(f"Average number of nodes per graph: {avg_num_nodes}")
            # average number of edges per graph
            avg_num_edges = sum([graph.num_edges for graph in self.graph_list]) / len(self.graph_list)
            print(f"Average number of edges per graph: {avg_num_edges}")


        else:
            raise ValueError(f"Unsupported dataset: {self.dataset_name}")
        
        if self.dataset_name == 'MUTAG':
            self.label_to_text = {0: 'non-mutagenic on Salmonella typhimurium', 1: 'mutagenic on Salmonella typhimurium'}

        if self.dataset_name == 'ENZYMES':
            # for enzymes dataset, the labels are Enzyme Commission top level enzyme classes (EC classes)
            self.label_to_text = {0: 'oxidoreductases', 1: 'transferases', 2: 'hydrolases', 3: 'lyases', 4: 'isomerases', 5: 'ligases'}

        if self.dataset_name == 'PROTEINS':
            self.label_to_text = {0: 'enzyme', 1: 'non-enzyme'}

        if self.dataset_name == 'MSRC_21':
            self.label_to_text = {1: 'building', 2: 'grass', 3: 'tree', 4: 'cow', 5: 'sheep', 6: 'sky', 7: 'airplane', 8: 'water', 9: 'face', 10: 'car', 
                                11: 'bicycle', 12: 'flower', 13: 'sign', 14: 'bird', 15: 'book', 16: 'chair', 17: 'road', 18: 'cat', 19: 'dog', 20: 'body', 21: 'boat'}
            
        if self.dataset_name == 'MSRC_21C':
            self.label_to_text = {1: 'building', 2: 'grass', 3: 'tree', 4: 'cow', 5: 'sheep', 6: 'sky', 7: 'airplane', 8: 'water', 9: 'face', 10: 'car', 
                                11: 'bicycle', 12: 'flower', 13: 'sign', 14: 'bird', 15: 'book', 16: 'chair', 17: 'road', 18: 'cat', 19: 'dog', 20: 'body', 21: 'boat'}
            
        if self.dataset_name == 'zero1':
            # self.label_to_text = {0: 'bio', 1: 'informatics', 2: 'bioinformatics'}
            self.label_to_text = {0: 'machine learning', 1: 'theory', 2: 'machine learning theory'}

        if self.dataset_name == 'zero2':
            self.label_to_text = {0: 'biology', 1: 'informatics', 2: 'bioinformatics'}

        if self.dataset_name == 'zero3':
            # self.label_to_text = {0: 'algebra', 1: 'geometry', 2: 'algebraic geometry'}
            self.label_to_text = {0: 'cardiology', 1: 'neurology', 2: 'neurocardiology'}


        self.tokenized_label_to_text = {i: self.tokenizer.encode(self.label_to_text[i], return_tensors='pt').to(self.device) for i in self.label_to_text}
        # for each tokenized label to text, remove the first and the last token. note that each tokenized label to text is a tensor of shape (1, n_tokens)
        self.tokenized_label_to_text = {i: self.tokenized_label_to_text[i][0, 1:-1].unsqueeze(0) for i in self.tokenized_label_to_text}

        self.label_to_text_emb = {i: self.llm(self.tokenized_label_to_text[i])[0].mean(dim=1).squeeze() for i in self.label_to_text}
        self.meannormalize_labeltotextemb()

    def initialize_lossfn(self):
        # self.criterion = torch.nn.CrossEntropyLoss()
        # projector criterion is the norm of the difference between the projected embeddings and the text prompt embeddings
        self.projector_criterion = torch.nn.MSELoss()
        # self.critierion is the similarity loss between the projected embeddings and the text prompt embeddings, similar to CLIP loss.
        # projected embeddings and text prompt embeddings are both vectors of size (batch_size, llm_dim). First compute the cosine similarity 
        # between the two embeddings, then compute softmax
        self.criterion = self.contrastive_loss_with_label
    

    def contrastive_loss_with_label(self, graph_embeddings, text_embeddings_of_y, y):
        # normalize embeddings
        # text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
        # graph_embeddings = F.normalize(graph_embeddings, p=2, dim=-1)
        # contrastive loss
        logits = (graph_embeddings @ text_embeddings_of_y.T) / self.temperature

        exp_logits = torch.exp(logits)
        sum_exp_logits = exp_logits.sum(dim=1)

        # pdb.set_trace()
        # logits is in shape (batch_size, num_classes). for each row, the value at the index of the true class is the logit for that class. 
        # retrieve the logit for the true class
        true_class_logits = logits[torch.arange(logits.shape[0]), y]

        loss = true_class_logits - torch.log(sum_exp_logits)

        return -loss.mean()


    def initialize_gnn(self):
        if self.gnn_type == 'GAT':
            self.gnn = GAT(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GCN':
            self.gnn = GCN(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GraphSAGE':
            self.gnn = GraphSAGE(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GIN':
            self.gnn = GIN(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GCov':
            self.gnn = GCov(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GraphTransformer':
            self.gnn = GraphTransformer(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        else:
            raise ValueError(f"Unsupported GNN type: {self.gnn_type}")
        self.gnn.to(self.device)

        if self.pre_train_model_path != 'None':
            if self.gnn_type not in self.pre_train_model_path :
                raise ValueError(f"the Downstream gnn '{self.gnn_type}' does not match the pre-train model")
            if self.dataset_name not in self.pre_train_model_path :
                # raise ValueError(f"the Downstream dataset '{self.dataset_name}' does not match the pre-train dataset")
                print(f"Warning: the Downstream dataset '{self.dataset_name}' does not match the pre-train dataset")

            self.gnn.load_state_dict(torch.load(self.pre_train_model_path, map_location=self.device))
            print("Successfully loaded pre-trained weights!")

    
    def initialize_prompt(self):
        self.prompt = MorpherGraphPrompt(token_dim=self.input_dim, token_num=self.prompt_graph_token_num).to(self.device)
        self.start_vocab = self.text_prompt_start_vocab
        self.start_vocab_tokens = self.tokenizer.encode(self.start_vocab, return_tensors='pt').to(self.device)
        start_vocab_emb = self.llm(self.start_vocab_tokens)[0].squeeze()
        
        self.n_tokens = len(self.start_vocab_tokens[0])
        self.text_prompt = MorpherTextPrompt(self.llm.get_input_embeddings(), n_tokens=self.n_tokens, start_vocab_emb=start_vocab_emb).to(self.device)


    def initialize_optimizer(self):
        self.projector_opi = optim.Adam(filter(lambda p: p.requires_grad, self.projector.parameters()), lr=self.projector_lr, weight_decay=self.projector_weight_decay)
        self.projector_tune_opi = optim.Adam(filter(lambda p: p.requires_grad, self.projector.parameters()), lr=self.projector_tune_lr, weight_decay= self.projector_tune_weight_decay)
        self.pg_opi = optim.Adam(filter(lambda p: p.requires_grad, self.prompt.parameters()), lr=self.pg_lr, weight_decay= self.pg_weight_decay)
        self.text_prompt_opi = optim.Adam(filter(lambda p: p.requires_grad, self.text_prompt.parameters()), lr=self.text_prompt_lr, weight_decay= self.text_prompt_weight_decay)


    def eval_projector(self, eval_loader: DataLoader):
        self.projector.eval()
        total_loss = 0.0
        for batch in eval_loader:
            batch = batch.to(self.device)

            out = self.gnn(batch.x, batch.edge_index, batch.batch)
            out = self.projector(out)

            # create the text embeddings from batch.y according to the encoded_label_to_text dictionary
            with torch.no_grad():
                text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])

            out = F.normalize(out, p=2, dim=-1)
            text_emb = F.normalize(text_emb, p=2, dim=-1)

            loss = self.criterion(out, text_emb, batch.y)
            total_loss += loss.item()
        return total_loss/len(eval_loader)
                

    def train_projector(self, train_loader: DataLoader, val_loader: DataLoader, projector_epochs):
        best_projector = None
        self.projector.train()
        best_eval_loss = 1000000
        for epoch in range(projector_epochs):
            total_loss = 0.0
            for batch in train_loader:
                batch = batch.to(self.device)

                out = self.gnn(batch.x, batch.edge_index, batch.batch)
                out = self.projector(out)

                # create the text embeddings from batch.y according to the encoded_label_to_text dictionary
                with torch.no_grad():
                    text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])

                # row normalization
                out = F.normalize(out, p=2, dim=-1)
                text_emb = F.normalize(text_emb, p=2, dim=-1)

                # pdb.set_trace()

                loss = self.criterion(out, text_emb, batch.y)
                self.projector_opi.zero_grad()
                loss.backward()
                self.projector_opi.step()
                total_loss += loss.item()

            if epoch % self.projector_train_modular == 0:
                print(f"Projector Epoch: {epoch}, Train Loss: {total_loss/len(train_loader)}")
                # pdb.set_trace()
                eval_loss = self.eval_projector(val_loader)
                print(f"Projector Epoch: {epoch}, Eval Loss: {eval_loss}")
                if eval_loss < best_eval_loss and abs(eval_loss - total_loss/len(train_loader)) < self.projector_train_eval_diff_threshold:
                    best_eval_loss = eval_loss
                    print("Checkpointing best projector model...")
                    best_projector = self.projector.state_dict()

        print("Projector training finished! Loading best projector model...")
        self.projector.load_state_dict(best_projector)
                    

    def MorpherTrain(self, train_loader):
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(self.device)
            prompted_graph = self.prompt(batch)
            # prompted_graph = batch
            graph_emb = self.gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch)
            out = self.projector(graph_emb)

            self.llm.set_input_embeddings(self.text_prompt)

            # update the prompted text embeddings
            self.label_to_text_emb = {i: self.llm(self.tokenized_label_to_text[i])[0].mean(dim=1).squeeze() for i in self.label_to_text}
            self.meannormalize_labeltotextemb()

            # text_emb = torch.stack([self.label_to_text_emb[label.item()].clone() for label in batch.y])
            text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])
            # similarity-based loss between out and the text embeddings
            # pdb.set_trace()
            # row normalization
            out = F.normalize(out, p=2, dim=-1)
            text_emb = F.normalize(text_emb, p=2, dim=-1)

            loss = self.criterion(out, text_emb, batch.y)

            self.pg_opi.zero_grad()
            self.text_prompt_opi.zero_grad()
            self.projector_tune_opi.zero_grad()
            loss.backward()
            self.pg_opi.step()
            self.text_prompt_opi.step()
            self.projector_tune_opi.step()
            total_loss += loss.item()
        
        return total_loss/len(train_loader)
    

    def MorpherEval(self, eval_loader, num_class, device):
        self.prompt.eval()
        self.text_prompt.eval()
        accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_class).to(device)
        macro_f1 = torchmetrics.classification.F1Score(task="multiclass", num_classes=num_class, average="weighted").to(device)
        accuracy.reset()
        macro_f1.reset()
        for batch in eval_loader:
            batch = batch.to(self.device)
            prompted_graph = self.prompt(batch)
            # prompted_graph = batch
            graph_emb = self.gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch)
            out = self.projector(graph_emb)

            self.llm.set_input_embeddings(self.text_prompt)
            # update the prompted text embeddings
            self.label_to_text_emb = {i: self.llm(self.tokenized_label_to_text[i])[0].mean(dim=1).squeeze() for i in self.label_to_text}
            self.meannormalize_labeltotextemb()

            text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])
            out = F.normalize(out, p=2, dim=-1)
            text_emb = F.normalize(text_emb, p=2, dim=-1)
            sims = out @ text_emb.T
            pred = sims.argmax(dim=1)

            acc = accuracy(pred, batch.y)
            f1 = macro_f1(pred, batch.y)
        acc = accuracy.compute()
        f1 = macro_f1.compute()

        return acc, f1



    def GNNForwardZero(self, batch):
        prompted_graph = self.prompt(batch)
        # prompted_graph = batch
        graph_emb = self.gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch) / self.gnn_temp
        # for each embedding, substract the mean embedding from it
        # mean_graph_emb = graph_emb.mean(dim=0)
        # graph_emb = graph_emb - mean_graph_emb
        out = self.projector(graph_emb)
        return graph_emb, out


    


    def train_projector_zero(self, train_loader: DataLoader, projector_epochs):
        self.projector.train()
        for epoch in range(projector_epochs):
            total_loss = 0.0
            for batch in train_loader:
                batch = batch.to(self.device)
                out = self.gnn(batch.x, batch.edge_index, batch.batch) / self.gnn_temp
                out = out - out.mean(dim=0)
                out = self.projector(out)

                # create the text embeddings from batch.y according to the encoded_label_to_text dictionary
                with torch.no_grad():
                    text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])[:-1]

                # row normalization
                out = F.normalize(out, p=2, dim=-1)
                text_emb = F.normalize(text_emb, p=2, dim=-1)

                loss = self.criterion(out, text_emb, batch.y)
                self.projector_opi.zero_grad()
                loss.backward()
                self.projector_opi.step()
                total_loss += loss.item()

        print("Projector training finished.")






    def MorpherTrainZero(self, train_loader):
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(self.device)
            prompted_graph = self.prompt(batch)
            # prompted_graph = batch
            graph_emb = self.gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch) / self.gnn_temp
            out = self.projector(graph_emb)

            self.llm.set_input_embeddings(self.text_prompt)

            # update the prompted text embeddings
            self.label_to_text_emb = {i: self.llm(self.tokenized_label_to_text[i])[0].mean(dim=1).squeeze() for i in self.label_to_text}
            self.meannormalize_labeltotextemb()

            # text_emb = torch.stack([self.label_to_text_emb[label.item()].clone() for label in batch.y])
            text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])[:-1]
            # similarity-based loss between out and the text embeddings
            # pdb.set_trace()
            # row normalization
            out = F.normalize(out, p=2, dim=-1)
            text_emb = F.normalize(text_emb, p=2, dim=-1)

            # pdb.set_trace()

            loss = self.criterion(out, text_emb, batch.y)

            self.pg_opi.zero_grad()
            self.text_prompt_opi.zero_grad()
            self.projector_tune_opi.zero_grad()
            loss.backward()
            self.pg_opi.step()
            self.text_prompt_opi.step()
            self.projector_tune_opi.step()
            total_loss += loss.item()
        
        return total_loss/len(train_loader)




    def MorpherEvalZero(self, eval_loader, num_class, device):
        self.prompt.eval()
        self.text_prompt.eval()
        accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_class).to(device)
        macro_f1 = torchmetrics.classification.F1Score(task="multiclass", num_classes=num_class, average="weighted").to(device)
        accuracy.reset()
        macro_f1.reset()

        for batch in eval_loader:
            batch = batch.to(self.device)
            prompted_graph = self.prompt(batch)
            # prompted_graph = batch
            graph_emb = self.gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch) / self.gnn_temp
            out = self.projector(graph_emb)

            self.llm.set_input_embeddings(self.text_prompt)
            # update the prompted text embeddings
            self.label_to_text_emb = {i: self.llm(self.tokenized_label_to_text[i])[0].mean(dim=1).squeeze() for i in self.label_to_text}
            self.meannormalize_labeltotextemb()


            text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])[:-1]
            out = F.normalize(out, p=2, dim=-1)
            text_emb = F.normalize(text_emb, p=2, dim=-1)
            sims = out @ text_emb.T
            pred = sims.argmax(dim=1)



            acc = accuracy(pred, batch.y)
            f1 = macro_f1(pred, batch.y)
            # pdb.set_trace()
        acc = accuracy.compute()
        f1 = macro_f1.compute()

        return acc, f1







    def MorpherTestZero(self, train_loader, test_loader, device, num_class = 3):
        self.prompt.eval()
        self.text_prompt.eval()
        accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_class).to(device)
        accuracy.reset()

        for batch in train_loader:
            batch = batch.to(self.device)
            train_batch = batch
            # prompted_graph = self.prompt(batch)
            # # prompted_graph = batch
            # graph_emb_train = self.gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch) / self.gnn_temp
            # original_graph_emb_train = self.gnn(batch.x, batch.edge_index, batch.batch) / self.gnn_temp
            # out_train = self.projector(graph_emb_train) 
            graph_emb_train, out_train = self.GNNForwardZero(train_batch)
            out_train_norm = F.normalize(out_train, p=2, dim=-1)
            # pdb.set_trace()

        for batch in test_loader:
            batch = batch.to(self.device)
            # prompted_graph = self.prompt(batch)
            # # prompted_graph = batch
            # graph_emb = self.gnn(prompted_graph.x, prompted_graph.edge_index, prompted_graph.batch) / self.gnn_temp
            # out = self.projector(graph_emb)
            graph_emb, out = self.GNNForwardZero(batch)

            self.llm.set_input_embeddings(self.text_prompt)
            # update the prompted text embeddings
            self.label_to_text_emb = {i: self.llm(self.tokenized_label_to_text[i])[0].mean(dim=1).squeeze() for i in self.label_to_text}
            self.meannormalize_labeltotextemb()

            text_emb = torch.stack([self.label_to_text_emb[label].clone() for label in self.label_to_text_emb])
            unnormalized_out = out.clone()
            out = F.normalize(out, p=2, dim=-1)
            text_emb_normalized = F.normalize(text_emb, p=2, dim=-1)
            sims = out @ text_emb_normalized.T
            train_sims = out_train_norm @ text_emb_normalized.T
            # sims_train = out_train_norm @ text_emb.T
            pred = sims.argmax(dim=1)
            # sims_zero is to compute the distance between each out and text_emb
            # sims_zero = torch.cdist(out, text_emb)
            # pred = sims_zero.argmin(dim=1)
            acc = accuracy(pred, batch.y)

        pred_train = train_sims.argmax(dim=1)
        acctrain = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_class).to(device)
        acc_train = acctrain(pred_train, train_batch.y)
        acc_train = acctrain.compute()
        # pdb.set_trace()

        acc = accuracy.compute()

        return acc, acc_train





    def run(self):
        train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
        print("prepare data is finished!")
        print("Setting Language Model and GNN to eval mode.")
        self.llm.eval()
        self.gnn.eval()
        self.llm.set_input_embeddings(self.text_prompt)

        for i in self.tokenized_label_to_text:
            self.tokenized_label_to_text[i] = torch.cat([self.start_vocab_tokens, self.tokenized_label_to_text[i]], 1)

        self.label_to_text_emb = {i: self.llm(self.tokenized_label_to_text[i])[0].mean(dim=1).squeeze() for i in self.label_to_text}
        self.meannormalize_labeltotextemb()

        projector_path = osp.join('trained_projector', f'best_projector_{self.dataset_name}_{self.pretrain_method}_{self.gnn_type}.pt')

        if not osp.exists(projector_path) or SAVE_PROJ:
            self.train_projector_zero(train_loader, projector_epochs=self.projector_epochs)
            torch.save(self.projector.state_dict(), projector_path)
        else:
            self.projector.load_state_dict(torch.load(projector_path, map_location=self.device))

        self.projector.eval()

        if self.dataset_name != 'MUTAG' or self.pretrain_method != 'GraphCL' or self.gnn_type != 'GCN':
            # reset the random seed to 42
            seed_torch(self.random_seed)

        train_losses = []
        train_accs = []
        test_accs = []
        test_f1s = []

        train_accs_with_new_class = []

        avg_test_zero_after_train_good = []
        avg_train_zero_after_train_good = []
        avg_test_of_train_zero_after_train_good = []

        for epoch in range(1, self.epochs+1):
            start = time.time()
            for epoch_prompt in range(self.prompt_tune_epochs):
                self.prompt.train()
                self.text_prompt.train()
                self.projector.eval()
                train_loss = self.MorpherTrainZero(train_loader)
                print(f"Tuning Prompt, prompt epoch: {epoch_prompt}, Train Loss: {train_loss}")
            for epoch_projector in range(self.projector_tune_epochs):
                self.prompt.eval()
                self.text_prompt.eval()
                self.projector.train()
                train_loss = self.MorpherTrainZero(train_loader)
                print(f"Tuning Projector, projector epoch: {epoch_projector}, Train Loss: {train_loss}")

            print(f"Epoch: {epoch}, Train Loss: {train_loss}, Time: {time.time()-start}")
            train_acc, train_f1 = self.MorpherEvalZero(train_loader, self.output_dim, self.device)
            test_acc, train_acc_with_new_class = self.MorpherTestZero(train_loader, test_loader, self.device, num_class = self.output_dim + 1)
            print(f"Epoch: {epoch}, Train Acc: {train_acc:.5f}, Test Acc: {test_acc:.5f}, Train Acc with new class: {train_acc_with_new_class}")

            train_losses.append(train_loss)
            train_accs.append(train_acc)
            test_accs.append(test_acc)
            train_accs_with_new_class.append(train_acc_with_new_class)

            if train_acc > self.train_good_threshold:
                avg_test_zero_after_train_good.append(test_acc)
                avg_train_zero_after_train_good.append(train_acc)
                avg_test_of_train_zero_after_train_good.append(train_acc_with_new_class)

        print(f"Average Test Acc after train good: {sum(avg_test_zero_after_train_good)/len(avg_test_zero_after_train_good)}, Std: {torch.tensor(avg_test_zero_after_train_good).std()}")
        print(f"Average Train Acc after train good: {sum(avg_train_zero_after_train_good)/len(avg_train_zero_after_train_good)}, Std: {torch.tensor(avg_train_zero_after_train_good).std()}")
        print(f"Average Test Acc of Train after train good: {sum(avg_test_of_train_zero_after_train_good)/len(avg_test_of_train_zero_after_train_good)}, Std: {torch.tensor(avg_test_of_train_zero_after_train_good).std()}")

        if SAVE_DATA:
            # save the train losses, train accs, test accs, train accs with new class
            save_folder = 'saved_data'
            np.save(osp.join(save_folder, f'{self.dataset_name}_{self.pretrain_method}_{self.gnn_type}_train_losses.npy'), np.array(train_losses))
            np.save(osp.join(save_folder, f'{self.dataset_name}_{self.pretrain_method}_{self.gnn_type}_train_accs.npy'), np.array([acc.cpu() for acc in train_accs]))
            np.save(osp.join(save_folder, f'{self.dataset_name}_{self.pretrain_method}_{self.gnn_type}_test_accs.npy'), np.array([acc.cpu() for acc in test_accs]))
            np.save(osp.join(save_folder, f'{self.dataset_name}_{self.pretrain_method}_{self.gnn_type}_train_accs_with_new_class.npy'), np.array([acc.cpu() for acc in train_accs_with_new_class]))