from pathlib import Path

from easy_io import read_json, dump_json
import matplotlib.pyplot as plt
import scipy.stats

from src.path import baseline_performance_dir, baseline_analysis_dir
from src.config import new_datasets_names, new_datasets_initial_models, baseline_models_open, baseline_models_closed, covnert_dataset_full_name_dict, convert_model_names_short

baseline_models_open = [
    "google/gemma-7b-it",
    "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf",
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Qwen/Qwen1.5-14B-Chat", "Qwen/Qwen1.5-72B-Chat",
]
baseline_models_closed = [
    "gpt-3.5-turbo-0125",
    "models/gemini-1.0-pro-001",
    "claude-3-opus-20240229",
    "gpt-4-0613", "gpt-4-0125-preview"
]


makers_dict = {
    "google/gemma-7b-it": "$Gemma$",
    "meta-llama/Llama-2-13b-chat-hf": "$Llama{13}$", "meta-llama/Llama-2-70b-chat-hf": "$Llama{70}$",
    "mistralai/Mixtral-8x7B-Instruct-v0.1": "$Mixtral$",
    "Qwen/Qwen1.5-14B-Chat": "$Qwen{14}$", "Qwen/Qwen1.5-72B-Chat": "$Qwen{72}$",
    "gpt-3.5-turbo-0125": "$GPT3.5$",
    "models/gemini-1.0-pro-001": "$Gemini$",
    "claude-3-opus-20240229": "$Claude3$",
    "gpt-4-0613": "$GPT4_{23}$", "gpt-4-0125-preview": "$GPT4_{24}$",
}

output_dir = baseline_analysis_dir / "other_tasks_analysis_results"

if __name__ == "__main__":
    # make matplotlib text larger
    plt.rcParams.update({'font.size': 15})
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    baseline_performance = read_json(baseline_performance_dir / "simple_prompt_baseline/performance.json")
    baseline_models_list = baseline_models_open + baseline_models_closed
    
    for other_task in ["ero_rating", "mmlu"]:
        other_task_performance: dict[str, dict] = read_json(baseline_analysis_dir / f"other_tasks_performance/{other_task}.json")["metric"]
        
        correlation_dict = {}
        for initial_model in new_datasets_initial_models:
            for dataset_name in new_datasets_names:
                fig = plt.figure(figsize=[6, 3.5])
                for metric in ["recall", "precision"]:
                    x = []
                    y = []
                    baseline_mdoels_src_for_this_metric = []
                    for baseline_model in baseline_models_list:
                        if baseline_model not in other_task_performance.keys():
                            continue
                        
                        baseline_model_key = f"baseline_model={baseline_model}"
                        if baseline_performance[dataset_name][f"initial_model={initial_model}"].get(baseline_model_key) is not None:
                            x.append(other_task_performance[baseline_model])
                            y.append(baseline_performance[dataset_name][f"initial_model={initial_model}"][baseline_model_key]["average"]["metrics"][metric]["average"])
                            baseline_mdoels_src_for_this_metric.append(convert_model_names_short[baseline_model])
                    
                        # plt.plot(x[-1:], y[-1:], color={"precision": "red", "recall": "blue"}[metric], marker=makers_dict[baseline_model],
                        #         markersize=46, linestyle="None")
                        #         # label={"precision": "Precision", "recall": "Recall"}[metric]
                        
                        plt.text(x[-1], y[-1], makers_dict[baseline_model], color={"precision": "red", "recall": "blue"}[metric],
                                 fontsize=13, horizontalalignment='center', verticalalignment='center')

                    # # plot text
                    # if metric == "recall":
                    #     for idx, text in enumerate(baseline_mdoels_src_for_this_metric):
                    #         if other_task == "ero_rating":
                    #             xmargin = 13
                    #         else:
                    #             xmargin = 2
                    #         plt.text(x[idx]-xmargin, y[idx]+0.02, text, fontsize=9)

                    # pearson's correlation and spearman's correlation
                    corr = {}
                    pearson, p_p = scipy.stats.pearsonr(x, y)
                    corr["pearson"] = {"correlation": pearson, "p": p_p}
                    spearman, p_s = scipy.stats.spearmanr(x, y)
                    corr["spearman"] = {"correlation": spearman, "p": p_s}
                    
                    correlation_dict.setdefault(initial_model, {}).setdefault(dataset_name, {})[metric] = corr
                
                other_task_name = {"ero_rating": "LMSYS Chatbot Arena Elo Rating", "mmlu": "Accuracy on MMLU [%]"}[other_task]
                

                plt.ylim([-.05, 1.05])
                if other_task == "ero_rating":
                    plt.xlim([1005, 1270])
                if other_task == "mmlu":
                    plt.xlim([50, 90])
                
                plt.xlabel(other_task_name)
                plt.title(f"{covnert_dataset_full_name_dict[dataset_name]}", y=1.05)

                # if dataset_name == "math_word_problem_generation":
                #     plt.legend()
                
                plt.tight_layout()
                plt.savefig(output_dir / f"{other_task}_{initial_model.split('/')[-1]}_{dataset_name}.pdf")
    
        dump_json(correlation_dict, output_dir / f"{other_task}_correlation.json")
