# pip install t2v_metrics==1.1
import os
import pandas as pd
from tqdm import tqdm
import t2v_metrics
import json
from argparse import ArgumentParser
from collections import defaultdict, OrderedDict
from utils.misc import fix_seed
from const import *

fix_seed(42)

# Initialize the new scoring metric.
scorer = t2v_metrics.VQAScore(model='clip-flant5-xxl')

# def get_average_score(img_dir, model_name, folder, gen_prompt, ref_prompt, num_images=9):
def get_average_score(res_dir, dialect, gen_prompt, ref_prompt, num_images=9):
    """
    Compute the average similarity score for a set of generated images using the new metric.
    """
    # prompt_dir = os.path.join(img_dir, model_name, folder, gen_prompt)
    prompt_dir = os.path.join(res_dir, dialect, gen_prompt)
    scores = []

    for i in range(num_images):
        image_path = os.path.join(prompt_dir, f"{i}.jpg")
        if not os.path.exists(image_path):
            # Handle filename inconsistencies.
            processed_prompt = gen_prompt.replace("'", "_")
            # prompt_dir = os.path.join(img_dir, model_name, folder, processed_prompt)
            prompt_dir = os.path.join(res_dir, dialect, processed_prompt)
            image_path = os.path.join(prompt_dir, f"{i}.jpg")

        # Compute the score for the (image, text) pair.
        score_output = scorer(images=[image_path], texts=[ref_prompt])
        try:
            score = score_output[0][0]
        except TypeError:
            score = score_output
        scores.append(score)

    return float(sum(scores)/len(scores))


def main(args):
    data_path = os.path.join(args.data_dir, args.dialect, "test.csv")
    df = pd.read_csv(data_path, encoding="unicode_escape")
    dialect_prompts = df["Dialect_Prompt"].tolist()
    sae_prompts = df["SAE_Prompt"].tolist()
    
    results = defaultdict(list)
    for i in tqdm(range(len(dialect_prompts)), desc="Processing prompts"):
        dialect_prompt = dialect_prompts[i]
        sae_prompt = sae_prompts[i]

        # Evaluate dialect images (using SAE prompt as reference).
        dialect = args.dialect
        score = get_average_score(args.res_dir, dialect, dialect_prompt, sae_prompt)
        results[dialect].append(score)
        print(f"Prompt {i} ({dialect}): {score:.4f}")

        # Evaluate SAE images (using SAE prompt for both generated and reference).
        score = get_average_score(args.res_dir, f"{dialect}_sae", sae_prompt, sae_prompt)
        results[f"{dialect}_sae"].append(score)
        print(f"Prompt {i} (sae): {score:.4f}")

    print("\n------------------- Final Results -------------------")
    avgs = {}
    for set_type, scores in results.items():
        avg_score = sum(scores) / len(scores)
        print(f"{set_type.capitalize()} total score: {avg_score:.4f}")
        avgs[set_type] = avg_score
    
    for set_type, avg_score in avgs.items():
        results[f"{set_type}_avg"] = avg_score
    
    # output_file = os.path.join(args.res_dir, "vqa_score_understanding.json")
    # with open(output_file, "w", encoding="utf-8") as f:
    #     json.dump(results, f, indent=4)
    output_file = os.path.join(args.res_dir, "vqa_score_understanding.json")
    # Load existing results if file exists
    if os.path.exists(output_file):
        with open(output_file, "r", encoding="utf-8") as f:
            existing_results = json.load(f)
    else:
        existing_results = {}

    # Merge new results
    existing_results.update(results)

    # Save back with sorted keys
    sorted_results = OrderedDict()
    for key in sorted(key for key in existing_results if not key.endswith('_avg')):
        sorted_results[key] = existing_results[key]
    for key in sorted(key for key in existing_results if key.endswith('_avg')):
        sorted_results[key] = existing_results[key]
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(sorted_results, f, indent=4)


def parse_arguments():
    parser = ArgumentParser()
    # parser.add_argument("--swap", action="store_true")
    parser.add_argument("--res_dir", type=str, default="", help="the parent results directory with subfolders like sae and sge")
    parser.add_argument("--data_dir", type=str, default="./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/concise/")
    parser.add_argument("--mode", type=str, default="concise")
    parser.add_argument("--dialect", type=str, default="sge")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    args.data_dir = f"./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/{args.mode}"
    main(args)
