import json

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pandas import DataFrame


def load_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

def get_training_logs(entries):
    def filter_training(log) -> bool:
        details = log.get("Details")
        if details is None:
            return False
        stage = details.get("Stage")
        if stage is None:
            return False
        return stage.get("!Value") == "Training"
    return list(filter(filter_training, entries))

def get_evaluation_logs(entries):
    def filter_training(log) -> bool:
        details = log.get("Details")
        if details is None:
            return False
        stage = details.get("Stage")
        if stage is None:
            return False
        return stage.get("!Value") == "Evaluation"
    return list(filter(filter_training, entries))

def get_details(entries):
    return list(filter(lambda x: x is not None, map(lambda x: x.get("Details"), entries)))
    
def get_time_consumption(entries):
    return list(map(lambda x: (float)(x["ElapsedMilliseconds"]["!Value"]), get_details(entries)))

if __name__ == "__main__":
    model_maps = {
        "Qwen2.5-0.5B": "Titanic-Qwen2.5-0.5b-Context_40-46.8860%",
        "Llama3.2-1B": "Titanic-Llama3.2-1b-Context_40-50.9988%",
        "Qwen2.5-1.5B": "Titanic-Qwen2.5-1.5b-Context_40-52.2914%",
        "Llama3.2-3B": "Titanic-Llama3.2-3b-Context_40-55.9342%",
        "Qwen2.5-3B": "Titanic-Qwen2.5-3b-Context_40-69.9177%",
        "Mistral-7B": "Titanic-Mistral-7b-Context_40-66.2750%",
        "Qwen2.5-7B": "Titanic-Qwen2.5-7b-Context_40-71.3278%",
        "Llama3.1-8B": "Titanic-Llama3.1-8b-Context_40-44.6533%",
        "Gemma2-9B": "Titanic-Gemma2-9b-Context_40-68.9777%",
        "Llama3.2-11B": "Titanic-Llama3.2-11b-Context_40-43.1257%",
        "Qwen2.5-14B": "Titanic-Qwen2.5-14b-Context_40-81.3161%",
        "Qwen2.5-32B": "Titanic-Qwen2.5-32b-Context_40-80.0235%",
        "Llama3.1-70B": "Titanic-Llama3.1-70b-Context_40-76.1457%",
        "Qwen2.5-72B": "Titanic-Qwen2.5-72b-Context_40-78.0259%",
        "GPT-3.5-Turbo": "Titanic-GPT-3.5-Turbo-Context_40-77.5558%",
        "GPT-4-Turbo": "Titanic-GPT-4-Turbo-Context_40-82.0212%",
        "GPT-4o-Mini": "Titanic-GPT-4o-mini-Context_40-79.9060%",
        "GPT-4o": "Titanic-GPT-4o-Context_40-81.0811%",
    }
    data = list()
    for model, file in model_maps.items():
        consumptions = get_time_consumption(get_training_logs(load_json(f"logs/{file}.json")))
        for consumption in consumptions:
            if consumption > 45000:
                continue
            data.append([model, consumption])
    data = DataFrame(data, columns=["Model", "Time Consumption"])
    plt.figure(figsize=(9,8), dpi=150)
    sns.set_theme(style="whitegrid", context="paper")
    sns.boxplot(x ='Time Consumption', y ='Model', data = data)
    plt.subplots_adjust(left=0.12, right=0.97, top=0.97, bottom=0.07)
    plt.xticks(np.arange(0, 45000, step=2500))
    plt.xlabel("Time Consumption (ms)")
    plt.savefig("Evaluation - Time Consumption - Training.png")

    data = list()
    for model, file in model_maps.items():
        consumptions = get_time_consumption(get_evaluation_logs(load_json(f"logs/{file}.json")))
        for consumption in consumptions:
            if consumption > 17500:
                continue
            data.append([model, consumption])
    data = DataFrame(data, columns=["Model", "Time Consumption"])
    plt.figure(figsize=(9,8), dpi=150)
    sns.set_theme(style="whitegrid", context="paper")
    sns.boxplot(x ='Time Consumption', y ='Model', data = data)
    plt.subplots_adjust(left=0.12, right=0.97, top=0.97, bottom=0.07)
    plt.xticks(np.arange(0, 22500, step=2500))
    plt.xlabel("Time Consumption (ms)")
    plt.savefig("Evaluation - Time Consumption - Evaluation.png")
    plt.show()
    pass