import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb
import os
import functools
import json
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--filepath", type=str, default=os.getcwd())
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--attribute", type=str, default=None)
args = parser.parse_args()

filepath = args.filepath
dataset = args.dataset
relevant_attribute = args.attribute

if dataset == "cars196" or dataset == "cub200":
    index_cols = list(range(3))
else:
    index_cols = list(range(2))

PATH = os.path.join(*[filepath, dataset, "CSV_output"])
JOBS_PATH = os.path.join(PATH, "jobs")

embeds = pd.DataFrame()
downstreams = pd.DataFrame()

for job_id in os.listdir(JOBS_PATH):
    JOB_PATH = os.path.join(JOBS_PATH, job_id)
    if relevant_attribute in os.listdir(JOB_PATH):
        EMBED_PATH = os.path.join(JOB_PATH, relevant_attribute, "embed.csv")
        DOWNSTREAM_PATH = os.path.join(JOB_PATH, relevant_attribute, "downstream.csv")
        HPARAM_PATH = os.path.join(JOB_PATH, relevant_attribute, "hparam.json")
        ARGS_PATH = os.path.join(JOB_PATH, relevant_attribute, "downstream.json")
    elif len(os.listdir(JOB_PATH)) < 4:
        print("Did not open job information at path {} due to insufficient number of files.".format(JOB_PATH))
        continue
    else:
        EMBED_PATH = os.path.join(JOB_PATH, "embed.csv")
        DOWNSTREAM_PATH = os.path.join(JOB_PATH, "downstream.csv")
        HPARAM_PATH = os.path.join(JOB_PATH, "hparam.json")
        ARGS_PATH = os.path.join(JOB_PATH, "downstream.json")
    
    with open(HPARAM_PATH, "r") as fp:
        hparam = json.load(fp)
    
    with open(ARGS_PATH, "r") as fp:
        args = json.load(fp)
    
    if relevant_attribute and relevant_attribute != args["attribute"]:
        continue

    embed = pd.read_csv(EMBED_PATH, sep=",", header=[0,1], index_col=index_cols)
    downstream = pd.read_csv(DOWNSTREAM_PATH, sep=",", header=[0,1], index_col=index_cols)
    
    for col in ["method", "loss", "batch_mining"]:
        if col == "method" and "parade" in hparam.get(col, None):
            embed[col] = "parade"
            downstream[col] = "parade"
        else:
            embed[col] = hparam.get(col, None)
            downstream[col] = hparam.get(col, None)
    
    embed = embed.set_index(["method", "loss", "batch_mining"], append=True)
    downstream = downstream.set_index(["method", "loss", "batch_mining"], append=True)
    
    embeds = embeds.append(embed)
    downstreams = downstreams.append(downstream)

non_numeric = [col for col in embeds.index.names if col != "job_id"]

os.makedirs(os.path.join(PATH, "tables"), exist_ok = True)
#if "test" in embeds["split"]:
#    embeds = embeds[embeds["split"] == "test"].droplevel("split")
#else:
#    embeds = embeds[embeds["split"] == embeds["split"].unique().item()].droplevel("split")
embeds_means = embeds.groupby(non_numeric).mean()
embeds_std = embeds.groupby(non_numeric).std()
embeds_means.to_csv(os.path.join(PATH, "tables", "embeds_means_{}.csv".format(relevant_attribute)), sep=",", header=True, index=True)
embeds_std.to_csv(os.path.join(PATH, "tables", "embeds_std_{}.csv".format(relevant_attribute)), sep=",", header=True, index=True)

non_numeric = [col for col in downstreams.index.names if col != "job_id"]

#model = downstream["model"].unique().item()
#downstreams = downstreams[downstreams["model"] == model].droplevel("model")
downstreams_means = downstreams.groupby(non_numeric).mean()
downstreams_std = downstreams.groupby(non_numeric).std()
downstreams_means.to_csv(os.path.join(PATH, "tables", "downstreams_means_{}.csv".format(relevant_attribute)), sep=",", header=True, index=True)
downstreams_std.to_csv(os.path.join(PATH, "tables", "downstreams_std_{}.csv".format(relevant_attribute)), sep=",", header=True, index=True)
