import os
import json
import csv
import re
import argparse
from collections import defaultdict
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix

parser = argparse.ArgumentParser()
parser.add_argument("--gt_dir", type=str, default="/path/to/data")
parser.add_argument("--gen_dir", type=str, default="/path/to/data")
parser.add_argument("--seed", type=str, default=55)
parser.add_argument("--file-name", type=str, default="llava15_mme_answers_seed42")
args = parser.parse_args()

def load_jsonl(file_path):
    """Load a JSONL file and return a list of dicts."""
    data = []
    with open(os.path.expanduser(file_path), "r", encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

# open ground truth answers
gt_files = sorted([f for f in os.listdir(args.gt_dir) if f.endswith(".json")])
# open generated answers
gen_files = sorted([f for f in os.listdir(args.gen_dir) if f.endswith(".jsonl") and args.file_name in f])
seed = args.seed

total_scores = 0
score_dict = dict()
csv_dict = dict()
model_name = None

file_name = args.file_name

pattern = r".*?_seed(\d+).*?_alpha([\d.]+)_beta([\d.]+)_.*?_start(\d+)_end(\d+)"
#pattern = r".*?_seed(\d+).*?_alpha([\d.]+)_beta([\d.]+)_start(\d+)_end(\d+)"
match = re.search(pattern, file_name)

if match:
    csv_dict = {
        "seed": int(match.group(1)),
        "alpha": float(match.group(2)),
        "beta": float(match.group(3)),
        "start_layer": int(match.group(4)),
        "end_layer": int(match.group(5)),
    }

if "key_value" in file_name:
    csv_dict["KV option"] = "KV"
elif "value" in file_name:
    csv_dict["KV option"] = "V"
elif "key" in file_name:
    csv_dict["KV option"] = "K"

if "llava" in file_name:
    model_name = "llava"
elif "qwen" in file_name:
    model_name = "qwen"
elif "mplug" in file_name:
    model_name = "mplug"

#vcd
if "usevcd" in file_name:
    csv_dict["vcd"] = "T"
else:
    csv_dict["vcd"] = "F"

#opera
if "useopera" in file_name:
    csv_dict["opera"] = "T"
else:
    csv_dict["opera"] = "F"

#dscr
if "usedscr" in file_name:
    csv_dict["dscr"] = "T"
else:
    csv_dict["dscr"] = "F"


for gt_file in gt_files:
    print(f"file_name: {file_name}")
    gen_file = next((f for f in gen_files if os.path.splitext(f)[0].replace(gt_file.split('.')[0], '').strip('_') == str(file_name)), None)
    print(f"gen_file: {gen_file}")

    if gen_file is None:
        print(f"No matching generated file for {gt_file}")
        continue

    gt_path = os.path.join(args.gt_dir, gt_file)
    gen_path = os.path.join(args.gen_dir, gen_file)

    gt_data = load_jsonl(gt_path)
    gen_data = load_jsonl(gen_path)

    true_pos, true_neg, false_pos, false_neg, unknown = 0, 0, 0, 0, 0
    total_questions = len(gt_data)
    yes_answers = 0
    acc_plus_correct_num = 0

    other_count = 0
    image_results = defaultdict(lambda: {"correct": 0, "total": 0})

    for index, line in enumerate(gt_data):
        idx = line["question_id"]
        image_name = line["image"]
        gt_answer = line["label"]
        assert idx == gen_data[index]["question_id"]

        #convert to lowercase and strip
        gt_answer = gt_answer.strip().lower()
        gen_answer = gen_data[index]["text"].strip().lower()

        if gt_answer == "yes":
            if 'yes' in gen_answer:
                true_pos += 1
                yes_answers += 1
                image_results[image_name]["correct"] += 1
            elif 'no' in gen_answer:
                false_neg += 1
        elif gt_answer == "no":
            if 'no' in gen_answer:
                true_neg += 1
                image_results[image_name]["correct"] += 1
            elif 'yes' in gen_answer:
                false_pos += 1
                yes_answers += 1
        else:
            unknown += 1

        image_results[image_name]["total"] += 1

    for img_name, result in image_results.items():
        if result["correct"] == 2:
            acc_plus_correct_num += 1

    acc_plus = acc_plus_correct_num / len(image_results) if len(image_results) > 0 else 0

    print(f"TP {true_pos}")
    print(f"FP {false_pos}")
    print(f"FN {false_neg}")
    print(f"TN {true_neg}")

    # calculate precision, recall, f1, accuracy, and the proportion of 'yes' answers
    precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0.0
    recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0.0
    f1 = (
        2 * precision * recall / (precision + recall)
        if (precision + recall) > 0
        else 0.0
    )
    accuracy = (true_pos + true_neg) / total_questions if total_questions > 0 else 0.0

    yes_proportion = yes_answers / total_questions
    unknown_prop = unknown / total_questions

    task_score = accuracy*100 + acc_plus*100
    score_dict[gt_file.split(".")[0]] = task_score
    csv_dict[gt_file.split(".")[0]] = task_score
    total_scores += task_score

    # print(f"\n **Results for {gt_file}**")
    # print(f"Accuracy: {accuracy}")
    # print(f"Precision: {precision}")
    # print(f"Recall: {recall}")
    # print(f"F1-Score: {f1}")
    # print(f"acc_plus: {acc_plus}\n")
    # print(f'yes: {yes_proportion}')
    # print(f'unknow: {unknown_prop}')

print("\n**Total score:", total_scores)
for task_name, task_score in score_dict.items():
    print(f"\t{task_name}: {task_score}")

score_dict["Total Score"] = total_scores
csv_dict["Total Score"] = total_scores

csv_file_path = (f"./eval/result/{model_name}_mme_dscr_score.csv")
csv_columns = list(csv_dict.keys())

file_exists = os.path.isfile(csv_file_path)

with open(csv_file_path, "a", newline="", encoding="utf-8") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=csv_columns)

    if not file_exists:
        writer.writeheader()

    writer.writerow(csv_dict)