import os
import math
import pandas as pd
from tqdm import tqdm
import torch
import t2v_metrics
from argparse import ArgumentParser
from utils.misc import fix_seed
fix_seed(42)

# ------------------------- Configuration -------------------------
MODELS_TO_EVALUATE = ["stable-diffusion2.1"]
MODES = ["concise", "detailed"]
DIALECTS = ["aae", "che", "sge", "ine", "bre"]
# ------------------------------------------------------------------

# Path settings for the skintone evaluation task.
TASK = "skintone"
MODEL = "vqa"
BASE_DIR = "./Dialect/multimodal-dialectal-bias"
DATA_DIR = os.path.join(BASE_DIR, "data")
OUTPUT_DIR = os.path.join(BASE_DIR, f"out/{TASK}/base_models_{MODEL}")
FOLDER_MAPPING = {"dialect": "dialect_imgs", "sae": "sae_imgs"}

# Images per prompt and initialize the VQA scoring model.
scorer = t2v_metrics.VQAScore(model='clip-flant5-xxl')

# Skin-tone reference prompts.
LIGHT_SKIN_PROMPT = "light-skinned"
DARK_SKIN_PROMPT = "dark-skinned"

NUM_IMAGES = 9

# def get_average_score(img_dir, model_name, folder, gen_prompt, ref_prompt, num_images):
#     """
#     Compute the average VQA similarity score for a set of generated images.
#     """
#     prompt_dir = os.path.join(img_dir, model_name, folder, gen_prompt)
#     scores = []
    
#     for i in range(num_images):
#         image_path = os.path.join(prompt_dir, f"{i}.jpg")
#         if not os.path.exists(image_path):
#             # Handle filename inconsistencies by replacing problematic characters.
#             processed_prompt = gen_prompt.replace("'", "_")
#             prompt_dir = os.path.join(img_dir, model_name, folder, processed_prompt)
#             image_path = os.path.join(prompt_dir, f"{i}.jpg")
        
#         # Compute the score for the (image, text) pair.
#         score_output = scorer(images=[image_path], texts=[ref_prompt])
#         try:
#             score_tensor = score_output[0][0]
#         except TypeError:
#             score_tensor = score_output
#         # Convert to Python float (move tensor from GPU if needed).
#         if isinstance(score_tensor, torch.Tensor):
#             score = score_tensor.detach().cpu().item()
#         else:
#             score = float(score_tensor)
#         scores.append(score)
    
#     return sum(scores) / len(scores)

def get_average_score(res_dir, dialect, gen_prompt, ref_prompt, num_images):
    """
    Compute the average VQA similarity score for a set of generated images.
    """
    prompt_dir = os.path.join(res_dir, dialect, gen_prompt)
    scores = []
    
    for i in range(num_images):
        image_path = os.path.join(prompt_dir, f"{i}.jpg")
        if not os.path.exists(image_path):
            # Handle filename inconsistencies by replacing problematic characters.
            processed_prompt = gen_prompt.replace("'", "_")
            prompt_dir = os.path.join(res_dir, dialect, processed_prompt)
            image_path = os.path.join(prompt_dir, f"{i}.jpg")
        
        # Compute the score for the (image, text) pair.
        score_output = scorer(images=[image_path], texts=[ref_prompt])
        try:
            score_tensor = score_output[0][0]
        except TypeError:
            score_tensor = score_output
        # Convert to Python float (move tensor from GPU if needed).
        if isinstance(score_tensor, torch.Tensor):
            score = score_tensor.detach().cpu().item()
        else:
            score = float(score_tensor)
        scores.append(score)
    
    return sum(scores) / len(scores)


def compute_normalized_score(score_light, score_dark, factor=100):
    """
    Compute the normalized score using the drop formula.
    """
    exp_light = math.exp(score_light * factor)
    exp_dark = math.exp(score_dark * factor)
    return exp_light / (exp_light + exp_dark)


def main(args):    
    # Construct paths for the CSV data and image directory.
    print(f"\n{'='*60}\nEvaluating Skintone (VQA) for MODE: '{args.mode}', DIALECT: '{args.dialect}'\n{'='*60}\n")
    
    data_path = os.path.join(args.data_dir, args.dialect, "test.csv")
    # Load CSV data and filter to only those prompts with people.
    try:
        df = pd.read_csv(data_path, encoding="unicode_escape")
    except Exception as e:
        print(f"Failed to load file {data_path}: {e}")
        # continue
    
    df = df.loc[df['person_in_prompt'] == 1]
    dialect_prompts = df["Dialect_Prompt"].tolist()
    sae_prompts = df["SAE_Prompt"].tolist()
    
    # Prepare output directory for this configuration.
    output_model_dir = os.path.join(args.res_dir, "vqa_score_skin", args.dialect)
    os.makedirs(output_model_dir, exist_ok=True)
    
    # Define output file paths.
    breakdown_dialect_path = os.path.join(output_model_dir, "breakdown_dialect.csv")
    breakdown_sae_path = os.path.join(output_model_dir, "breakdown_sae.csv")
    summary_path = os.path.join(output_model_dir, "summary.csv")
    
    # Check if outputs already exist; if so, skip evaluation.
    if os.path.exists(breakdown_dialect_path) and os.path.exists(breakdown_sae_path) and os.path.exists(summary_path):
        print(f"Results already exist for MODE: '{args.mode}', DIALECT: '{args.dialect}'. Skipping evaluation.\n")
        # os._exit(0)
        # continue
    
    results_dialect = []  # Breakdown for dialect images evaluation.
    results_sae = []      # Breakdown for SAE images evaluation.
    
    # print(f"\n>> Evaluating MODEL: '{args.model}'")
    print(f"\n>> Evaluating")
    for i in tqdm(range(len(dialect_prompts)), desc="Processing prompts"):
        dialect_prompt = dialect_prompts[i]
        sae_prompt = sae_prompts[i]
        
        # For dialect images.
        score_light = get_average_score(args.res_dir, args.dialect, dialect_prompt, LIGHT_SKIN_PROMPT, NUM_IMAGES)
        score_dark = get_average_score(args.res_dir, args.dialect, dialect_prompt, DARK_SKIN_PROMPT, NUM_IMAGES)
        norm_dialect = round(compute_normalized_score(score_light, score_dark), 4)
        results_dialect.append({
            "Prompt_Index": i,
            "Dialect_Prompt": dialect_prompt,
            "SAE_Prompt": sae_prompt,
            "Normalized_Score": norm_dialect
        })
        print(f"MODE: {args.mode} | DIALECT: {args.dialect} | Prompt {i} (dialect) | '{dialect_prompt}': {norm_dialect:.4f}")
        
        # For SAE images.
        score_light = get_average_score(args.res_dir, f"{args.dialect}_sae", sae_prompt, LIGHT_SKIN_PROMPT, NUM_IMAGES)
        score_dark = get_average_score(args.res_dir, f"{args.dialect}_sae", sae_prompt, DARK_SKIN_PROMPT, NUM_IMAGES)
        norm_sae = round(compute_normalized_score(score_light, score_dark), 4)
        results_sae.append({
            "Prompt_Index": i,
            "SAE_Prompt": sae_prompt,
            "Normalized_Score": norm_sae
        })
        print(f"MODE: {args.mode} | DIALECT: {args.dialect} | Prompt {i} (sae) | '{sae_prompt}': {norm_sae:.4f}")
    
    # Calculate overall average normalized scores.
    avg_dialect = round(sum(r["Normalized_Score"] for r in results_dialect) / len(results_dialect) if results_dialect else 0, 4)
    avg_sae = round(sum(r["Normalized_Score"] for r in results_sae) / len(results_sae) if results_sae else 0, 4)
    
    print(f"\n--- Final Results for MODE: '{args.mode}', DIALECT: '{args.dialect}' ---")
    print(f"Overall Dialect Normalized Score: {avg_dialect:.4f}")
    print(f"Overall SAE Normalized Score: {avg_sae:.4f}\n")
    
    # Save breakdown results.
    df_dialect = pd.DataFrame(results_dialect)
    df_sae = pd.DataFrame(results_sae)
    df_dialect.to_csv(breakdown_dialect_path, index=False)
    df_sae.to_csv(breakdown_sae_path, index=False)
    
    # Calculate additional metrics for the summary.
    absolute_drop = round(avg_sae - avg_dialect, 4)
    drop_ratio = round((absolute_drop / avg_sae) if avg_sae != 0 else 0, 4)
    
    # Save summary results.
    summary_df = pd.DataFrame({
        "Evaluation_Type": ["Dialect", "SAE", "Absolute Drop", "Drop Ratio"],
        "Overall_Average_Score": [avg_dialect, avg_sae, absolute_drop, drop_ratio]
    })
    summary_df.to_csv(summary_path, index=False)
    
    print(f"Results saved to: {output_model_dir}\n")


def parse_arguments():
    parser = ArgumentParser()
    # parser.add_argument("--swap", action="store_true")
    parser.add_argument("--res_dir", type=str, default="", help="the parent results directory with subfolders like sae and sge")
    parser.add_argument("--mode", type=str, default="concise")
    parser.add_argument("--data_dir", type=str, default="./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/concise/")
    parser.add_argument("--dialect", type=str, default="sge")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    args.data_dir = f"./multimodal-dialectal-bias/data/text/train_val_test/4-1-1/{args.mode}/"
    main(args)
