#!/usr/bin/env python
# coding: utf-8

from pathlib import Path
import re
import os
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModel, AutoModelForTokenClassification

sys.path.append("/workspace")
from ner_bert_finetuning.MatSciBERT.normalize_text import normalize


def extract_ner_tokens(title_list, token_classifier, chunk_size=512):
    """
    Extract named entity tokens from a list of paper titles using a token classifier.

    Args:
    title_list (list): List of paper titles.
    token_classifier: Token classifier for named entity recognition.
    chunk_size (int, optional): Chunk size for processing. Defaults to 512.

    Returns:
    list: List of extracted named entity tokens.
    """
    extracted_tokens = []

    for i in tqdm(range(0, len(title_list), chunk_size), desc="Processing NER"):
        chunk = title_list[i:i + chunk_size]
        ner_results = token_classifier(chunk)
        extracted_chunk = [" ".join([item['word'] for item in sublist]) for sublist in ner_results]
        extracted_tokens.extend(extracted_chunk)

    return extracted_tokens


def prepare_embedding_model(model_name):
    """
    Prepare the tokenizer and embedding model.

    Args:
    model_name (str): Name of the model to use.

    Returns:
    tuple: Tokenizer and embedding model.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return tokenizer, model


def compute_text_embeddings(tokens, tokenizer, model, normalize_func, chunk_size=512):
    """
    Compute text embeddings for a list of extracted tokens.

    Args:
    tokens (list): List of extracted tokens.
    tokenizer: Tokenizer to use.
    model: Embedding model to use.
    normalize_func: Function to normalize text strings.
    chunk_size (int, optional): Chunk size for processing. Defaults to 512.

    Returns:
    list: List of text embeddings.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    text_embeddings = []

    for i in tqdm(range(0, len(tokens), chunk_size), desc="Processing embeddings"):
        chunk = tokens[i:i + chunk_size]
        norm_chunk = [normalize_func(" ".join(s)) for s in chunk]
        tokenized_chunk = tokenizer(norm_chunk, return_tensors='pt', padding=True, truncation=True)

        tokenized_chunk = {k: v.to(device) for k, v in tokenized_chunk.items()}

        with torch.no_grad():
            last_hidden_states = model(**tokenized_chunk)[0]

        batch_embeddings = last_hidden_states.sum(1)
        text_embeddings.extend(batch_embeddings.cpu())

    return text_embeddings


def main():
    DEBUG = False
    raw_csv_filename = "/workspace/data/raw/cod_metadata_20230523.csv"
    cod_df = pd.read_csv(raw_csv_filename)
    cod_df = cod_df.fillna({'title': ''})

    if DEBUG:
        cod_df = cod_df.iloc[:5000]
    
    title_list = [obj[0].replace("~", "") for obj in cod_df[["title"]].values]

    # setup ner pipeline
    ner_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
    checkpoint_path = "/workspace/ner_bert_finetuning/results/best_ner_scibert_20230616_0040"
    ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint_path)
    token_classifier = pipeline("token-classification", 
                                model=ner_model, 
                                tokenizer=ner_tokenizer,
                                device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
                                aggregation_strategy="simple",
                                batch_size=512,
                                num_workers=0)
    
    # perform ner
    extracted_tokens = extract_ner_tokens(title_list, token_classifier)

    # tokenize extracted words
    # tokenizer_embedding, model_embedding = prepare_embedding_model('m3rg-iitd/matscibert')
    tokenizer_embedding, model_embedding = prepare_embedding_model('matscibert')
    text_embeddings = compute_text_embeddings(extracted_tokens, tokenizer_embedding, model_embedding, normalize)

    # export 
    cod_df['title_embedding'] = text_embeddings
    cod_df['title_embedding'] = cod_df['title_embedding'].apply(lambda x: x.numpy())

    filename_without_ext, _ = os.path.splitext(raw_csv_filename)
    export_filename = f"{filename_without_ext}_with_embedding.pkl"
    cod_df.to_pickle(export_filename)
    print(f"saved: {export_filename}")

if __name__ == "__main__":
    main()
    


