import pandas as pd
import numpy as np
import os
import math

exclude = False

map = {"celeba": ["Skintone"], "cub200": ["class", "color"], "LFW": ["Race"], "cars196": ["class"]}

no_batchminer = ["arcface", "multisimilarity", "proxynca"]
path = "Downstream_Results/"
files = ["embeds_means_{}.csv", "embeds_std_{}.csv", "downstreams_means_{}.csv", "downstreams_std_{}.csv"]

def appendix_ready(dfs):
    fixed = []
    for df in dfs:
        df["loss/batch_mining"] = ["/".join([s.capitalize() for s in tup]) if tup[0] not in no_batchminer else tup[0].capitalize() for tup in zip(df.index.get_level_values("loss"), df.index.get_level_values("batch_mining"))]
        df = df.set_index("loss/batch_mining", append=True)
        df = df.droplevel("loss")
        df = df.droplevel("batch_mining")
        if "model" not in df.index.names:
            df["model"] = "NA"
            df = df.set_index("model", append=True)
            
        ###
        cols_to_stack = list(df.index.names)
        cols_to_stack.remove("model")
        fixed.append(df.stack().rename_axis(None, axis=1).reset_index().pivot_table(index=["model", "metric"], columns=cols_to_stack, values=["overall", "gap"]))
    df = pd.concat(fixed, keys=["Representation", "Downstream"], names = ["Space"])
    
    if len(df.columns.get_level_values("method").unique()) == 1:
        df = df.droplevel("method", axis=1)
        df.columns = df.columns.reorder_levels([None, "loss/batch_mining", "balance"])
        ascending = [False, True, False]
    
    elif "balance" in df.columns.names:
        df.columns = df.columns.reorder_levels([None, "loss/batch_mining", "balance", "method"])
        ascending = [False, True, True, True]
        
    else:
        df.columns = df.columns.reorder_levels([None, "loss/batch_mining", "method"])
        ascending = [False, True, True]
    df = df.sort_index(axis=1, level=list(range(df.columns.nlevels)), ascending=ascending)
    
    return df

def paper_ready(dfs):
    fixed = []
    for df in dfs:
        df["loss/batch_mining"] = ["/".join([s.capitalize() for s in tup]) for tup in zip(df.index.get_level_values("loss"), df.index.get_level_values("batch_mining"))]
        df = df.set_index("loss/batch_mining", append=True)
        df = df.droplevel("loss")
        df = df.droplevel("batch_mining")
        fixed.append(df.T)
    df = pd.concat(fixed, keys=["Representation", "Downstream (LR)"], names = ["Space"])
    df = df.droplevel("subset")
    
    if len(df.columns.get_level_values("method").unique()) == 1:
        df = df.droplevel("method", axis=1)
        df.columns = df.columns.reorder_levels(["loss/batch_mining", "balance"])
        ascending = [True, False]
    
    elif "balance" in df.columns.names:
        df = df.loc[:, df.columns.get_level_values("loss/batch_mining") == "Margin/Distance"]
        df.columns = df.columns.reorder_levels(["loss/batch_mining", "balance", "method"])
        ascending = True
        
    else:
        df = df.loc[:, df.columns.get_level_values("loss/batch_mining") == "Margin/Distance"]
        df.columns = df.columns.reorder_levels(["loss/batch_mining", "method"])
        ascending = True
    
    df = df.sort_index(axis=1, level=list(range(df.index.nlevels)), ascending=ascending)
    
    return df

def cub200_reformat(df, type, order=None):
    df = df.loc[np.logical_or(df.index.get_level_values("balance") == True, df.index.get_level_values("balance") == "True"), :]
    df = df.loc[np.logical_and(df.index.get_level_values("batch_mining") == "distance", df.index.get_level_values("loss") == "margin"),:]
    metrics = df.columns.get_level_values("metric").unique().values
    series = pd.DataFrame(columns=metrics)
    subgroup_df = df.loc[:, np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")]
    if type == "mean":
        order = {}
        for metric in metrics:
            series[metric] = (subgroup_df.xs(metric, axis=1,level=1, drop_level=False).sort_values(by=subgroup_df.index[0],axis=1,ascending=False).iloc[:,:6].mean(axis=1)-subgroup_df.xs(metric, axis=1,level=1, drop_level=False).sort_values(by=subgroup_df.index[0],axis=1,ascending=False).iloc[:,6:].mean(axis=1)).values
            order[metric] = subgroup_df.xs(metric, axis=1,level=1, drop_level=False).sort_values(by=subgroup_df.index[0],axis=1,ascending=False).columns.get_level_values("subset").values
        df["gap"] = series.values
    elif type == "std":
        for metric in metrics:
            series[metric] = (subgroup_df.loc[:,order[metric]].xs(metric, axis=1,level=1, drop_level=False).iloc[:,:6].pow(2).mean(axis=1)+subgroup_df.loc[:,order[metric]].xs(metric, axis=1,level=1, drop_level=False).iloc[:,6:].pow(2).mean(axis=1)).pow(0.5).values
        df["gap"] = series.values
    else:
        raise ValueError("Type {} not accepted. Only 'std' and 'mean' can be used.".format(type))
    print(df["gap"])
    return df, order

def celeba_reformat(df, type, exclude=False):
    if exclude:
        df = df.loc[:,df.columns.get_level_values("subset") != "Skintone_6"]
        
    ## Division between light skin and dark skin (3 metrics so multiply by 3)
    halfway_point = 3*int(len(df.loc[:,np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].columns.get_level_values("subset").unique()) / 2)
    
    if type == "mean":
        df["gap"] = df.loc[:,np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].iloc[:,:halfway_point].mean(axis=1, level=1, numeric_only=True) - df.loc[:,np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].iloc[:,halfway_point:].mean(axis=1, level=1, numeric_only=True)
    elif type == "std":
        df["gap"] = (df.loc[:,np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].iloc[:,:halfway_point].pow(2).sum(axis=1, level=1, numeric_only=True).pow(0.5).div(3).pow(2) + df.loc[:,np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].iloc[:,halfway_point:].pow(2).sum(axis=1, level=1, numeric_only=True).pow(0.5).div(3).pow(2)).pow(0.5)
    else:
        raise ValueError("Type {} not accepted. Only 'std' and 'mean' can be used.".format(type))
    return df
    
def lfw_reformat(df, type):
    if type == "mean":
        for period in range(1, len(df.columns.get_level_values("subset").unique()) - 2):
            if "gaps" not in locals() and "gaps" not in globals():
                gaps = df.loc[:, np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].groupby(axis=1,level=1).diff(axis=1, periods = period).fillna(0).abs().iloc[:,-3:]
            else:
                gaps = gaps.add(df.loc[:, np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].groupby(axis=1,level=1).diff(axis=1, periods = period).fillna(0).abs().iloc[:,-3:], axis=1, level=1)
        df["gap"] = gaps.droplevel("subset", axis=1) / float(len(df.columns.get_level_values("subset").unique()) - 3)
    elif type == "std":
        df["gap"] = df.loc[:, np.logical_and(df.columns.get_level_values("subset") != "overall", df.columns.get_level_values("subset") != "gap")].pow(2).sum(axis=1, level=1, numeric_only=True).pow(0.5) / float(len(df.columns.get_level_values("subset").unique()) - 3)
    else:
        raise ValueError("Type {} not accepted. Only 'std' and 'mean' can be used.".format(type))
    
    if "model" in df.index.names:
        df = df.loc[df.index.get_level_values("model") == "rf",:]
    return df
    
def dataset_specific_reformat(df, dataset, type, order=None):
    if dataset == "cub200":
        return cub200_reformat(df, type=type, order=order)
    elif dataset == "celeba":
        return celeba_reformat(df, type=type, exclude=exclude)
    elif dataset == "LFW":
        return lfw_reformat(df, type=type)
    else:
        raise NotImplementedError

def paper_reformat(df, dataset, type, order=None):
    if "split" in df.index.names:
        if len(np.unique(df.index.get_level_values("split"))) == 1:
            split = np.unique(df.index.get_level_values("split")).item()
            df = df.loc[df.index.get_level_values("split") ==  split,:].droplevel("split")
        else:
            df = df.loc[df.index.get_level_values("split") == "test",:].droplevel("split")
    if "model" in df.index.names:
        if len(np.unique(df.index.get_level_values("model")).flatten()) == 1:
            model = np.unique(df.index.get_level_values("model")).item()
            df = df.loc[df.index.get_level_values("model") ==  model,:].droplevel("model")
        else:
            df = df.loc[df.index.get_level_values("model") == "lr",:].droplevel("model")
    if "triplet" in df.index.get_level_values("loss"):
        df = df.loc[np.logical_or(df.index.get_level_values("loss") == "margin", df.index.get_level_values("loss") == "triplet"), :]
    df = df.loc[np.logical_not(df.index.get_level_values("loss") == "npair")]
    if not np.all(np.isnan(df.loc[:, df.columns.get_level_values("subset") == "gap"].values)):
        if type == "std":
            df["gap"] = df[["inflated", "reduced"]].pow(2).sum(axis=1, level=1).pow(0.5)
        df = df.loc[:,df.columns.get_level_values("subset") == "gap"]
    else:
        df = dataset_specific_reformat(df, dataset, type=type, order=order)
        if isinstance(df, tuple): out = df; df, order = out
        df = df.loc[:,df.columns.get_level_values("subset") == "gap"]
    return df, order
    
def appendix_reformat(df, dataset, type):
    if "split" in df.index.names:
        if len(np.unique(df.index.get_level_values("split"))) == 1:
            split = np.unique(df.index.get_level_values("split")).item()
            df = df.loc[df.index.get_level_values("split") ==  split,:].droplevel("split")
        else:
            df = df.loc[df.index.get_level_values("split") == "test",:].droplevel("split")
    if not np.all(np.isnan(df.loc[:, df.columns.get_level_values("subset") == "gap"].values)):
        if type == "std":
            df["gap"] = df[["inflated", "reduced"]].pow(2).sum(axis=1, level=1).pow(0.5)
    else:
        df = dataset_specific_reformat(df, dataset, type=type)
        df = df.loc[:,np.logical_or(df.columns.get_level_values("subset") == "overall", df.columns.get_level_values("subset") == "gap")]
    return df

for key in map:
    print(key)
    tables_path = os.path.join(*[path, key, "CSV_output", "tables"])
    
    for attr in map[key]:
        curr_files = [os.path.join(tables_path, file.format(attr)) for file in files]
        
        if key in ["celeba", "LFW"]:
            index_cols = list(range(4))
        else:
            index_cols = list(range(5))
            
        embeds_means = pd.read_csv(curr_files[0], sep=",", header=[0,1], index_col=index_cols)
        embeds_std = pd.read_csv(curr_files[1], sep=",", header=[0,1], index_col=index_cols)
        downstreams_means = pd.read_csv(curr_files[2], sep=",", header=[0,1], index_col=index_cols).abs()
        downstreams_std = pd.read_csv(curr_files[3], sep=",", header=[0,1], index_col=index_cols).abs()

        embeds_means = appendix_reformat(embeds_means, key, type="mean")
        embeds_std = appendix_reformat(embeds_std, key, type="std")
        downstreams_means = appendix_reformat(downstreams_means, key, type="mean")
        downstreams_std = appendix_reformat(downstreams_std, key, type="std")

        means = appendix_ready([embeds_means, downstreams_means])
        stds = appendix_ready([embeds_std, downstreams_std])
        
        embeds_means = pd.read_csv(curr_files[0], sep=",", header=[0,1], index_col=index_cols)
        embeds_std = pd.read_csv(curr_files[1], sep=",", header=[0,1], index_col=index_cols)
        downstreams_means = pd.read_csv(curr_files[2], sep=",", header=[0,1], index_col=index_cols).abs()
        downstreams_std = pd.read_csv(curr_files[3], sep=",", header=[0,1], index_col=index_cols).abs()

        embeds_means, order = paper_reformat(embeds_means, key, type="mean")
        embeds_std, order = paper_reformat(embeds_std, key, type="std", order=order)
        downstreams_means, order = paper_reformat(downstreams_means, key, type="mean")
        downstreams_std, order = paper_reformat(downstreams_std, key, type="std", order=order)

        means = paper_ready([embeds_means, downstreams_means])
        stds = paper_ready([embeds_std, downstreams_std])
    
        formatted = r"$" + means.round(3).astype(str) + " \pm " + stds.round(3).astype(str) + r"$"
        print(formatted.to_latex(multicolumn=True, multirow=True))
