
import os
from pathlib import Path
import json
import pandas
import glob
import pandas  as pd
import matplotlib.pyplot as plt
import numpy as np
import ast
from collections import defaultdict
import sys
from datetime import datetime

# Transforms strings to python objects. e.g. "True" into the value True
def string_to_python(string):
    try:
        return ast.literal_eval(string)
    except:
        return string

# Reads experiment name (e.g. "meta-l_mp=2-stage=skipsum-agg=add-us=20-action=edge-meta_nl=2-rw=loss_clf_pred_penalty-ps=0.0-ps=False-mel=1-dataset=TU_PROTEINS")
# Returns a dict: {'l_mp': 2, 'stage': 'skipsum', ...}
def name_to_dict(run):
    cols = run.split('-')[1:]
    keys, vals = [], []

    i = 0
    n = len(cols)
    while i < n:
        col = cols[i]
        if '=' in col:
            key, val = col.split('=')
            if val == '':
                val = '-' + cols[i+1]
            keys.append(key)
            vals.append(string_to_python(val))
        else:
            print('Skip')
        i+=1

    return dict(zip(keys, vals))

# Given list of dicts and kwarsgs, filter this list by matching the passed key-value pairs with the dict key-value entries
# e.g. filter_by(data, dataset="TU_BZR", stage='skipsum')
def filter_by(data, **kwargs):

    data_filtered = data
    for key,value in kwargs.items():
        if key == 'dataset' and value=='all':
            continue
        data_filtered = list(filter(lambda x: x[key] == value, data_filtered))

        #if top_k_by_acc != None:
    #    if type(top_k_by_acc) == float:
    #        top_k_by_acc = int(len(intermediate)*top_k_by_acc)
    #    intermediate = sorted(intermediate, key=lambda d: d["accuracy"].max(), reverse=True)[:top_k_by_acc]

    return data_filtered

# agg across grid search
def agg_batch(dir, metric_name  = 'accuracy', maximize_metric=True,
              policy='best',
              num_splits = 3):



    print("Starting!", datetime.now().strftime("%H:%M:%S"))

    # all_data: dict of experiment name (e.g. "rnd_clf") to dict of split ("train") to list of dicts that have both the configs (including seed) and the results (best accuracy, etc)
    all_data = {}

    splits_all = ['train', "val", 'test'] # only putting "val" still adds test_acc to the dicts. Only add "test" here if you really need something from the test results.

    splits = splits_all[:num_splits]

    if num_splits == 3:
        split_val = 'val'
    else:
        split_val = 'train'

    print()

    experiment_path =  dir
    folder = os.path.basename(dir)
    print(f"Parsing folder {experiment_path}...")
    print(f'Metric name: {metric_name} Maximize: {maximize_metric}')
    print(f'Splits: {splits} Split validation: {split_val}')

    if not Path(experiment_path).is_dir():
        print("Not a folder -> skipping")
        print()
        return

    configs = os.listdir(experiment_path)
    print(f"Found {len(configs)} configs...")

    all_data[folder] = []

    for idx, cfg in enumerate(configs):
        if idx % 100 == 0: print(f"Parsing {idx}/{len(configs)}")
        if cfg == "agg": # skip aggregation folder
            continue

        cfg_fullpath = os.path.join(experiment_path, cfg)

        if not os.path.isdir(cfg_fullpath):
            continue
        seeds = os.listdir(f"{cfg_fullpath}")

        for s in seeds:
            if s == "agg": # skip aggregation folder
                continue
            s = int(s)

            filepath_dict = {}


            all_files_found = True
            for split in splits:
                filepath_dict[split] = os.path.join(cfg_fullpath, str(s), split, 'stats.json')

                if not os.path.exists(filepath_dict[split]):
                    print(f"[ERROR] Filepath not found: {filepath_dict[split]}")
                    all_files_found = False

            if not all_files_found:
                continue

            config_dict = name_to_dict(cfg)

            config_dict.update({'seed': s})

            df_dict = {}
            for split, filepath in filepath_dict.items():
                df_dict[split] = pd.read_json(filepath, lines=True)


            df_val = df_dict[split_val]
            if policy in ['early5', 'early10']:
                patience = int(policy.replace('early', ''))
                metric_epochs = df_val[metric_name]
                metric_diff = metric_epochs.iloc[1:].values - metric_epochs.iloc[:-1].values
                best_epoch_idx_list = np.where(metric_diff < 0)[0]  if maximize_metric else np.where(metric_diff  > 0)[0]
                for best_idx in best_epoch_idx_list:
                    best_next = metric_epochs.iloc[best_idx:(best_idx + patience)]
                    fn = best_next.argmax if maximize_metric else best_next.argmin
                    if fn() == 0: break

                if len(best_epoch_idx_list) == 0: best_idx = 0
                best_epoch =  df_val.iloc[best_idx]['epoch']

            else:
                lower_to_larger = df_val[metric_name].argsort().values
                if maximize_metric:
                    best_to_worse_idx = lower_to_larger[::-1]
                else:
                    best_to_worse_idx = lower_to_larger

                for best_idx in best_to_worse_idx:
                    best_epoch = df_val.iloc[best_idx]['epoch']
                    if best_epoch > 0: break

            config_dict.update({"best_epoch": best_epoch})

            for split in splits:
                df_i = df_dict[split]
                if policy == 'last':
                    split_i_dict = df_i.iloc[-1].to_dict()
                    split_i_dict = {f"{key}_{split}": value for key, value in split_i_dict.items()}
                    config_dict.update(split_i_dict)
                else:
                    cond =df_i['epoch'] == best_epoch
                    if cond.sum() == 1:
                        split_i_dict = df_i.loc[cond].iloc[0].to_dict()
                        split_i_dict = {f"{key}_{split}": value for key, value in split_i_dict.items()}
                        config_dict.update(split_i_dict)
                    else:
                        print(f"[ERROR] Too many best epochs [{best_epoch}]: {cond.sum()}")
                        all_files_found = False


            if not all_files_found:
                continue

            config_dict['config_folder'] = cfg_fullpath
            all_data[folder].append(config_dict)

    print()
    print("Done!", datetime.now().strftime("%H:%M:%S"))
    df = pd.DataFrame(all_data[folder])
    if policy != 'best':
        df.to_pickle(os.path.join(dir, f'results_{policy}_{metric_name}.pkl'))
    else:
        df.to_pickle(os.path.join(dir, f'results_{metric_name}.pkl'))


## test
# dir = 'results/nc_example'
# agg_runs(dir)



