from collections import defaultdict
import pandas as pd
import os
import sys
import argparse
from parsers.config import get_config, get_constraint_config
import ast
import re

##
## Argument Parser 
##
parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--constraint", type=str)

args = parser.parse_args(sys.argv[1:])

dataset_config= f"sample_{args.dataset}"

if args.constraint in ["valency", "atomCount", "molWeight", "regression"]:
    constraint_config = f"{args.constraint}_{args.dataset}"
else:
    constraint_config = f"{args.constraint}"

config = get_config (dataset_config, seed=42)
constr_config = get_constraint_config (constraint_config)

log_dir = f"logs_sample/{config.data.data}/test/{constr_config.constraint}/"
# log_name = f"{config.ckpt}-{constr_config.method.op}-{constr_config.method.gamma}"
# log_name += f"-{constr_config.add_diff_step}"
# param_vals = map (str, constr_config.params.values())
# log_name += f"-{','.join(param_vals)}-{constr_config.rounding}"
# log_name = log_name.replace(".", "p")

results_df = {"method_op": [], "method_gamma": [], "add_diff_steps": [], "params": [],
              "rounding": [], "constraint_val": [], "constraint_val_preround": [], #time: [] 
              "validity_wo_corr": [], "valid": [], "unique": [], "fcd_test": [], "novelty": [], 
              "nspdk_mmd": [], "burnin": [], "schedule": [], "schedule_params": []}

for log_file in [f for f in os.listdir (log_dir) if f.endswith(".log")]:
    log_df = {res_key: None for res_key in results_df}
    try:
        attrs = log_file[:-4].split ("-")
        # _, dataset = attrs[0].split ("_", 1)
        method_op, method_gamma, schedule, burnin, add_diff_steps, params, rounding = attrs[1:]
        if method_op == "None":
            log_df["method_op"] = "proj"
            log_df["method_gamma"] = 0.0
        else:
            log_df["method_op"] = method_op
            log_df["method_gamma"] = float(method_gamma.replace("p", "."))
        log_df["add_diff_steps"] = int(add_diff_steps)
        log_df["params"] = params.replace("p", ".") #.replace(",", "|")
        log_df["rounding"] = rounding
        log_df["burnin"] = int(burnin)
        split_id = re.search(r'\d', schedule).start()
        sch_name, sch_params = schedule[:split_id], schedule[split_id:]
        log_df["schedule"] = sch_name
        log_df["schedule_params"] = sch_params.replace("p", ".").replace(",", "|")
    except:
        print (log_dir, log_file, "has errors.")
        pass
    with open (f"{log_dir}/{log_file}", 'r') as log_f:
        for line in log_f:
            line = line[:-1]
            if line.startswith("Constraint Validity:"):
                log_df["constraint_val"] = float(line.split(": ")[1])
            elif line.startswith("Constraint Validity before round:"):
                log_df["constraint_val_preround"] = float(line.split(": ")[1])
            elif line.startswith("validity w/o correction:"):
                log_df["validity_wo_corr"] = float(line.split(": ")[1])
            elif line.startswith("valid:"):
                log_df["valid"] = float(line.split(": ")[1])
            elif line.startswith("unique@"):
                log_df["unique"] = float(line.split(": ")[1])
            elif line.startswith("FCD/Test:"):
                log_df["fcd_test"] = float(line.split(": ")[1])
            elif line.startswith("Novelty:"):
                log_df["novelty"] = float(line.split(": ")[1])
            elif line.startswith("NSPDK MMD:"):
                log_df["nspdk_mmd"] = float(line.split(": ")[1])

    # print (log_df)

    if len([1 for res_log in log_df.values() if res_log is None]) == 0:
        if method_op == "None":
            for rounding in ["none", "randomized", "repeated"]:
                log_df["rounding"] = rounding
                for res_key in results_df:
                    results_df[res_key].append(log_df[res_key])
        else:
            for res_key in results_df:
                results_df[res_key].append(log_df[res_key])
    else:
        print (log_dir, log_file, "is not complete.")

results_df = pd.DataFrame(results_df)
try:
    os.makedirs(f"all_results/{args.dataset}/{args.constraint}")
except:
    pass

if args.constraint == 'valency':
    results_df["params"] = ['v'.join(map(str, ast.literal_eval(param))) for param in results_df["params"]]
elif args.constraint == 'atomCount':
    results_df["params"] = ['c'.join(map(str, ast.literal_eval(param))) for param in results_df["params"]]
elif args.constraint == 'molWeight':
    # only store the max mol weight
    results_df["params"] = [float(param.split("],")[1]) for param in results_df["params"]]

results_df.to_csv(f"all_results/{args.dataset}/{args.constraint}/all_results.csv", index=False)
with open(f"all_results/{args.dataset}/{args.constraint}/all_results.tex", "w+") as wf:
    wf.write(results_df.to_latex(index=False))
