from collections import defaultdict
from unittest import result
import pandas as pd
import os
import sys
import argparse
from parsers.config import get_config, get_constraint_config
import ast
import re
import yaml

##
## Argument Parser 
##
parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--constraint", type=str)

args = parser.parse_args(sys.argv[1:])

dataset_config= f"sample_{args.dataset}"

if args.constraint in ["valency", "atomCount", "molweight", "regression"]:
    constraint_config = f"{args.constraint}_{args.dataset}"
else:
    constraint_config = f"{args.constraint}"

config = get_config (dataset_config, seed=42)
constr_config = get_constraint_config (constraint_config)

log_dir = f"logs_sample/{config.data.data}/test/{constr_config.constraint}/"
# log_name = f"{config.ckpt}-{constr_config.method.op}-{constr_config.method.gamma}"
# log_name += f"-{constr_config.add_diff_step}"
# param_vals = map (str, constr_config.params.values())
# log_name += f"-{','.join(param_vals)}-{constr_config.rounding}"
# log_name = log_name.replace(".", "p")

results_df = {"method_op": [], "method_gamma": [], "add_diff_steps": [], "params": [],
              "rounding": [], "constraint_val": [], "time": [], "degree": [], 'cluster': [], 
              "orbit": [], "spectral": [], "constraint_val_preround": [], "burnin": [],
              "schedule": [], "schedule_params": []}

for log_file in [f for f in os.listdir (log_dir) if f.endswith(".log")]:
    log_df = {res_key: None for res_key in results_df}
    try:
        attrs = log_file[:-4].split ("-")
        # _, dataset = attrs[0].split ("_", 1)
        # method_op, method_gamma, add_diff_steps, params, rounding = attrs[1:]
        method_op, method_gamma, schedule, burnin, add_diff_steps, params, rounding = attrs[1:]
        if method_op == "None":
            log_df["method_op"] = "proj"
            log_df["method_gamma"] = 0.0
        else:
            log_df["method_op"] = method_op
            log_df["method_gamma"] = float(method_gamma.replace("p", "."))
        log_df["add_diff_steps"] = int(add_diff_steps)
        log_df["params"] = params.replace("p", ".").replace(",", "|")
        log_df["rounding"] = rounding
        log_df["burnin"] = int(burnin)
        split_id = re.search(r'\d', schedule)
        split_id = split_id.start() if split_id is not None else len(schedule)
        sch_name, sch_params = schedule[:split_id], schedule[split_id:]
        log_df["schedule"] = sch_name if sch_name != 'nan' else 'none'
        log_df["schedule_params"] = sch_params.replace("p", ".").replace(",", "|")
    except:
        print (log_dir, log_file, "has errors.")
        pass
    with open (f"{log_dir}/{log_file}", 'r') as log_f:
        for line in log_f:
            line = line[:-1]
            if line.startswith("Constraint Validity:"):
                log_df["constraint_val"] = float(line.split(": ")[1])
            elif line.startswith("Constraint Validity before round:"):
                log_df["constraint_val_preround"] = float(line.split(": ")[1])
            elif line.startswith("Round"):
                log_df["time"] = float(line.split(": ")[1][:-1])
            elif line.startswith("MMD_full"):
                mmd = ast.literal_eval(line.split(" ", 1)[1])
                for mmd_key, mmd_val in mmd.items():
                    log_df[mmd_key] = mmd_val
    if len([1 for res_log in log_df.values() if res_log is None]) == 0:
        if method_op == "None":
            for rounding in ["none", "randomized", "repeated"]:
                log_df["rounding"] = rounding
                for res_key in results_df:
                    results_df[res_key].append(log_df[res_key])
        else:
            for res_key in results_df:
                results_df[res_key].append(log_df[res_key])
    else:
        print (log_dir, log_file, "is not complete.")

results_df = pd.DataFrame(results_df)
try:
    os.makedirs(f"all_results/{args.dataset}/{args.constraint}")
except:
    pass

# results_df = results_df.loc[results_df["schedule"].isin(['poly'])]
# results_df = results_df.loc[results_df['schedule'] == 'fixed']

# check for cheeger -->v

config_params = yaml.load(open(f"config/constraints/master_{args.constraint}.yaml", 'r'), Loader=yaml.FullLoader)[f"{args.dataset}"]

if args.constraint == 'cheeger':
    results_df = results_df.loc[results_df['params'].apply(lambda x: (list(map(float, x.split('|'))) in config_params))]
else:
    results_df = results_df.loc[results_df['params'].isin(config_params)]

results_df.to_csv(f"all_results/{args.dataset}/{args.constraint}/all_results.csv", index=False)
with open(f"all_results/{args.dataset}/{args.constraint}/all_results.tex", "w+") as wf:
    wf.write(results_df.to_latex(index=False))
