import transformers
import numpy as np
import tqdm
import argparse
import os
import tempfile
import sys
import json
import re
import functools
import yaml
import torch
from openai import OpenAI

'''
NOTE: if you want to use gpt2 as scoring model, you need to truncate the input before passing it to gpt2 for computing log likelihoods, gpt2 has max input length of tokens
'''

def parse_args():
    parser = argparse.ArgumentParser(description='Anchor')
    
    # Debugging options
    parser.add_argument('--debug', action='store_true', default=False,
                        help='Enable debug mode')
    parser.add_argument('--pilot', action='store_true', default=False,
                        help='Enable pilot mode')
    
    # Context and perturbation parameters
    parser.add_argument('--context', action='store_true', default=False,
                        help='context = True => the paper contents will be prepended to the original/perturbed texts, and then passed through the scoring model for computing log likelihoods, however, of course, the log likelihoods will be computed only on the review tokens') # just here for legacy reasons

    # Embedding model   
    parser.add_argument('--embedding_model', type=str, default="text-embedding-3-small",
                        help='Embedding model to use')
    parser.add_argument('--cache_dir', type=str, default='/data/assets/hub',
                        help='Cache directory for models')
    parser.add_argument('--batch_size', type=int, default=16)
  
    # Dataset and prompt options
    parser.add_argument('--reviewset', type=str, default="original",
                        help='Review dataset to use')
    parser.add_argument('--prompt_level', type=int, default=1,
                        help='in case context = True and you are using an instruct model as scoring model, which level of prompt would you like to prepend to the candidate review along with the paper, can be 1 or 2, because creating level 3, 4 requires additional information which we dont have at test time')
    

    parser.add_argument('--continue_from_last', action='store_true', default=False,
                        help='Continue from last checkpoint')
    
    return parser.parse_args()

args = parse_args()

print(args.cache_dir)

if args.embedding_model == "text-embedding-3-small" or args.embedding_model == "text-embedding-3-large":
    openai_key = os.getenv("OPENAI_API_KEY")
    client = OpenAI(api_key=openai_key)

elif args.embedding_model == "specter2":
    
    from transformers import AutoTokenizer
    from adapters import AutoAdapterModel

    tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base', cache_dir=args.cache_dir)
    model = AutoAdapterModel.from_pretrained('allenai/specter2_base', cache_dir=args.cache_dir)
    model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)
    model = model.to("cuda:0")
elif args.embedding_model == "linq-embed-mistral":
    import torch
    import torch.nn.functional as F
    from torch import Tensor
    from transformers import AutoTokenizer, AutoModel

    def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
        
    bnb_config = transformers.BitsAndBytesConfig(
            load_in_4bit=True,                # enable 4-bit quantization
            bnb_4bit_use_double_quant=True,   # nested quantization for memory saving
            bnb_4bit_quant_type="nf4",        # NormalFloat4 (best quality)
            bnb_4bit_compute_dtype="bfloat16" # computation dtype (fp16 also works if bf16 not available)
        )
        
    tokenizer = AutoTokenizer.from_pretrained(
        'Linq-AI-Research/Linq-Embed-Mistral', 
        cache_dir=args.cache_dir
    )
    model = AutoModel.from_pretrained(
        'Linq-AI-Research/Linq-Embed-Mistral', 
        quantization_config=bnb_config,
        cache_dir=args.cache_dir).to("cuda:0")
    max_length = 4096
else:
    raise ValueError("Embedding model not implemented")

embedding_model = args.embedding_model

DEBUG = args.debug
PILOT = args.pilot

context = args.context

reviewset = args.reviewset
prompt_level = args.prompt_level # in case context = True and you are using an instruct model as scoring model, which level of prompt would you like to prepend to the candidate review along with the paper, can be 1 or 2, because creating level 3, 4 requires additional information which we dont have at test time
continue_from_last = args.continue_from_last


results_path_prefix = f"results/{embedding_model}{f'_{reviewset}' if reviewset != 'original' else ''}{'_wcontext' if context else ''}{'_pilot' if PILOT else ''}/"

print(f"Writing results to {results_path_prefix}")

os.makedirs(results_path_prefix, exist_ok=True)

sys.stdout = open(f"{results_path_prefix}/loggings_org.txt","a" if continue_from_last else "w",buffering =1)

# load review guidelines
with open("/ai-involvement-in-peer-reviews/AI_generation/guidelines.yaml","r") as fin:
    review_guidelines = yaml.load(fin, Loader=yaml.SafeLoader)

# load level-1 and level-2 review genreration prompts
with open("/ai-involvement-in-peer-reviews/AI_generation/prompts.yaml","r") as fin:
    review_generation_prompts = yaml.load(fin, Loader=yaml.SafeLoader)

def extract_paper_contents_from_filepath(paper_filepath):

    with open(paper_filepath, "r") as fin:
        file_content = json.load(fin)

    # print(file_content)
    stitched_content = []

    if "metadata" in file_content.keys() and "title" in file_content["metadata"].keys() and file_content["metadata"]["title"] is not None:
        title = file_content["metadata"]["title"].strip()
        stitched_content.append(f"Title: {title}\n")


    # For now only using introduction and conclusion, TODO: consider other section too: which section? wer abstract?
    if "sections" in file_content["metadata"].keys() and file_content["metadata"]["sections"] is not None:
        
        for section in file_content["metadata"]["sections"]:
            if section["heading"] is None:
                continue
            heading = section["heading"].strip().lower()
            if "introduction" in heading or "conclusion" in heading:
                stitched_content.append(f"{section['heading'].strip()}\n{section['text'].strip()}")

    return "\n\n".join(stitched_content)

def extract_paper_fp_from_review_fp(review_filepath):
    ## extract the paper contents 
    pattern = r".*cleandata/(.*)/(train|test|dev)/.*(level[1-4]|reviews)/(.*)_.*"
    match = re.search(pattern, review_filepath)
    conference = match.group(1)
    split = match.group(2)
    paper_number = match.group(4)

    return f"/ai-involvement-in-peer-reviews/data/{conference}/{split}/parsed_pdfs/{paper_number}.pdf.json", conference, split, paper_number


is_humanized = True if reviewset == "humanized" else False
is_new = False

if reviewset == "new2024" or reviewset == "new2025":
    is_new  = True
    year = reviewset[3:]
    review_json = f"/ai-involvement-in-peer-reviews/RecentReviews/recentreviews/reviews_{year}.jsonl"
    with open(review_json,"r") as file:
        all_paths = [json.loads(line) for line in file]
    legacy_result_file = results_path_prefix + f"/legacy_result_anchor_newcode_{reviewset}.json"
    result_file = results_path_prefix + f"/result_anchor_newcode_{reviewset}.json"
else:
    if not is_humanized:
        pathfile = "/ai-involvement-in-peer-reviews/PathFiles/all_paths.txt"
        legacy_result_file = results_path_prefix + "/legacy_result_anchor_newcode.json"
        result_file = results_path_prefix + "/result_anchor_newcode.json"
    elif is_humanized:
        pathfile = "/ai-involvement-in-peer-reviews/Extras/all_paths_humanized.txt"
        result_file = "result_anchor_humanized.json"


    with open(pathfile,"r") as file:
        all_paths = file.readlines()
    all_paths = [path.strip() for path in all_paths]


papers = []

################ when running a trial just run it on each_num_ex number of examples from each level, modifying at 17:21, sept 5, not sure why i felt the need to timestamp ###################
if reviewset == "original":
    if PILOT:
        
        buffer = []
        level_counter = [0,0,0,0,0,0]
        each_num_ex = 4

        for each_review_path in all_paths:
            for i in range(1,5):
                if f"level{i}" in each_review_path and level_counter[i] < each_num_ex:
                    buffer.append(each_review_path)
                    level_counter[i] += 1
                
            if "reviews" in each_review_path and level_counter[5] < each_num_ex:
                buffer.append(each_review_path)
                level_counter[5] += 1

        all_paths = buffer
    else:
        pass

###################################################################################################

elif is_new:
    # get all new reviews for which pangram score is available
    with open(f"/ai-involvement-in-peer-reviews/DetectorEval/Pangram/RecentReviewAnalysis/results_iclr{year}_Pangram.json", "r") as fin:
        pangram_scores_avl = json.load(fin)
    # print([x.replace(".json","") for x in pangram_scores_avl.keys()])
    # create all_paths list with that 
    all_paths = [x for x in all_paths if f"{x['review_id']}.json" in pangram_scores_avl.keys()]

else:
    print("case not handled yet")


for each_review_path in tqdm.tqdm(all_paths): #each line in case of new reviews
    if is_new:
        review_text = each_review_path["summary"]+"\n"+each_review_path["strengths"]+"\n"+each_review_path["weaknesses"]+"\n"+each_review_path["questions"]
        review_text = review_text.replace("\n\n","\n")
        unique_id = each_review_path["review_id"]
        papers.append({
            "filename": each_review_path['review_id'],
            # TODO: paper contents for recent reviews where?
            "content": review_text
        })
    else:
        resolved_filepath = each_review_path.replace("/Project/Human_or_AI/Data_Preprocessing/","/ai-involvement-in-peer-reviews/Data_Preprocessing/")
        with open(resolved_filepath,"r") as file:
            review_text = file.read()
        
        paper_filepath, conference, _, _ = extract_paper_fp_from_review_fp(each_review_path)
        paper_contents = extract_paper_contents_from_filepath(paper_filepath)

        '''really ugly code to extract just the conference name follows'''
        remove_list = ['_2016', '_2017', '_2013', '-2017','/2013','/2014','/2015','/2016','/2017']
        for _ in remove_list:
            conference = conference.replace(_,'')
        '''ugl(ier) code ends here'''

        unique_id = '-'.join(each_review_path.split('/')[-2:]).replace('.txt','')
        papers.append({
            "filename": each_review_path,
            'paper_content': paper_contents,
            "content": review_text,
            "conference": conference
        })

save_after_examples = args.batch_size
ctr = 0

results_dict = dict()
all_embeddings = []
completed_till = 0

if continue_from_last and os.path.exists(result_file) and os.path.exists(result_file.replace(".json",".npy")):
    with open(result_file,"r") as fin:
        results_dict =  json.load(fin)
    all_embeddings = np.load(result_file.replace(".json",".npy")).tolist()
    completed_till = min(len(results_dict), len(all_embeddings))
    all_embeddings = all_embeddings[:completed_till]
    print(f"Continuing from last checkpoint, {completed_till} embeddings")

print(f"TOTAL REVIEWS TO BE EVALUATED = {len(papers)}")


for ctr in tqdm.tqdm(range(completed_till, len(papers), save_after_examples)):

    current_paper_batch = papers[ctr:min(len(papers),ctr+save_after_examples)]

    review_filenames = [x['filename'] for x in current_paper_batch]
    review_contents = [x['content'] for x in current_paper_batch]

    if args.embedding_model == "text-embedding-3-small" or args.embedding_model == "text-embedding-3-large":
        response = client.embeddings.create(
            model=embedding_model,
            input=review_contents
        )

        review_embeddings = [item.embedding for item in response.data]
    elif args.embedding_model == "specter2":
        text_batch = review_contents
        inputs = tokenizer(
            text_batch, 
            padding=True, 
            truncation=True,
            return_tensors="pt", 
            return_token_type_ids=False, 
            max_length=512
        ).to("cuda:0")
        print(inputs['input_ids'].shape)
        output = model(**inputs)
        embeddings = output.last_hidden_state[:, 0, :]
        review_embeddings = embeddings.detach().cpu().numpy().tolist()
    elif args.embedding_model == "linq-embed-mistral":
        input_texts = review_contents
        batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt").to("cuda:0")
        print(batch_dict['input_ids'].shape)
        outputs = model(**batch_dict)
        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

        # Normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        review_embeddings = embeddings.detach().cpu().numpy().tolist()
    
    
    for i, (review_filename, review_embedding) in enumerate(zip(review_filenames, review_embeddings)):
        results_dict[review_filename] = {
            "filename": review_filename,
            "content": current_paper_batch[i]['content'],
            "embedding_id": i + len(all_embeddings) # since we are appending the embeddings to the all_embeddings list, we need to add the index of the embedding to the embedding list
        }

    all_embeddings.extend(review_embeddings)
    
    if ctr % (1000 * save_after_examples) == 0:
        print(f"Saving intermediate results at ctr = {ctr}")
        with open(result_file,"w") as fout:
            json.dump(results_dict, fout, indent=4)

        # this numpy save is verrry slow, espcecially with large arrays, so after a certain number os steps, this will take more time than the actual embedding extraction, so BEWARE, dont do this very often
        np.save(result_file.replace(".json",".npy"), all_embeddings)

    # torch.cuda.empty_cache()

with open(result_file,"w") as fout:
    json.dump(results_dict, fout, indent=4)

np.save(result_file.replace(".json",".npy"), all_embeddings)

print("Embedding extraction completed.")