import json
import re
import os
import argparse
import sys

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--paths_file", default="/ai-involvement-in-peer-reviews/PathFiles/all_paths.txt", help="File containing all review file paths.")
    parser.add_argument("--output_dir", default="/saved_keypoints", help="Directory to write filtered reviews.")
    parser.add_argument("--min_reviews", type=int, default=3, help="Minimum number of human reviews required for a paper to be retained.")
    parser.add_argument("--max_train_papers_per_conference", type=int, default=5, help="Maximum number of papers from train split to retain per conference.")
    parser.add_argument("--max_test_papers_per_conference", type=int, default=5, help="Maximum number of papers from test split to retain per conference.")
    parser.add_argument("--max_dev_papers_per_conference", type=int, default=5, help="Maximum number of papers from dev split to retain per conference.")

    parser.add_argument("--output_paths_file", type=str, help="where to save the filepaths of the retained reviews")

    return parser.parse_args()

args = get_args()

output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)

# redirect stdout to a file
sys.stdout = open(os.path.join(args.output_dir, "summary.txt"), "w", buffering=1)

def parse_paper_details(review_filepath):
    ## extract the paper contents 
    pattern = r".*cleandata/(.*)/(train|test|dev)/.*(level[1-4]|reviews)/(.*)_([1-9]*).*"
    match = re.search(pattern, review_filepath)
    if match is None:
        print(review_filepath)
    conference = match.group(1)
    split = match.group(2)
    level = match.group(3)
    paper_number = match.group(4)
    reviewer_number = match.group(5)

    return conference, split, level, paper_number, reviewer_number

all_review_data = dict()

with open(args.paths_file,"r") as fin:
    for line in fin:
        line = line.replace("\n","")
        conference, split, level, paper_number, reviewer_number = parse_paper_details(line)
        if conference not in all_review_data.keys():
            all_review_data[conference] = dict()
        if split not in all_review_data[conference].keys():
            all_review_data[conference][split] = dict()
        if paper_number not in all_review_data[conference][split].keys():
            all_review_data[conference][split][paper_number] = dict()
        if level not in all_review_data[conference][split][paper_number].keys():
            all_review_data[conference][split][paper_number][level] = dict()

        with open(line.replace('/Project/Human_or_AI/Data_Preprocessing/','/ai-involvement-in-peer-reviews/Data_Preprocessing/'),"r") as freview:
            review_text = freview.read()

        if "gpt_4o_latest" in line:
            author = "gpt_4o_latest"
        elif "Llama-3.3-70B-Instruct" in line:
            author = "meta-llama-Llama-3.3-70B-Instruct"
        elif "/reviews/" in line:
            author = "/reviews/"

        if reviewer_number not in all_review_data[conference][split][paper_number][level].keys():
            all_review_data[conference][split][paper_number][level][reviewer_number] = dict()

        all_review_data[conference][split][paper_number][level][reviewer_number][author] = {
            "review_text": review_text,
            "file_path": line
        }

retained_papers = dict()
retained_reviews = dict()

retained_filepaths = []

# for conference in all_review_data.keys():
#     for split in all_review_data[conference].keys():
#         for paper_number in all_review_data[conference][split].keys():
#             print(f"Conference: {conference}, Split: {split}, Paper Number: {paper_number}")
#             print(json.dumps(all_review_data[conference][split][paper_number], indent=4))
#             break
#         break
#     break

for conference in all_review_data.keys():
    for split in all_review_data[conference].keys():
        for paper_number in all_review_data[conference][split].keys():
            if len(all_review_data[conference][split][paper_number]["reviews"].keys()) < args.min_reviews:
                    continue
            
            if conference not in retained_papers.keys():
                # then I am certainly going to include this paper, so initialize the count of the remaining number of papers I can include is the max allowed - 1
                retained_papers[conference] = {
                    "train_remaining": args.max_train_papers_per_conference - (split == "train"),
                    "test_remaining": args.max_test_papers_per_conference - (split == "test"),
                    "dev_remaining": args.max_dev_papers_per_conference - (split == "dev")
                }
            else:
                if retained_papers[conference][f"{split}_remaining"] <= 0:
                    continue
                else:
                    retained_papers[conference][f"{split}_remaining"] -= 1

            # if we reach here, we are retaining this paper

            for level in all_review_data[conference][split][paper_number].keys():
                for reviewer_number in all_review_data[conference][split][paper_number][level].keys():

                    for author in all_review_data[conference][split][paper_number][level][reviewer_number].keys():
                        write_to_dir = os.path.join(output_dir, f"{conference}/{split}/{paper_number}/{level}/{author if author != '/reviews/' else ''}/")
                        os.makedirs(write_to_dir, exist_ok=True)
                        with open(os.path.join(write_to_dir, f"{reviewer_number}.txt"), "w") as fout:
                            fout.write(all_review_data[conference][split][paper_number][level][reviewer_number][author]['review_text'])

                        retained_filepaths.append(all_review_data[conference][split][paper_number][level][reviewer_number][author]['file_path'])

                    
                        if conference not in retained_reviews.keys():
                            retained_reviews[conference] = dict()
                        if split not in retained_reviews[conference].keys():
                            retained_reviews[conference][split] = dict()
                        if level not in retained_reviews[conference][split].keys():
                            retained_reviews[conference][split][level] = 0

                        retained_reviews[conference][split][level] += 1

with open(args.output_paths_file, "w") as fout:
    for filepath in retained_filepaths:
        fout.write(filepath+"\n")

for conference in retained_papers.keys():
    retained_papers[conference]["train"] = args.max_train_papers_per_conference - retained_papers[conference]["train_remaining"]
    retained_papers[conference]["test"] = args.max_test_papers_per_conference - retained_papers[conference]["test_remaining"]
    retained_papers[conference]["dev"] = args.max_dev_papers_per_conference - retained_papers[conference]["dev_remaining"]
    del retained_papers[conference]["train_remaining"]
    del retained_papers[conference]["test_remaining"]
    del retained_papers[conference]["dev_remaining"]    

print("RETAINED PAPER STATISTICS:")
print(json.dumps(retained_papers, indent=4))

print("RETAINED REVIEW STATISTICS:")
print(json.dumps(retained_reviews, indent=4))

# print the total number of reviews retained
# print(f"Total reviews retained: {sum(retained_reviews.values())}")

total_reviews = 0

print("AGGREGATE REVIEW COUNTS:")
for level in ["level1","level2","level3","level4","reviews"]:
    total_train_reviews = 0
    total_test_reviews = 0
    total_dev_reviews = 0
    for conference in retained_reviews.keys():
        total_train_reviews += retained_reviews[conference]["train"].get(level, 0)
        total_test_reviews += retained_reviews[conference]["test"].get(level, 0)
        total_dev_reviews += retained_reviews[conference].get("dev", {}).get(level, 0)

    total_reviews += (total_train_reviews + total_test_reviews + total_dev_reviews)
    
    print(f"{level} reviews retained: {total_train_reviews} train, {total_test_reviews} test, {total_dev_reviews} dev")

print(f"Total reviews retained: {total_reviews}")

