import os
import json
import numpy as np
from typing import Dict
from model import DummySafeguardModel
from benchmark import SEASafeguardBench


def get_general_performance(result_paths: Dict[str, str]) -> Dict[str, list]:
    performance = {"prompt_classification": [], "response_classification": []}
    # Get English performance
    en_performance = json.load(open(result_paths["general_EN_English"]))["performance"]
    performance["prompt_classification"].append(en_performance["prompt_classification"]["auprc"])
    performance["response_classification"].append(en_performance["response_classification"]["auprc"])
    
    # Get Local performance
    sea_performance = {"prompt_classification": [], "response_classification": []}
    for subset_split, language in [
        ("IN", "Local"),
        ("MS", "Local"),
        ("MY", "Local"),
        ("TA", "Local"),
        ("TH", "Local"),
        ("TL", "Local"),
        ("VI", "Local"),
    ]:
        t_performance = json.load(open(result_paths[f"general_{subset_split}_{language}"]))["performance"]
        sea_performance["prompt_classification"].append(t_performance["prompt_classification"]["auprc"])
        sea_performance["response_classification"].append(t_performance["response_classification"]["auprc"])
    performance["prompt_classification"].append(np.mean(sea_performance["prompt_classification"]).item())
    performance["response_classification"].append(np.mean(sea_performance["response_classification"]).item())
    return performance

def get_cultural_itw_performance(result_paths: Dict[str, str]) -> Dict[str, list]:
    performance = {"prompt_classification": []}
    # Get English performance
    en_performance = {"prompt_classification": []}
    for subset_split, language in [
        ("IN_EN", "English"),
        ("MS_EN", "English"),
        ("MY_EN", "English"),
        ("TA_EN", "English"),
        ("TH_EN", "English"),
        ("TL_EN", "English"),
        ("VI_EN", "English"),
    ]:
        t_performance = json.load(open(result_paths[f"cultural_in_the_wild_{subset_split}_{language}"]))["performance"]
        en_performance["prompt_classification"].append(t_performance["prompt_classification"]["auprc"])
    performance["prompt_classification"].append(np.mean(en_performance["prompt_classification"]).item())
    
    # Get Local performance
    sea_performance = {"prompt_classification": []}
    for subset_split, language in [
        ("IN_EN", "Local"),
        ("MS_EN", "Local"),
        ("MY_EN", "Local"),
        ("TA_EN", "Local"),
        ("TH_EN", "Local"),
        ("TL_EN", "Local"),
        ("VI_EN", "Local"),
    ]:
        t_performance = json.load(open(result_paths[f"cultural_in_the_wild_{subset_split}_{language}"]))["performance"]
        sea_performance["prompt_classification"].append(t_performance["prompt_classification"]["auprc"])
    performance["prompt_classification"].append(np.mean(sea_performance["prompt_classification"]).item())
    return performance

def get_cultural_cg_performance(result_paths: Dict[str, str]) -> Dict[str, list]:
    performance = {"prompt_classification": [], "response_classification": []}
    # Get English performance
    en_performance = {"prompt_classification": [], "response_classification": []}
    for subset_split, language in [
        ("EN", "English"),
        ("IN_EN", "English"),
        ("MS_EN", "English"),
        ("MY_EN", "English"),
        ("TA_EN", "English"),
        ("TH_EN", "English"),
        ("TL_EN", "English"),
        ("VI_EN", "English"),
    ]:
        t_performance = json.load(open(result_paths[f"cultural_content_generation_{subset_split}_{language}"]))["performance"]
        en_performance["prompt_classification"].append(t_performance["prompt_classification"]["auprc"])
        en_performance["response_classification"].append(t_performance["response_classification"]["auprc"])
    performance["prompt_classification"].append(np.mean(en_performance["prompt_classification"]).item())
    performance["response_classification"].append(np.mean(en_performance["response_classification"]).item())
    
    # Get Local performance
    sea_performance = {"prompt_classification": [], "response_classification": []}
    for subset_split, language in [
        ("IN_EN", "Local"),
        ("MS_EN", "Local"),
        ("MY_EN", "Local"),
        ("TA_EN", "Local"),
        ("TH_EN", "Local"),
        ("TL_EN", "Local"),
        ("VI_EN", "Local"),
    ]:
        t_performance = json.load(open(result_paths[f"cultural_content_generation_{subset_split}_{language}"]))["performance"]
        sea_performance["prompt_classification"].append(t_performance["prompt_classification"]["auprc"])
        sea_performance["response_classification"].append(t_performance["response_classification"]["auprc"])
    performance["prompt_classification"].append(np.mean(sea_performance["prompt_classification"]).item())
    performance["response_classification"].append(np.mean(sea_performance["response_classification"]).item())
    return performance

def report_performance(result_paths: Dict[str, str]):
    general_performance = get_general_performance(result_paths)
    cultural_itw_performance = get_cultural_itw_performance(result_paths)
    cultural_cg_performance = get_cultural_cg_performance(result_paths)

    prompt_classification_results = general_performance["prompt_classification"] + cultural_itw_performance["prompt_classification"] + cultural_cg_performance["prompt_classification"]
    prompt_classification_results = prompt_classification_results + [np.mean(prompt_classification_results).item()]

    response_classification_results = general_performance["response_classification"] + cultural_cg_performance["response_classification"]
    response_classification_results = response_classification_results + [np.mean(response_classification_results).item()]

    print_results = [round(result * 100, 1) for result in prompt_classification_results + response_classification_results]
    print(" & ".join(map(str, print_results)))

def main():
    model_name = "dummy"
    model = DummySafeguardModel()
    
    result_paths = {}
    benchmark = SEASafeguardBench()
    for subset in benchmark.available_subsets_splits.keys():
        for subset_split in benchmark.available_subsets_splits[subset]:
            languages = ["English", "Local"] if len(subset_split.split("_")) > 1 else ["English"] if subset_split == "EN" else ["Local"]
            for language in languages:
                save_path = f"./results/{model_name}/{subset}/{subset_split}/{language}/performance.json"
                if not os.path.exists(save_path):
                    metrics = benchmark.eval(model, subset=subset, split=subset_split, language=language, verbose=True)
                    metrics.save_to_json(save_path)
                result_paths[f"{subset}_{subset_split}_{language}"] = save_path

    report_performance(result_paths)


if __name__ == "__main__":
    main()