import pandas as pd
import sys
import argparse
from parsers.config import get_constraint_config
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sch_specific = 'poly-fixed'

##
## Argument Parser 
##
parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--constraint", type=str)

args = parser.parse_args(sys.argv[1:])

if args.constraint in ["valency", "atomCount", "molWeight", "regression"]:
    constraint_config = f"{args.constraint}_{args.dataset}"
else:
    constraint_config = f"{args.constraint}"

constr_config = get_constraint_config (constraint_config)

log_dir = f"logs_sample/{args.dataset}/test/{constr_config.constraint}/"

results_df = pd.read_csv (f"all_results/{args.dataset}/{args.constraint}/all_results.csv")

# res_none = results_df[results_df["method_op"] == "None"]
res_proj = results_df[results_df["method_op"] == "proj"]
res_proj["fcd_test"] = (res_proj["fcd_test"] - res_proj["fcd_test"].min())/(res_proj["fcd_test"].max() - res_proj["fcd_test"].min())

# if args.constraint == "valency":
#     res_dfg = res_proj.groupby("valency")
# elif args.constraint == "atomCount":
#     res_dfg = res_proj.groupby("atomCount")
res_dfg = res_proj.groupby("params")

# mmd_vars = ["degree", "cluster", "orbit", "spectral"]
mol_vars = ["validity_wo_corr", "valid", "unique", "novelty", "nspdk_mmd", "fcd_test"]
try:
    print ("preround")
    value_vars = mol_vars + ["constraint_val", "constraint_val_preround"]
except:
    value_vars = mol_vars + ["constraint_val"]

id_vars = [col for col in results_df.columns if col not in value_vars]

if sch_specific == 'poly-fixed':
    res_schs = res_proj.loc[res_proj["schedule"].isin(['fixed', 'poly']) & (res_proj["add_diff_steps"] == 0)]
    res_schs.loc[res_schs["schedule"] == "fixed", "schedule_params"] = ['gamma0' if int(x) == 0 else 'gamma1' for x in res_schs.loc[res_schs["schedule"] == "fixed"]["schedule_params"]]
    res_schs = res_schs.loc[res_schs["schedule_params"].isin(['0|1', '0|3', '0|5', '0.1|1', '0.1|3', '0.1|5', '0.0001|1', '0.0001|3', '0.0001|5', 'gamma0', 'gamma1'])]
    res_schs = res_schs.melt (id_vars=id_vars, value_vars=value_vars)
    if args.dataset == "qm9":
        atoms = ["C", "N", "O", "F"]
        if (args.constraint == 'valency'):
            res_schs["params"] = [''.join([x + y for x, y in zip(atoms, z.split('v'))])  for z in res_schs["params"]]
        elif (args.constraint == 'atomCount'):
            res_schs["params"] = [''.join([x + y for x, y in zip(atoms, z.split('c'))])  for z in res_schs["params"]]
    elif args.dataset == "zinc250k":
        atoms = ["C", "N", "O", "F", 'P', 'S', 'Cl', 'Br', 'I']
        if (args.constraint == 'valency'):
            res_schs["params"] = [''.join([x + y for x, y in zip(atoms, z.split('v'))])  for z in res_schs["params"]]
        elif (args.constraint == 'atomCount'):
            res_schs["params"] = [''.join([x + y for x, y in zip(atoms, z.split('c'))])  for z in res_schs["params"]]
    grid = sns.FacetGrid(res_schs, col="variable", row="add_diff_steps", hue="schedule_params", margin_titles=True)
    # res_schs["poly_init"] = [sch_params.split("|")[0] if (sch == 'poly') else '0' 
    #                             for sch, sch_params in zip(res_schs["schedule"], res_schs["schedule_params"])]
    # res_schs["poly_pow"] = [sch_params.split("|")[1] if (sch == 'poly') else ('inf' if int(sch_params) == 0 else '0') 
    #                             for sch, sch_params in zip(res_schs["schedule"], res_schs["schedule_params"])]
    # res_schs = res_schs.melt (id_vars=id_vars + ["poly_init", "poly_pow"], value_vars=value_vars)
    # grid = sns.FacetGrid(res_schs, col="variable", row="poly_init", hue="poly_pow", margin_titles=True)
    grid.map_dataframe (sns.lineplot, x="params", y="value", marker='o', markersize=10)
    grid.add_legend()
    grid.set_titles(col_template="{col_name}", row_template="{row_name}")
    grid.tight_layout()
    grid.set_xticklabels(rotation=90)
    grid.savefig(f"all_results/{args.dataset}/{args.constraint}/summarized_results_new.pdf")
else:
    for b in res_dfg.groups.keys():
        # res_none_b = res_none.loc[res_none["budget"] == b].loc[0]
        res_df_b = res_dfg.get_group(b)
        for sch in np.unique(res_df_b["schedule"]):
            if sch == 'poly':
                for rounding in np.unique(res_df_b["rounding"]):
                    res_df_b_sch = res_df_b.loc[(res_df_b["schedule"] == sch) & (res_df_b["rounding"] == rounding)]
                    res_df_b_sch["poly_init"] = [float(x.split("|")[0]) for x in res_df_b_sch["schedule_params"]]
                    res_df_b_sch["poly_pow"] = [int(x.split("|")[1]) for x in res_df_b_sch["schedule_params"]]
                    res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars + ["poly_init", "poly_pow"], value_vars=value_vars)
                    grid = sns.FacetGrid(res_df_b_sch, row="add_diff_steps", col="poly_init", 
                                        hue="variable", margin_titles=True)
                    grid.map_dataframe (sns.lineplot, x="poly_pow", y="value", marker='o', markersize=10)
                    grid.add_legend()
                    grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                    grid.tight_layout()
                    grid.savefig(f"all_results/{args.dataset}/{args.constraint}/new_summ_{b}_{sch}_{rounding}.pdf")
            elif sch == 'batched' or sch == 'cyclical':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                res_df_b_sch["cycle_size"] = [int(x.split("|")[0]) for x in res_df_b_sch["schedule_params"]]
                res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars + ["cycle_size"], value_vars=value_vars)
                grid = sns.FacetGrid(res_df_b_sch, row="add_diff_steps", col="cycle_size", 
                                    hue="variable", margin_titles=True)
                grid.map_dataframe (sns.lineplot, x="method_gamma", y="value", marker='o', markersize=10)
                grid.add_legend()
                grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                grid.tight_layout()
                grid.savefig(f"all_results/{args.dataset}/{args.constraint}/new_summ_{b}_{sch}.pdf")
            elif sch == 'polymid':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                for rounding in np.unique(res_df_b_sch["rounding"]):
                    res_df_b_sch_rnd = res_df_b_sch.loc[res_df_b_sch["rounding"] == rounding]
                    res_df_b_sch_rnd["polymid_init"] = [float(x.split("|")[0]) for x in res_df_b_sch_rnd["schedule_params"]]
                    res_df_b_sch_rnd["polymid_pow"] = [int(x.split("|")[1]) for x in res_df_b_sch_rnd["schedule_params"]]
                    res_df_b_sch_rnd["polymid_onestep"] = [int(x.split("|")[2]) for x in res_df_b_sch_rnd["schedule_params"]]
                    res_df_b_sch_rnd = res_df_b_sch_rnd.melt (id_vars=id_vars + ["polymid_init", "polymid_pow", "polymid_onestep"], 
                                                    value_vars=value_vars)
                    grid = sns.FacetGrid(res_df_b_sch_rnd, row="polymid_onestep", col="polymid_init", 
                                        hue="variable", margin_titles=True)
                    grid.map_dataframe (sns.lineplot, x="polymid_pow", y="value", marker='o', markersize=10)
                    grid.add_legend()
                    grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                    grid.tight_layout()
                    grid.savefig(f"all_results/{args.dataset}/{args.constraint}/new_summ_{b}_{rounding}_{sch}.pdf")
            elif sch == 'fixed':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                res_df_b_sch["gamma"] = res_df_b_sch["schedule_params"]
                res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars + ["gamma"], value_vars=value_vars)
                grid = sns.FacetGrid(res_df_b_sch, row="add_diff_steps", hue="variable", margin_titles=True)
                grid.map_dataframe (sns.lineplot, x="gamma", y="value", marker='o', markersize=10)
                grid.add_legend()
                grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                grid.tight_layout()
                grid.savefig(f"all_results/{args.dataset}/{args.constraint}/new_summ_{b}_{sch}.pdf")