import os
import json
import csv
import re
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--gt_files", type=str, default="/path/to/data")
parser.add_argument("--gen_files", type=str, default="/path/to/data")
parser.add_argument("--seed", type=str, default=42)
parser.add_argument("--file-name", type=str, default="llava15_pope_aokvqa_useopera_usedscr_answers_seed42_key_value_alpha0.6_beta0.8_start0_end30_aokvqa_pope_adversarial.jsonl")
args = parser.parse_args()

# open ground truth answers
gt_files = [json.loads(q) for q in open(os.path.expanduser(args.gt_files), "r")]

# open generated answers
gen_files = [json.loads(q) for q in open(os.path.expanduser(args.gen_files), "r")]

csv_dict = dict()
file_name = args.file_name
dataset_name = None
model_name = None

pattern = r".*?_seed(\d+).*?_alpha([\d.]+)_beta([\d.]+)_start(\d+)_end(\d+)"
#pattern = r".*?_seed(\d+).*?_alpha([\d.]+)_beta([\d.]+)_start(\d+)_end(\d+)"
match = re.search(pattern, file_name)

if match:
    csv_dict = {
        "seed": int(match.group(1)),
        "alpha": float(match.group(2)),
        "beta": float(match.group(3)),
        "start_layer": int(match.group(4)),
        "end_layer": int(match.group(5)),
    }

#model name
if "llava" in file_name:
    model_name = "llava"
elif "qwen" in file_name:
    model_name = "qwen"
elif "mplug" in file_name:
    model_name = "mplug"

#dataset name
if "coco" in file_name:
    dataset_name = "coco"
elif "aokvqa" in file_name:
    dataset_name = "aokvqa"
elif "gqa" in file_name:
    dataset_name = "gqa"

#dataset type
if "adversarial" in file_name:
    csv_dict["type"] = "adversarial"
elif "popular" in file_name:
    csv_dict["type"] = "popular"
elif "random" in file_name:
    csv_dict["type"] = "random"

#vcd
if "usevcd" in file_name:
    csv_dict["vcd"] = "T"
else:
    csv_dict["vcd"] = "F"

#opera
if "useopera" in file_name:
    csv_dict["opera"] = "T"
else:
    csv_dict["opera"] = "F"

#dscr
if "usedscr" in file_name:
    csv_dict["dscr"] = "T"
else:
    csv_dict["dscr"] = "F"

# calculate precision, recall, f1, accuracy, and the proportion of 'yes' answers
true_pos = 0
true_neg = 0
false_pos = 0
false_neg = 0
unknown = 0
total_questions = len(gt_files)
yes_answers = 0

# compare answers
for index, line in enumerate(gt_files):
    idx = line["question_id"]
    gt_answer = line["label"]
    assert idx == gen_files[index]["question_id"]

    gen_answer = gen_files[index]["text"]
    # convert to lowercase
    gt_answer = gt_answer.lower()
    gen_answer = gen_answer.lower()
    # strip
    gt_answer = gt_answer.strip()
    gen_answer = gen_answer.strip()
    # pos = 'yes', neg = 'no'
    if gt_answer == 'yes':
        if 'yes' in gen_answer:
            true_pos += 1
            yes_answers += 1
        else:
            false_neg += 1
    elif gt_answer == 'no':
        if 'no' in gen_answer:
            true_neg += 1
        else:
            yes_answers += 1
            false_pos += 1
    else:
        print(f'Warning: unknown gt_answer: {gt_answer}')
        unknown += 1
# calculate precision, recall, f1, accuracy, and the proportion of 'yes' answers
precision = true_pos / (true_pos + false_pos)
recall = true_pos / (true_pos + false_neg)
f1 = 2 * precision * recall / (precision + recall)
accuracy = (true_pos + true_neg) / total_questions
yes_proportion = yes_answers / total_questions
unknown_prop = unknown / total_questions

csv_dict["precision"] = precision
csv_dict["recall"] = recall
csv_dict["f1"] = f1
csv_dict["accuracy"] = accuracy

csv_file_path = f"./result/{model_name}_pope_{dataset_name}_score.csv"
csv_columns = list(csv_dict.keys())

file_exists = os.path.isfile(csv_file_path)

with open(csv_file_path, "a", newline="", encoding="utf-8") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=csv_columns)

    if not file_exists:
        writer.writeheader()

    writer.writerow(csv_dict)

# report results
print("eval_pope")
for name, score in csv_dict.items():
    print(f"\t{name}: {score}")

