import transformers
import numpy as np
import tqdm
import argparse
import os
import tempfile
import sys
import json
import re
import functools
import yaml
import torch
from openai import OpenAI
from datetime import datetime
import time

def parse_args():
    parser = argparse.ArgumentParser(description='Generate reference peer reviews')
    
    parser.add_argument('--pathfile', type=str, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--model_name', type=str, required=True, help='Model name to use for generation')
    parser.add_argument('--base_url', type=str, required=False, default=None, help='Base URL for the model API')
    parser.add_argument('--num_generations', type=int, default=1, help='Number of generations per prompt')
    parser.add_argument('--create_batch_file', action='store_true', help='whether to create batch file for generation instead of individual chat completion calls')
    parser.add_argument('--batch_file_path', type=str, default=None, help='Path to save the batch file if create_batch_file is set')
    parser.add_argument('--global_openai_metadata', type=str, default=None, help='Path to global openai metadata dump file')
    parser.add_argument('--add_dev', action='store_true', help='Whether to generate references for the dev set as well')
    parser.add_argument('--verbose', action='store_true', help='Whether to print verbose logs')
    
    return parser.parse_args()

args = parse_args()

api_key = os.getenv("OPENAI_API_KEY")

if not args.create_batch_file:
    client_args = {
        "api_key": api_key,
    }
    if args.base_url is not None:
        client_args["base_url"] = args.base_url
    client = OpenAI(**client_args)

# if args.create_batch_file is True, clear the destination batch file if it exists
if args.create_batch_file:
    if os.path.exists(args.batch_file_path):
        os.remove(args.batch_file_path)

pathfile = args.pathfile

pattern = re.compile(r"<extra_id_\d+>")

# load review guidelines
with open("/ai-involvement-in-peer-reviews/AI_generation/guidelines.yaml","r") as fin:
    review_guidelines = yaml.load(fin, Loader=yaml.SafeLoader)

# load prompt variations
with open("/ai-involvement-in-peer-reviews/reference_generation/prompt_variations.yaml","r") as fin:
    review_generation_prompts = yaml.load(fin, Loader=yaml.SafeLoader)

def get_openai_response(model_name, base_url, api_key, prompt):

    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
    )

    return response.choices[0].message.content

def remove_linenumbers(text):
    return re.sub(r'^\s*\d+\s*\n?', '', text, flags=re.MULTILINE)

def format_paper_contents(paper_contents):
    buffer = ""

    if "metadata" in paper_contents.keys() and "title" in paper_contents["metadata"].keys() and paper_contents["metadata"]["title"] is not None:
        title = remove_linenumbers(paper_contents["metadata"]["title"].strip())
        buffer += f"{title}\n\n"

    if "abstractText" in paper_contents["metadata"].keys() and paper_contents["metadata"]["abstractText"] is not None:
        abstract = remove_linenumbers(paper_contents["metadata"]["abstractText"].strip())
        buffer += f"ABSTRACT\n{abstract}\n\n"

    if "sections" in paper_contents["metadata"].keys() and paper_contents["metadata"]["sections"] is not None:
        for section in paper_contents["metadata"]["sections"]:
            if section["heading"] is None:
                continue
            section_heading = section["heading"]
            section_text = remove_linenumbers(section["text"])

            buffer += f"{section_heading}\n{section_text}\n\n"

    if "references" in paper_contents["metadata"].keys() and paper_contents["metadata"]["references"] is not None:
        buffer += "References:\n"
        for reference in paper_contents["metadata"]["references"]:
            for ref_field in reference.keys():
                if reference[ref_field] is None:
                    continue

                if isinstance(reference[ref_field], str):
                    ref_text = remove_linenumbers(reference[ref_field])
                    buffer += f"{ref_text}, "
                elif isinstance(reference[ref_field], list):
                    for itm in reference[ref_field]:
                        buffer += f"{remove_linenumbers(itm)}, "

            buffer += "\n\n"

    return buffer.strip()

def extract_paper_contents_from_filepath(paper_filepath):
    with open(paper_filepath, "r") as fin:
        file_content = json.load(fin)

    return format_paper_contents(file_content)

def extract_paper_fp_from_review_fp(review_filepath):
    ## extract the paper contents 
    pattern = r".*cleandata/(.*)/(train|test|dev)/.*(level[1-4]|reviews)/(.*)_([1-9]*).txt"
    match = re.search(pattern, review_filepath)
    conference = match.group(1)
    split = match.group(2)
    level = match.group(3)
    paper_number = match.group(4)
    reviewer_number = match.group(5)

    return (
        f"/ai-involvement-in-peer-reviews/data/{conference}/{split}/parsed_pdfs/{paper_number}.pdf.json", 
        conference, 
        split, 
        level, 
        paper_number, 
        reviewer_number
    )

def shorthand(model_name):
    mapping = {
        "/.cache/huggingface/hub/models--Qwen--Qwen3-30B-A3B-Thinking-2507-FP8/snapshots/60d80c83c53c3b611c642dbb8c942b3f90c5948a": "Qwen3-30B-A3B-Thinking-2507-FP8",
        "/.cache/huggingface/hub/models--google--gemma-3-27b-it/snapshots/005ad3404e59d6023443cb575daa05336842228a": "gemma-3-27b-it"
    }
    return mapping.get(model_name, model_name)

with open(pathfile,"r") as file:
    all_paths = file.readlines()
all_paths = [path.strip() for path in all_paths]


papers = []

for each_review_path in tqdm.tqdm(all_paths): #each line in case of new reviews
    resolved_filepath = each_review_path.replace("/Project/Human_or_AI/Data_Preprocessing/","/ai-involvement-in-peer-reviews/Data_Preprocessing/")
    with open(resolved_filepath,"r") as file:
        review_text = file.read()

    paper_filepath, conference, split, level, paper_number, reviewer_number = extract_paper_fp_from_review_fp(each_review_path)

    if split == "dev" and not args.add_dev:
        continue

    paper_text = extract_paper_contents_from_filepath(paper_filepath)

    '''really ugly code to extract just the conference name follows'''
    conference_wo_yr = conference
    remove_list = ['_2016', '_2017', '_2013', '-2017','/2013','/2014','/2015','/2016','/2017']
    for _ in remove_list:
        conference_wo_yr = conference_wo_yr.replace(_,'')
    '''ugl(ier) code ends here'''

    unique_id = '-'.join(each_review_path.split('/')[-2:]).replace('.txt','')
    papers.append({
        "filename": each_review_path,
        "paper_content": paper_text,
        "content": review_text, # review content
        "conference_wo_yr": conference_wo_yr,
        "conference": conference,
        "split": split,
        "level": level,
        "paper_number": paper_number,
        "reviewer_number": reviewer_number,
    })

save_after_examples = 1
reviews_generated = 0
batch_entries = 0

included_destination = set()

print(len(papers), "papers to process")

for ctr in tqdm.tqdm(range(0, len(papers), save_after_examples)):

    current_paper_batch = papers[ctr:min(len(papers),ctr+save_after_examples)]

    '''
    Do something with current paper batch
    '''

    for instance in current_paper_batch:

        for prompt_key in review_generation_prompts.keys():

            for gen_idx in range(args.num_generations):

                prompt_template = review_generation_prompts[prompt_key]

                prompt_filled = prompt_template.replace("{PAPER_CONTENT}", instance['paper_content'])
                prompt_filled = prompt_filled.replace("{CONFERENCE}", instance['conference_wo_yr'])
                prompt_filled = prompt_filled.replace("{GUIDELINES}", review_guidelines[instance['conference_wo_yr']])

                # get model response
                
                out_file = os.path.join(
                    args.output_dir, 
                    f"{instance['conference']}/{instance['split']}/{instance['paper_number']}/references/{shorthand(args.model_name)}/{prompt_key}_{gen_idx}_reference_review.txt"
                )


                os.makedirs(os.path.dirname(out_file), exist_ok=True)

                if (os.path.exists(out_file) and os.path.getsize(out_file) > 0) or out_file in included_destination:
                    if args.verbose:
                        print(f"File {out_file} already exists and is not empty. Skipping.")
                    continue
                elif not args.create_batch_file:
                    if args.verbose:
                        print(f"making api call with {len(prompt_filled.split())} input words")

                    start_time = time.time()

                    model_response = get_openai_response(
                        model_name=args.model_name,  # or the name you used to serve
                        base_url=args.base_url,
                        api_key=api_key,
                        prompt=prompt_filled
                    )

                    end_time = time.time()
                    print("Time taken for generation: ", end_time - start_time)

                    reviews_generated += 1

                    with open(out_file,"w") as fout:
                        fout.write(model_response)

                    included_destination.add(out_file)

                    generation_metadata = {
                        "model_name": args.model_name,
                        "prompt_template": prompt_template,
                        "file_written_to": out_file,
                        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    }

                    with open(args.global_openai_metadata, "a") as metaout:
                        metaout.write(json.dumps(generation_metadata) + "\n")

                # with open(out_file.replace("_reference_review.txt","_prompt.txt"),"w") as fout:
                #     fout.write(prompt_filled)
                # print("prompt printed at ", out_file.replace("_reference_review.txt","_prompt.txt"))

                else:
                    batch_file_entry = {
                        "custom_id": f"{out_file}::{datetime.now().strftime('%Y%m%d%H%M%S')}", 
                        "method": "POST", 
                        "url": "/v1/chat/completions", 
                        "body": {
                            "model": args.model_name, 
                            "messages": [
                                {"role": "system", "content": "You are a helpful assistant."},
                                {"role": "user", "content": prompt_filled},
                            ],
                        }
                    }


                    with open(args.batch_file_path, "a") as batchout:
                        batchout.write(json.dumps(batch_file_entry) + "\n")

                    batch_entries += 1
                    included_destination.add(out_file)


if not args.create_batch_file:
    print(reviews_generated, "reviews generated, phew!")
else:
    print(f"Batch file created at {args.batch_file_path} with {batch_entries} entries.")
