import pandas as pd
import seaborn as sns
import os

mmd_vars = ["degree", "cluster", "orbit", "spectral"]
mol_qual_vars = ["valid", "unique", "fcd", "novelty", "nspdk"]

mmd_vars = ["degree", "cluster", "orbit", "spectral"]

def plot_dataCons_wise (df):
    dfg = df.groupby(['dataset', 'constraint'])
    for dataset, constraint in dfg.groups.keys():
        df_dc = dfg.get_group ((dataset, constraint))
        if 'fcd' in df_dc.columns:
            df_dc['fcd'] = (df_dc['fcd'] - min(df_dc['fcd']))/(max(df_dc['fcd']) - min(df_dc['fcd']))
        value_vars = mmd_vars + ['constr_val'] if 'valid' not in df.columns else mol_qual_vars + ['constr_val']
        id_vars = [col for col in df_dc.columns if col not in value_vars]
        df_dc = df_dc.melt (id_vars=id_vars, value_vars=value_vars)
        grid = sns.FacetGrid(df_dc, col="variable", row="method", hue="method_param", margin_titles=True)
        grid.map_dataframe (sns.lineplot, x="param", y="value", marker='o', markersize=10)
        grid.add_legend()
        grid.set_titles(col_template="{col_name}", row_template="{row_name}")
        grid.tight_layout()
        if constraint in ['valency', 'atomCount']:
            grid.set_xticklabels(rotation=90)
        print (dataset, constraint)
        try:
            os.makedirs(f"all_results/{dataset}/{constraint}")
        except:
            pass
        grid.savefig(f"all_results/{dataset}/{constraint}/all_newResults_legend1.pdf")

def plot_dataCons_wise2 (df):
    df["method_param"] = [x + y for x, y in zip (df["method"], df["method_param"])]
    dfg = df.groupby(['dataset', 'constraint'])
    for dataset, constraint in dfg.groups.keys():
        df_dc = dfg.get_group ((dataset, constraint))
        if 'fcd' in df_dc.columns:
            df_dc['fcd'] = (df_dc['fcd'] - min(df_dc['fcd']))/(max(df_dc['fcd']) - min(df_dc['fcd']))
        value_vars = mmd_vars + ['constr_val'] if 'valid' not in df.columns else mol_qual_vars + ['constr_val']
        id_vars = [col for col in df_dc.columns if col not in value_vars]
        df_dc = df_dc.melt (id_vars=id_vars, value_vars=value_vars)
        grid = sns.FacetGrid(df_dc, col="variable", hue="method_param", margin_titles=True)
        grid.map_dataframe (sns.lineplot, x="param", y="value", marker='o', markersize=10)
        grid.add_legend()
        grid.set_titles(col_template="{col_name}", row_template="{row_name}")
        grid.tight_layout()
        if constraint in ['valency', 'atomCount']:
            grid.set_xticklabels(rotation=90)
        print (dataset, constraint)
        try:
            os.makedirs(f"all_results/{dataset}/{constraint}")
        except:
            pass
        grid.savefig(f"all_results/{dataset}/{constraint}/all_newResults_legend2.pdf")

def plot_cons_wise1 (df):
    df["method_param"] = [x + y for x, y in zip (df["method"], df["method_param"])]
    dfg = df.groupby('constraint')
    for constraint in dfg.groups.keys():
        df_dc = dfg.get_group (constraint)
        value_vars = mmd_vars + ['constr_val'] if 'valid' not in df.columns else mol_qual_vars + ['constr_val']
        id_vars = [col for col in df_dc.columns if col not in value_vars]
        df_dc = df_dc.melt (id_vars=id_vars, value_vars=value_vars)
        grid = sns.FacetGrid(df_dc, col="variable", row="dataset", hue="method_param", margin_titles=True)
        grid.map_dataframe (sns.lineplot, x="param", y="value", marker='o', markersize=10)
        grid.add_legend()
        grid.set_titles(col_template="{col_name}", row_template="{row_name}")
        grid.tight_layout()
        if constraint in ['valency', 'atomCount']:
            grid.set_xticklabels(rotation=90)
        try:
            os.makedirs(f"all_results/{constraint}")
        except:
            pass
        print (constraint)
        grid.savefig(f"all_results/{constraint}/all_newResults_legend2.pdf")

df = pd.read_csv("all_results/all_results_bisect3.csv")
plot_dataCons_wise (df)
print ("-------")
plot_dataCons_wise2 (df)
print ("-------")
plot_cons_wise1 (df)
print ("-------")

# dataset,constraint,param,method,method_param,constr_val
df = pd.read_csv("all_results/all_mol_results_bisect3.csv")
plot_dataCons_wise (df)
print ("-------")
plot_dataCons_wise2 (df)
print ("-------")
plot_cons_wise1 (df)
print ("-------")

# df_mol = pd.read_csv("all_results/all_esults_mol.csv")
# plot_dataCons_wise (df_mol)
# plot_dataCons_wise2 (df_mol)
# plot_cons_wise1 (df_mol)
