import matplotlib.pyplot as plt
# import colors
from matplotlib import cm
import numpy as np
from scipy import interpolate
from scipy.spatial import ConvexHull
from scipy.ndimage import gaussian_filter1d
import pandas as pd
import numpy as np
import json
import os
from loguru import logger

renames = {
    "gpt-3.5-turbo": "chatgpt",
    "llama-7b": "Llama 7B",
    "llama-13b": "Llama 13B",
    "vicuna-7b": "Vicuna 7B",
    "vicuna-13B": "Vicuna 13B",
    "tiiuae/falcon-7b": "Falcon 7B",
    "tiiuae/falcon-7b-instruct": "Falcon-Instruct 7B",
    "davinci": "GPT 175B",
    "text-davinci-001": "InstructGPT 175B",
    
}

def get_few_shot_number(prompts):
    prompt = prompts[0]["prompt"]
    if not isinstance(prompt, str):
        prompt = prompt["prompt"]
    is_in_all = 0
    if "1." in prompt:
        if "7." in prompt:
            return 8
        
        return 3
    return 0


def get_results(folder, include_std=False):
    if not os.path.exists(os.path.join(folder, "results", "dict_real_metrics.json")):
        return None
    pipelines = [json.load(open(os.path.join(folder, "pipelines", f"pipeline_{file}.json"), "r")) 
                 for file in range(len(os.listdir(os.path.join(folder, "pipelines"))))]
    results = None
    for file in os.listdir(os.path.join(folder, "results")):
        if file.startswith("list_dict_metrics"):
            results = json.load(open(os.path.join(folder, "results", file), "r"))
            break
    dict_real_metrics = json.load(open(os.path.join(folder, "results", "dict_real_metrics.json"), "r"))
    results.append(dict_real_metrics)

    dataframe = pd.DataFrame(results)
    columns = {
        "complexity": "finetune_temporal_fake_to_fake.accuracy", 
        "diversity": "average_distinctness_spacy.1",
        "performance": "finetune_temporal_fake_to_real.accuracy", 
        "faithfulness": "finetune_temporal_real_to_fake.accuracy", 
        "conformity": "mauve",
    }
    
    columns = {k: v for k,v in columns.items() if v in dataframe.columns}

    if include_std:
        cols = list(columns.keys())
        for metric in cols:
            columns[metric + "_std"] = columns[metric] + "_std" 

    generators = [pipeline["generator"] if pipeline["generator"]["class"] == "Generator" else pipeline["generator"]["converters"][0] for pipeline in pipelines]
    results = dataframe[list(columns.values())].reset_index(drop=True)
    results.columns = list(columns.keys())
    results["temperature"] = [generator["querier"]["temperature"] for generator in generators][:len(results) - 1] + [1.0]
    results["model"] = [generator["querier"]["model"] for generator in generators][:len(results) - 1] + ["real"]
    results["top_p"] = [generator["querier"].get("top_p", None) for generator in generators][:len(results) - 1] + [1.0]
    classes = list(pipelines[0]["generator"]["prompts"]["prompts"].keys())
    results["few_shot"] = [get_few_shot_number(generator["prompts"]["prompts"][classes[0]]) for generator in generators][:len(results) - 1] + [0]
    results["diverse"] = [pipeline["generator"]["class"] != "Generator" for pipeline in pipelines][:len(results) - 1] + [False]

    results["complexity"] = 1 - results["complexity"]

    return results

def get_results_eval(eval_folder, include_std=False):
    results = pd.DataFrame()
    for subfolder in os.listdir(eval_folder):
        try:
            if os.path.exists(os.path.join(eval_folder, subfolder, "pipelines")):
                new_results = get_results(os.path.join(eval_folder, subfolder), include_std=include_std)
            else:
                new_results = get_results_eval(os.path.join(eval_folder, subfolder), include_std=include_std)
        except Exception as e:
            continue
        if new_results is not None:
            results = pd.concat([results, new_results], ignore_index=True)
    
    return results



if __name__ == "__main__":
    datasets = ["AGNews", "eli5", "goemotions", "SST"]
    datasets_real = datasets
    os.makedirs("../processed", exist_ok=True)
    for dataset in datasets:
        results = get_results_eval(f"../../data/generations/{dataset}/meta_eval", include_std=True)
        results.fillna(1, inplace=True)
        results.to_csv(f"../processed/{dataset}.csv", sep="\t", index=False)