"""Use LLM/LM models to generate embeddings"""
import torch 
from tqdm import tqdm
import os
import argparse
import sys
import pickle
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
sys.path.append("../..")
from common import get_cur_time, TextEncoder


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset", type=str, default="cora")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--encoder_name", type=str, default="SentenceBert", choices=["MiniLM", "SentenceBert", "e5-large", "roberta", "Qwen-3B", "Qwen-7B", "Mistral-7B", "Llama-8B", "bow", "tfidf"])
    parser.add_argument("--use_cls", type=int, default=0)
    parser.add_argument("--save_emb", type=int, default=1)
    parser.add_argument("--emb_dim", type=int, default=500, help="Embedding dimension for BOW and TF-IDF")
    parser.add_argument("--base_path", type=str, default="/path/to/GraphAD_data/datasets", help="Base path for all dataset related files")
    
    args = parser.parse_args()
    
    device = torch.device(args.device)

    if args.dataset.startswith('taglas_'):
        taglas_name = args.dataset[len('taglas_'):]
        graph_data = torch.load(os.path.join(args.base_path, "taglas_data", f"{taglas_name}.pt"), weights_only=False)
    else:
        graph_data = torch.load(os.path.join(args.base_path, f"{args.dataset}.pt"), weights_only=False)

    print('= ' * 20)
    print('## Starting Time:', get_cur_time(), flush=True)
    print(args, "\n")

    if os.path.exists(os.path.join(args.base_path, args.encoder_name, f"{args.dataset}.pt")):
        print(f"[{args.dataset}-{args.encoder_name}] Embedding file already exists, Quit!")
        print('= ' * 20)
        exit()
    
    if args.encoder_name in ["bow", "tfidf"]:
        encoder_type = 'shallow'
        # Create directory for saving vocabularies
        vocab_dir = os.path.join(args.base_path, "vocab", args.dataset)
        os.makedirs(vocab_dir, exist_ok=True)
        
        if args.encoder_name == "bow":
            vocab_path = os.path.join(vocab_dir, "bow_vocabulary.pkl")
            if os.path.exists(vocab_path):
                # Load existing vocabulary
                with open(vocab_path, 'rb') as f:
                    vectorizer = pickle.load(f)
                bow_matrix = vectorizer.transform(graph_data.raw_texts)
            else:
                # Create new vocabulary
                vectorizer = CountVectorizer(max_features=args.emb_dim)
                vectorizer.fit(graph_data.raw_texts)  # Build vocabulary on all texts
                bow_matrix = vectorizer.transform(graph_data.raw_texts)  # Transform texts to vectors
                # Save vocabulary
                with open(vocab_path, 'wb') as f:
                    pickle.dump(vectorizer, f)
            generated_node_emb = torch.FloatTensor(bow_matrix.toarray())
            
        elif args.encoder_name == "tfidf":
            vocab_path = os.path.join(vocab_dir, "tfidf_vocabulary.pkl")
            if os.path.exists(vocab_path):
                # Load existing vocabulary
                with open(vocab_path, 'rb') as f:
                    vectorizer = pickle.load(f)
                tfidf_matrix = vectorizer.transform(graph_data.raw_texts)
            else:
                # Create new vocabulary
                vectorizer = TfidfVectorizer(max_features=args.emb_dim)
                vectorizer.fit(graph_data.raw_texts)  # Build vocabulary on all texts
                tfidf_matrix = vectorizer.transform(graph_data.raw_texts)  # Transform texts to vectors
                # Save vocabulary
                with open(vocab_path, 'wb') as f:
                    pickle.dump(vectorizer, f)
            generated_node_emb = torch.FloatTensor(tfidf_matrix.toarray())
    else:
        encoder_type = "LM" if args.encoder_name in ["MiniLM", "SentenceBert", "e5-large", "roberta"] else "LLM"
        text_encoder = TextEncoder(encoder_name=args.encoder_name, encoder_type=encoder_type, device=device)
        
        with torch.no_grad():
            text_embeddings = []
            for text in tqdm(graph_data.raw_texts, desc=f"Generating {encoder_type} Embedding"):
                text_emb = text_encoder(input_text=text, pooling="cls" if args.use_cls else "mean")
                text_embeddings.append(text_emb)
                torch.cuda.empty_cache()
            generated_node_emb = torch.cat(text_embeddings, dim=0)
    
    print(f"[{args.dataset}-{args.encoder_name}] Node Embedding Shape {generated_node_emb.shape}")
    if args.save_emb:
        write_dir = os.path.join(args.base_path, args.encoder_name)
        os.makedirs(write_dir, exist_ok=True)
        torch.save(generated_node_emb, os.path.join(write_dir, f"{args.dataset}.pt"))
    
    print('\n## Finishing Time:', get_cur_time(), flush=True)
    print('= ' * 20)
    print("Done!")
