import pandas as pd
import sys
import argparse
from parsers.config import get_constraint_config
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sch_specific = 'poly-fixed'

##
## Argument Parser 
##
parser = argparse.ArgumentParser()

parser.add_argument("--dataset", type=str)
parser.add_argument("--constraint", type=str)

args = parser.parse_args(sys.argv[1:])

if args.constraint in ["valency", "atomCount", "molWeight", "regression"]:
    constraint_config = f"{args.constraint}_{args.dataset}"
else:
    constraint_config = f"{args.constraint}"

constr_config = get_constraint_config (constraint_config)

log_dir = f"logs_sample/{args.dataset}/test/{constr_config.constraint}/"

results_df = pd.read_csv (f"all_results/{args.dataset}/{args.constraint}/all_results.csv")

# res_none = results_df[results_df["method_op"] == "None"]
res_proj = results_df.loc[results_df["method_op"] == "proj"]

# if args.constraint == "budget":
#     res_dfg = res_proj.groupby("budget")
# elif args.constraint == "lmax":
#     res_dfg = res_proj.groupby("lmax_ub")

# res_proj = res_proj.loc[res_proj["schedule"].isin('fixed', 'poly')]
# res_proj['schedule'] = ['polygen'] * len(res_proj)
# res_proj.loc[res_proj["schedule"] == 'fixed']['schedule_params'] = ['0|inf' if x == 0 else '0|0' for x in res_proj.loc[res_proj["schedule"] == 'fixed']['schedule_params']]

res_dfg = res_proj.groupby("params")

mmd_vars = ["degree", "cluster", "orbit", "spectral"]

try:
    print ("preround")
    value_vars = mmd_vars + ["constraint_val", "constraint_val_preround"]
except:
    value_vars = mmd_vars + ["constraint_val"]

id_vars = [col for col in results_df.columns if col not in value_vars]


if sch_specific == 'poly-fixed':
    res_schs = res_proj.loc[res_proj["schedule"].isin(['fixed', 'poly'])] # & (res_proj["add_diff_steps"] == 0)]
    res_schs.loc[res_schs["schedule"] == "fixed", "schedule_params"] = ['gamma0' if int(x) == 0 else 'gamma1' for x in res_schs.loc[res_schs["schedule"] == "fixed"]["schedule_params"]]
    res_schs = res_schs.loc[res_schs["schedule_params"].isin(['0|1', '0|3', '0|5', '0.1|1', '0.1|3', '0.1|5', '0.0001|1', '0.0001|3', '0.0001|5', 'gamma0', 'gamma1'])]
    res_schs = res_schs.melt (id_vars=id_vars, value_vars=value_vars)
    grid = sns.FacetGrid(res_schs, col="variable", row="add_diff_steps", hue="schedule_params", margin_titles=True)
    # res_schs["poly_init"] = [sch_params.split("|")[0] if (sch == 'poly') else '0' 
    #                             for sch, sch_params in zip(res_schs["schedule"], res_schs["schedule_params"])]
    # res_schs["poly_pow"] = [sch_params.split("|")[1] if (sch == 'poly') else ('inf' if int(sch_params) == 0 else '0') 
    #                             for sch, sch_params in zip(res_schs["schedule"], res_schs["schedule_params"])]
    # res_schs = res_schs.melt (id_vars=id_vars + ["poly_init", "poly_pow"], value_vars=value_vars)
    # grid = sns.FacetGrid(res_schs, col="variable", row="poly_init", hue="poly_pow", margin_titles=True)
    grid.map_dataframe (sns.lineplot, x="params", y="value", marker='o', markersize=10)
    grid.add_legend()
    grid.set_titles(col_template="{col_name}", row_template="{row_name}")
    grid.tight_layout()
    if args.constraint == 'cheeger':
        grid.set_xticklabels(rotation=90)
    grid.savefig(f"all_results/{args.dataset}/{args.constraint}/summarized_results_new.pdf")
else:
    for b in res_dfg.groups.keys():
        # res_none_b = res_none.loc[res_none["budget"] == b].loc[0]
        res_df_b = res_dfg.get_group(b)
        for sch in set(res_df_b["schedule"]):
            if sch == 'poly':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                res_df_b_sch["poly_init"] = [float(x.split("|")[0]) for x in res_df_b_sch["schedule_params"]]
                res_df_b_sch["poly_pow"] = [int(x.split("|")[1]) for x in res_df_b_sch["schedule_params"]]
                res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars + ["poly_init", "poly_pow"], value_vars=value_vars)
                grid = sns.FacetGrid(res_df_b_sch, row="add_diff_steps", col="poly_init", 
                                    hue="variable", margin_titles=True)
                grid.map_dataframe (sns.lineplot, x="poly_pow", y="value", marker='o', markersize=10)
                grid.add_legend()
                grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                grid.tight_layout()
                grid.savefig(f"all_results/{args.dataset}/{args.constraint}/summ_{b}_{sch}.pdf")
            elif sch == 'polymid':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                res_df_b_sch["polymid_init"] = [float(x.split("|")[0]) for x in res_df_b_sch["schedule_params"]]
                res_df_b_sch["polymid_pow"] = [int(x.split("|")[1]) for x in res_df_b_sch["schedule_params"]]
                res_df_b_sch["polymid_onestep"] = [int(x.split("|")[2]) for x in res_df_b_sch["schedule_params"]]
                res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars + ["polymid_init", "polymid_pow", "polymid_onestep"], 
                                                value_vars=value_vars)
                grid = sns.FacetGrid(res_df_b_sch, row="polymid_onestep", col="polymid_init", 
                                    hue="variable", margin_titles=True)
                grid.map_dataframe (sns.lineplot, x="polymid_pow", y="value", marker='o', markersize=10)
                grid.add_legend()
                grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                grid.tight_layout()
                grid.savefig(f"all_results/{args.dataset}/{args.constraint}/summ_{b}_{sch}.pdf")
            elif sch == 'batched' or sch == 'cyclical':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                res_df_b_sch["cycle_size"] = [int(x.split("|")[0]) for x in res_df_b_sch["schedule_params"]]
                res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars + ["cycle_size"], value_vars=value_vars)
                grid = sns.FacetGrid(res_df_b_sch, row="add_diff_steps", col="cycle_size", 
                                    hue="variable", margin_titles=True)
                grid.map_dataframe (sns.lineplot, x="method_gamma", y="value", marker='o', markersize=10)
                grid.add_legend()
                grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                grid.tight_layout()
                grid.savefig(f"all_results/{args.dataset}/{args.constraint}/summ_{b}_{sch}.pdf")
            elif sch == 'none':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars, value_vars=value_vars)
                grid = sns.FacetGrid(res_df_b_sch, row="add_diff_steps", 
                                    hue="variable", margin_titles=True)
                grid.map_dataframe (sns.lineplot, x="method_gamma", y="value", marker='o', markersize=10)
                grid.add_legend()
                grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                grid.tight_layout()
                grid.savefig(f"all_results/{args.dataset}/{args.constraint}/summ_{b}_{sch}.pdf")
            elif sch == 'fixed':
                res_df_b_sch = res_df_b.loc[res_df_b["schedule"] == sch]
                res_df_b_sch["gamma"] = res_df_b_sch["schedule_params"]
                res_df_b_sch = res_df_b_sch.melt (id_vars=id_vars + ["gamma"], value_vars=value_vars)
                grid = sns.FacetGrid(res_df_b_sch, row="add_diff_steps", hue="variable", margin_titles=True)
                grid.map_dataframe (sns.lineplot, x="gamma", y="value", marker='o', markersize=10)
                grid.add_legend()
                grid.set_titles(col_template="{col_name}", row_template="{row_name}")
                grid.tight_layout()
                grid.savefig(f"all_results/{args.dataset}/{args.constraint}/summ_{b}_{sch}.pdf")
            