import os
import json
import re
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import openai as OpenAI
import argparse
import matplotlib.pyplot as plt
import difflib

from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--all_files_path", type=str)
parser.add_argument("--output_dir", type=str, default="results/new_word_count/")
parser.add_argument("--n", type=int, default=1, help="n-gram size for new n-gram count.")
parser.add_argument("--metric", type=str, choices=["jaccard", "difference", "lcs"], default="jaccard", help="Metric to compute new word count.")
parser.add_argument("--vertical", action="store_true", help="Whether to stack subplots vertically.")
# TODO: make arrangements so that level 4 examples are replaced by their ttbt counterparts, probably take as argument a ttbt_id as input argument
parser.add_argument("--source_ttbt_dir", type=str, default=None, help="Directory containing TTBT reviews to replace level 4 examples with.")
parser.add_argument("--ttbt_id", type=str, default=None, help="TTBT id to replace level 4 examples with.")

args = parser.parse_args()

if (args.source_ttbt_dir is None) != (args.ttbt_id is None):
    raise ValueError("Both --source_ttbt_dir and --ttbt_id should be provided together or both should be None.")

is_ttbt_replacement = (args.ttbt_id is not None)


def extract_paper_fp_from_review_fp(review_filepath):
    ## extract the paper contents 
    pattern = r".*cleandata/(.*)/(train|test|dev)/.*(level[1-4]|reviews)/(.*)_([1-9]*).txt"
    match = re.search(pattern, review_filepath)
    conference = match.group(1)
    split = match.group(2)
    level = match.group(3)
    paper_number = match.group(4)
    reviewer_number = match.group(5)

    return conference, split, level, paper_number, reviewer_number

def get_source_filenames(conference, split, level, paper_number, reviewer_number):

    if level in ["level1", "level2"]:
        human_refs = []
        for i in range(20):
            candidate_ref = f"/Project/Human_or_AI/Data_Preprocessing/cleandata/{conference}/{split}/reviews/{paper_number}_{i}.txt"
            local_candidate_ref = candidate_ref.replace('/Project/Human_or_AI', '/ai-involvement-in-peer-reviews')
            if os.path.exists(local_candidate_ref) and os.path.getsize(local_candidate_ref) > 0:
                human_refs.append(candidate_ref)
        if len(human_refs) == 0:
            raise ValueError(f"No human references found for {conference} {split} {paper_number} {reviewer_number}")
        return human_refs
    

    return [f"/Project/Human_or_AI/Data_Preprocessing/cleandata/{conference}/{split}/reviews/{paper_number}_{reviewer_number}.txt"]

out_dir = f"results/new_word_count/"
os.makedirs(out_dir, exist_ok=True)

all_files_path = args.all_files_path

with open(all_files_path, "r") as fin:
    all_review_filepaths = [line.strip() for line in fin.readlines()]
    
results_data = dict()

for idx, review_filepath in enumerate(all_review_filepaths):
    # storing the review content here might be redundant
    with open(review_filepath.replace('/Project/Human_or_AI', '/ai-involvement-in-peer-reviews'), "r") as fin:
        review_text = fin.read().strip()
    results_data[review_filepath] = review_text

def phrase_set(text, n):
    words = re.findall(r"[A-Za-z0-9']+", text.lower())
    phrases = set()
    for i in range(len(words) - n + 1):
        phrases.add(" ".join(words[i:i+n]))
    return phrases

def get_longest_common_sublist(list_a: list, list_b: list):
    """
    Finds the longest common contiguous sublist between two lists,
    and returns both the sublist and its length.
    """
    n = len(list_a)
    m = len(list_b)
    
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    
    max_length = 0
    end_index_a = 0  # To track the end position of the LCS in list_a

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if list_a[i - 1] == list_b[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
                
                if dp[i][j] > max_length:
                    max_length = dp[i][j]
                    end_index_a = i  # store the end index of the LCS in list_a
            else:
                dp[i][j] = 0

    # Reconstruct the sublist using the end index and length
    longest_sublist = list_a[end_index_a - max_length:end_index_a]

    return max_length, longest_sublist


def compute_word_metric(candidate_text, reference_text, metric="jaccard", n=None):
    if metric != "lcs":
        set_cand = phrase_set(candidate_text, n=n)
        set_ref = phrase_set(reference_text, n=n)
        intersection = set_cand.intersection(set_ref)
        union = set_cand.union(set_ref)
        if metric == "jaccard":
            return (len(intersection) / len(union), None) if union else (0, None)
        elif metric == "difference":
            return len(set_cand - set_ref)/len(set_cand), set_cand - set_ref # number of new words in candidate not in reference
    else:
        list_cand = re.findall(r"[A-Za-z0-9']+", candidate_text.lower())
        list_ref = re.findall(r"[A-Za-z0-9']+", reference_text.lower())
        lcs_length, longest_sublist = get_longest_common_sublist(list_cand, list_ref)
        return lcs_length, longest_sublist

# ANSI colors
RESET = "\033[0m"
RED = "\033[91m"     # deletions
GREEN = "\033[92m"   # insertions
YELLOW = "\033[93m"  # replacements

def color_diff(a: str, b: str):
    """
    Show word-level differences between strings a and b with colors.
    - Words removed from 'a' -> red
    - Words added in 'b' -> green
    - Words replaced -> yellow
    """
    a_words = a.split()
    b_words = b.split()
    sm = difflib.SequenceMatcher(None, a_words, b_words)
    
    out_a, out_b = [], []
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == "equal":
            out_a.extend(a_words[i1:i2])
            out_b.extend(b_words[j1:j2])
        elif tag == "delete":
            out_a.extend([f"{RED}{w}{RESET}" for w in a_words[i1:i2]])
        elif tag == "insert":
            out_b.extend([f"{GREEN}{w}{RESET}" for w in b_words[j1:j2]])
        elif tag == "replace":
            out_a.extend([f"{YELLOW}{w}{RESET}" for w in a_words[i1:i2]])
            out_b.extend([f"{YELLOW}{w}{RESET}" for w in b_words[j1:j2]])
    
    print()
    print("Source L5: " + " ".join(out_a) + "\n")
    print("L4: " + " ".join(out_b) + "\n")


filepath2metric = dict()

level_3_newwords = set()
level_4_newwords = set()

if is_ttbt_replacement:
    unique_ttbt_review_keys = []

for idx, (key, val) in enumerate(tqdm(results_data.items())):
    conference, split, level, paper_number, reviewer_number = extract_paper_fp_from_review_fp(key)
    source_human_review = get_source_filenames(conference, split, level, paper_number, reviewer_number)

    metric_scores = []

    candidate_review = val

    if "level4" in key and is_ttbt_replacement:
        ttbt_l4_filepath = f"{args.source_ttbt_dir}/{conference}/{split}/{paper_number}/level4/{args.ttbt_id}/{reviewer_number}.txt"

        if ttbt_l4_filepath in unique_ttbt_review_keys:
            continue
        else:
            unique_ttbt_review_keys.append(ttbt_l4_filepath)
            
        candidate_review = open(ttbt_l4_filepath, "r").read().strip()


    for src_human_rev in source_human_review:
        metric_val, new_words_set = compute_word_metric(
            candidate_review,
            results_data[src_human_rev], 
            metric=args.metric,
            n=args.n
        )
        # if args.metric == "lcs" and level == "level4":
        #     if metric_val >= 20:
        #         print(f"{key}<=={src_human_rev}: {metric_val}")
        #         # print("The common sublist:", ' '.join(new_words_set))
        #         color_diff(results_data[src_human_rev], candidate_review) # logically difflib.SequenceMatcher(None, a, b) means "a is being edited to become b"
        #         input()
        #         print("\r")
        metric_scores.append(metric_val)

        if args.metric == "difference":
            if level == "level3":
                level_3_newwords.update(new_words_set)
            elif level == "level4":
                level_4_newwords.update(new_words_set)

    if args.metric != "lcs":
        filepath2metric[key] = np.mean(metric_scores)
    else:
        filepath2metric[key] = max(metric_scores)

author2color = {
"gpt_4o_latest": "blue",
"meta-llama-Llama-3.3-70B-Instruct": "orange",
"human": "green"
}

level_wise = [{}, {}, {}, {}, {}]

for key, val in filepath2metric.items():

    if "gpt_4o_latest" in key:
        author = "gpt_4o_latest"        
    elif "meta-llama-Llama-3.3-70B-Instruct" in key:
        author = "meta-llama-Llama-3.3-70B-Instruct"
    elif "/reviews/" in key:
        author = "human"

    if is_ttbt_replacement and "level4" in key:
        author = args.ttbt_id

    z_score = val

    for i in range(1, 5):
        if f"level{i}" in key:
            if author not in level_wise[i-1].keys():
                level_wise[i-1][author] = []
            level_wise[i-1][author].append(z_score)
    
    if "/reviews/" in key:
        if author not in level_wise[4].keys():
            level_wise[4][author] = []
        level_wise[4][author].append(z_score)

# print(level_wise[0]["gpt_4o_latest"])

level_wise_dict = dict()
for i in range(1,6):
    print(f"LEVEL {i} examples: {sum([len(level_wise[i-1][author]) for author in level_wise[i-1].keys()])}")
    level_wise_dict[f"LEVEL {i}"] = level_wise[i-1]

num_levels = 5
if args.vertical:
    fig, axes = plt.subplots(nrows=num_levels - (args.metric == 'lcs'), ncols=1, figsize=(6, 4 * (num_levels - (args.metric == 'lcs'))))
else:
    fig, axes = plt.subplots(nrows=1, ncols=num_levels - (args.metric == 'lcs'), figsize=(6 * (num_levels - (args.metric == 'lcs')), 4))
bins = 50

for i in range(1, num_levels - (args.metric == 'lcs') + 1):
    level_name = f"LEVEL {i}"
    ax = axes[i - 1]
    # extract the first element of each tuple (similarity_gpt); guard against empty or malformed entries

    for author in level_wise_dict[level_name].keys():
        data = level_wise_dict[level_name][author]
        if data:
            # hist_range = (np.min(data)-0.01, np.max(data)+0.01) if args.metric == "difference" else (0, 1.01)
            if np.min(data) >= -0.01 and np.max(data) <= 1.01:
                hist_range = (0, 1.01)
            elif args.metric == "lcs":
                if level_name != "LEVEL 4":
                    hist_range = (np.min(data)-0.1, np.max(data)+0.1)
                    bins = np.arange(np.min(data)-0.1, np.max(data)+0.1, 0.2)
                else:
                    hist_range = (np.min(data)-1, np.max(data)+1)
                    bins = np.arange(np.min(data)-0.5, np.max(data)+0.5, 1)
            else:
                hist_range = (np.min(data)-0.01, np.max(data)+0.01)


            ax.hist(data, bins=bins, color=author2color.get(author, "green"), alpha=0.5, range=hist_range, label=author)

    ax.set_title(level_name)
    # ax.set_ylim(0,1)
    ax.set_ylabel('Count')
    ax.set_xlabel(f"{args.metric.capitalize()}")
    ax.legend()

# Add a global title for all subplots
# fig.suptitle(f"Distribution of {args.metric.capitalize()} over levels", fontsize=16)

plt.tight_layout()
plt.savefig(os.path.join(args.output_dir, f"{args.metric}-distribution{f'-{args.n}' if args.metric != 'lcs' else ''}.png"), dpi=600)

# print the new words alphabetically in two files
if args.metric == "difference":
    with open(os.path.join(args.output_dir, "level_3_new_words.txt"), "w") as fout:
        for word in sorted(level_3_newwords):
            fout.write(word + "\n")

    with open(os.path.join(args.output_dir, "level_4_new_words.txt"), "w") as fout:
        for word in sorted(level_4_newwords):
            fout.write(word + "\n") 