import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os 
import json


categories = {
    "MMLU": ["MMLU_General"],
    "BBL-BC": ["Play_Dialog", "StrategyQA", "Strange_Stories", "Winowhy"],
    "BBL-MC": ["Vitaminc_Fact_Verification", "Language_Identification"],
    "BBL-QA": ["BBQ_Lite", "Code_Line_Description", "Logical_Deduction", "Known_Unknowns", "Hindu_Knowledge", "Novel_Concepts"],
}

def stat_on_adversarial_csv(csv_path):
    df = pd.read_csv(csv_path, index_col=None)
    df = df.dropna()
    categories = df['Type'].unique()
    for c in categories:
        print(c + ":\t" + "{:0.2f}".format(df[df['Type'] == c]["Performance"].mean()))

# 0. Utility functions
def load_main_results():
    df = pd.DataFrame(columns=["Model", "Dataset", "Collection", "Type", "ID", "Shots", "Performance", "Metric", "Include?"])
    mmlu_root_dir = "./result_csv/Main/MMLU"
    bbl_root_dir = "./result_csv/Main/BBL"
    mmlu_sheets = [f for f in os.listdir(mmlu_root_dir) if "DS" not in f]
    bbl_sheets = [f for f in os.listdir(bbl_root_dir) if "DS" not in f]
    for sheet in mmlu_sheets:
        df = pd.concat([df, pd.read_csv(os.path.join(mmlu_root_dir, sheet), index_col=None, header=0)], ignore_index=True)
    for sheet in bbl_sheets:
        df = pd.concat([df, pd.read_csv(os.path.join(bbl_root_dir, sheet), index_col=None, header=0)], ignore_index=True)
    
    df = df[df["Include?"] == True]
    df = df[df["Performance"] != "PD"]
    df.reset_index(drop=True)

    return df

def load_adv_results():
    df = pd.DataFrame(columns=["Model", "Dataset", "Collection", "Type", "ID", "Shots", "Performance", "Metric", "Include?"])
    data_dir = "./result_csv/Adversial/"
    sheets = [f for f in os.listdir(data_dir) if "DS" not in f]
    for sheet in sheets:
        df = pd.concat([df, pd.read_csv(os.path.join(data_dir, sheet), index_col=None, header=0)], ignore_index=True)
    df = df[df["Include?"] == True]
    df = df[df["Performance"] != "PD"]
    df.reset_index(drop=True)
    return df

def load_icl_results():
    df = pd.DataFrame(columns=["Model", "Dataset", "Collection", "Type", "ID", "Shots", "Performance", "Metric", "Include?"])
    mmlu_root_dir = "./result_csv/Shots/MMLU"
    bbl_root_dir = "./result_csv/Shots/BBL"
    mmlu_sheets = [f for f in os.listdir(mmlu_root_dir) if "DS" not in f]
    bbl_sheets = [f for f in os.listdir(bbl_root_dir) if "DS" not in f]
    for sheet in mmlu_sheets:
        df = pd.concat([df, pd.read_csv(os.path.join(mmlu_root_dir, sheet), index_col=None, header=0)], ignore_index=True)
    for sheet in bbl_sheets:
        df = pd.concat([df, pd.read_csv(os.path.join(bbl_root_dir, sheet), index_col=None, header=0)], ignore_index=True)
    
    df = df[df["Include?"] == True]
    df = df[df["Performance"] != "PD"]
    df.reset_index(drop=True)

    return df

def load_scaling_results():
    df = pd.DataFrame(columns=["Model", "Dataset", "Collection", "Type", "ID", "Shots", "Performance", "Metric", "Include?"])
    mmlu_root_dir = "./result_csv/Scaling/MMLU"
    bbl_root_dir = "./result_csv/Scaling/BBL"
    mmlu_sheets = [f for f in os.listdir(mmlu_root_dir) if "DS" not in f]
    bbl_sheets = [f for f in os.listdir(bbl_root_dir) if "DS" not in f]
    for sheet in mmlu_sheets:
        df = pd.concat([df, pd.read_csv(os.path.join(mmlu_root_dir, sheet), index_col=None, header=0)], ignore_index=True)
    for sheet in bbl_sheets:
        df = pd.concat([df, pd.read_csv(os.path.join(bbl_root_dir, sheet), index_col=None, header=0)], ignore_index=True)
    
    df = df[df["Include?"] == True]
    df = df[df["Performance"] != "PD"]
    df.reset_index(drop=True)


    return df

def print_adversarial_latex(csv_dir):
    csv_file = [os.path.join(csv_dir, f) for f in os.listdir(csv_dir) if ".csv" in f]
    overall_acc = {"Correct": [], "Incorrect": [], "Unobserved": [], "Default": [], "Negation": [], "Random": []}
    overall_std = {"Correct": [], "Incorrect": [], "Unobserved": [], "Default": [], "Negation": [], "Random": []}

    for f in csv_file:
        df = pd.read_csv(f, index_col=None)
        df = df.dropna()
        print(df.loc[0]["Dataset"])
        categories = ["Correct", "Incorrect", "Unobserved", "Default", "Negation", "Random"]
        accs = []
        stds = [] 
        acc_text_bf, std_text_bf = None, None
        for c in categories:
            accs.append(df[df['Type'] == c]["Performance"].mean())
            overall_acc[c].append(df[df['Type'] == c]["Performance"].mean())
            stds.append(df[df['Type'] == c]["Performance"].std())
            overall_std[c].append(df[df['Type'] == c]["Performance"].std())
        
        accs = ["{:0.1f}".format(a*100) if a != max(accs) else "\\textbf{" + "{:0.2f}".format(a*100) + "}" for a in accs]
        stds = ["\\pm {:0.1f}".format(s*100) if not np.isnan(s) else "-" for s in stds]

        print(" & ".join(accs))
        print(" & ".join(stds))
        print()

    print("Overall")
    for key in overall_acc:
        overall_acc[key] = np.mean(overall_acc[key])
        overall_std[key] = np.std(overall_std[key])

    overall_acc = list(overall_acc.values())
    overall_std = list(overall_std.values())
    overall_acc = ["{:0.1f}".format(a*100) if a != max(overall_acc) else "\\textbf{" + "{:0.2f}".format(a*100) + "}" for a in overall_acc]
    overall_std = ["\\pm {:0.1f}".format(s*100) if not np.isnan(s) else "-" for s in overall_std]
    print(" & ".join(overall_acc))
    print(" & ".join(overall_std))

    
def print_main_latex():
    df = load_main_results()
    models = df["Model"].unique()
    data = {}
    for model in models:
        data[model] = {}
        model_df = df[df["Model"] == model]
        datasets = model_df["Dataset"].unique()
        for dataset in datasets:
            dataset_df = model_df[model_df["Dataset"] == dataset]
            dataset_df = dataset_df[pd.to_numeric(dataset_df["Performance"]).notnull()]
            dataset_df["Performance"] = pd.to_numeric(dataset_df["Performance"])
            unobserved = dataset_df[(dataset_df["Type"] == "Unobserved") | (dataset_df["Type"] == "Default")]
            observed = dataset_df[(dataset_df["Type"] != "Unobserved") & (dataset_df["Type"] != "Default")]
            observed_acc = observed["Performance"].mean()
            observed_std = observed["Performance"].std()
            unobserved_acc = unobserved["Performance"].mean()
            unobserved_std = unobserved["Performance"].std()
            data[model][dataset] = {
                "observed_acc": observed_acc, 
                "observed_std": observed_std, 
                "unobserved_acc": unobserved_acc, 
                "unobserved_std": unobserved_std
            }
    overall_data = {}
    bbl_overall = {}
    for model in data.keys():
        overall_data[model] = {}
        observed_accs = [data[model][dataset]["observed_acc"] for dataset in data[model].keys()]
        observed_stds = [data[model][dataset]["observed_std"] for dataset in data[model].keys()]
        unobserved_accs = [data[model][dataset]["unobserved_acc"] for dataset in data[model].keys()]
        unobserved_stds = [data[model][dataset]["unobserved_std"] for dataset in data[model].keys()]

        observed_acc = sum(observed_accs) * 100 / len(observed_accs) if len(observed_accs) > 0 else 0
        observed_std = sum(observed_stds) * 100 / len(observed_stds) if len(observed_stds) > 0 else 0
        unobserved_acc = sum(unobserved_accs) * 100 / len(unobserved_accs) if len(unobserved_accs) > 0 else 0
        unobserved_std = sum(unobserved_stds) * 100 / len(unobserved_stds) if len(unobserved_stds) > 0 else 0

        delta = observed_acc - unobserved_acc

        bbl_overall[model] = {
            "observed_acc": "{:.1f}".format(observed_acc) if observed_acc > 0 else "-",
            "observed_std": "{:.2f}".format(observed_std) if observed_std > 0 else "-",
            "unobserved_acc": "{:.1f}".format(unobserved_acc) if unobserved_acc > 0 else "-",
            "unobserved_std": "{:.2f}".format(unobserved_std) if unobserved_std > 0 else "-",
            "delta": "{:.2f}".format(delta) if delta != 0 else "_"
        }
        for category in categories.keys():
            observed_accs = [data[model][dataset]["observed_acc"] for dataset in data[model].keys() if dataset in categories[category]]
            observed_stds = [data[model][dataset]["observed_std"] for dataset in data[model].keys() if dataset in categories[category]]
            unobserved_accs = [data[model][dataset]["unobserved_acc"] for dataset in data[model].keys() if dataset in categories[category]]
            unobserved_stds = [data[model][dataset]["unobserved_std"] for dataset in data[model].keys() if dataset in categories[category]]

            observed_acc = sum(observed_accs) * 100 / len(observed_accs) if len(observed_accs) > 0 else 0
            observed_std = sum(observed_stds) * 100 / len(observed_stds) if len(observed_stds) > 0 else 0
            unobserved_acc = sum(unobserved_accs) * 100 / len(unobserved_accs) if len(unobserved_accs) > 0 else 0
            unobserved_std = sum(unobserved_stds) * 100 / len(unobserved_stds) if len(unobserved_stds) > 0 else 0
            delta = observed_acc - unobserved_acc

            overall_data[model][category] = {
                "observed_acc": "{:.1f}".format(observed_acc) if observed_acc > 0 else "-",
                "observed_std": "{:.2f}".format(observed_std) if observed_std > 0 else "-",
                "unobserved_acc": "{:.1f}".format(unobserved_acc) if unobserved_acc > 0 else "-",
                "unobserved_std": "{:.2f}".format(unobserved_std) if unobserved_std > 0 else "-",
                "delta": "{:.2f}".format(delta) if delta != 0 else "_"
            }
    # Print Main Table Results
    print("\t", "MMLU", "\t", "BBL-QA", "\t", "BBL-BC", "\t", "BBL-MC")
    for model in overall_data.keys():
        print(model, "Obs:", 
            overall_data[model]["MMLU"]["observed_acc"], 
            overall_data[model]["MMLU"]["observed_std"],
            overall_data[model]["BBL-QA"]["observed_acc"], 
            overall_data[model]["BBL-QA"]["observed_std"],
            overall_data[model]["BBL-BC"]["observed_acc"], 
            overall_data[model]["BBL-BC"]["observed_std"],
            overall_data[model]["BBL-MC"]["observed_acc"], 
            overall_data[model]["BBL-MC"]["observed_std"])
        
        print(model, "Uno:",
            overall_data[model]["MMLU"]["unobserved_acc"], 
            overall_data[model]["MMLU"]["unobserved_std"],
            overall_data[model]["BBL-QA"]["unobserved_acc"], 
            overall_data[model]["BBL-QA"]["unobserved_std"],
            overall_data[model]["BBL-BC"]["unobserved_acc"], 
            overall_data[model]["BBL-BC"]["unobserved_std"],
            overall_data[model]["BBL-MC"]["unobserved_acc"], 
            overall_data[model]["BBL-MC"]["unobserved_std"])
        print(model, "Del:", 
            overall_data[model]["MMLU"]["delta"], 
            "&",
            overall_data[model]["BBL-QA"]["delta"], 
            "&",
            overall_data[model]["BBL-BC"]["delta"], 
            "&",
            overall_data[model]["BBL-MC"]["delta"], 
            "&")

def print_disaggregated_latex(df: pd.DataFrame, model_orders: list):
    models = df["Model"].unique()
    data = {}
    for model in models:
        data[model] = {}
        model_df = df[df["Model"] == model]
        datasets = model_df["Dataset"].unique()
        for dataset in datasets:
            dataset_df = model_df[model_df["Dataset"] == dataset]
            dataset_df = dataset_df[pd.to_numeric(dataset_df["Performance"]).notnull()]
            dataset_df["Performance"] = pd.to_numeric(dataset_df["Performance"])
            unobserved = dataset_df[(dataset_df["Type"] == "Unobserved") | (dataset_df["Type"] == "Default")]
            observed = dataset_df[(dataset_df["Type"] != "Unobserved") & (dataset_df["Type"] != "Default")]
            observed_acc = observed["Performance"].mean()
            observed_std = observed["Performance"].std()
            unobserved_acc = unobserved["Performance"].mean()
            unobserved_std = unobserved["Performance"].std()
            data[model][dataset] = {
                "observed_acc": observed_acc, 
                "observed_std": observed_std, 
                "unobserved_acc": unobserved_acc, 
                "unobserved_std": unobserved_std
            }

    datasets = df["Dataset"].unique()
    for dataset in datasets:
        print("--------------------")
        print(dataset)
        print()
        observed_results = [] 
        unobserved_results = []
        for model in model_orders:
            ob_result = data[model][dataset]["observed_acc"]*100
            unob_result = data[model][dataset]["unobserved_acc"]*100
            ob_str = "{:.1f}".format(data[model][dataset]["observed_acc"]*100) + " ($\pm$ " + "{:.1f})".format(data[model][dataset]["observed_std"]*100)
            unob_str = "{:.1f}".format(data[model][dataset]["unobserved_acc"]*100) + " ($\pm$ " + "{:.1f})".format(data[model][dataset]["unobserved_std"]*100)
            if ob_result > unob_result:
                ob_str = "\\textbf{" + ob_str + "}"
            if unob_result > ob_result:
                unob_str = "\\textbf{" + unob_str + "}"
            observed_results.append(ob_str)
            unobserved_results.append(unob_str)

        print(" & ".join(observed_results))
        print()
        print(" & ".join(unobserved_results))
        print("--------------------")

def print_main_disaggregated_latex():
    df = load_main_results()
    print_disaggregated_latex(df, ["Flan-T5-XL", "Flan-T5-XXL", "T0++", "Alpaca-7B", "Alpaca-13B"])

def print_icl_disaggregated_latex():
    df = load_icl_results()
    print_disaggregated_latex(df, ["Flan-T5-Small", "Flan-T5-Base", "Flan-T5-Large", "Flan-T5-XL", "Flan-T5-XXL"])

def print_scaling_latex():
    df = load_scaling_results()
    print_disaggregated_latex(df, ["Flan-T5-Small", "Flan-T5-Base", "Flan-T5-Large", "Flan-T5-XL", "Flan-T5-XXL"])




    
    


    


    
