# -*- coding: utf-8 -*-
import argparse
import glob
import os

import pandas as pd

def save_list(data, save_path, _type='excel'):
    df = pd.DataFrame(data=data)
    for column in df.columns:
        if df[column].dtype == object:
            df[column] = df[column].astype(float)
    if _type == 'excel':
        df.to_excel(save_path, index=False)


def merge_excels(source_dir):
    # show_hps = {'lr', 'wd', 'batch_size', 'base_net'}
    excels = []
    all_info = []
    val_cols = None
    for dir in glob.glob(f"{source_dir}/*"):
        if not os.path.isdir(dir):
            continue
        excel = os.path.join(dir, 'metric_scores.xlsx')
        if not os.path.exists(excel):
            print(dir)
            log_to_excel(dir)
        try:
            df = pd.read_excel(excel, engine='openpyxl')
        except BaseException:
            print("error excel:", excel)
        else:
            if val_cols is None:
                val_cols = set(df.columns)
            else:
                val_cols = set(df.columns) & val_cols
            excels.append(excel)
    val_cols = sorted(list(val_cols))
    for excel in excels:
        hp_str = os.path.basename(os.path.dirname(excel))
        df = pd.read_excel(excel, engine='openpyxl')
        df = df[val_cols]
        df['hyper-parameters'] = hp_str
        all_info.append(df)
    pd.concat(all_info).to_excel(os.path.join(source_dir, 'metric_scores.xlsx'), index=False)

def log_to_excel(log_path):
    with open(os.path.join(log_path, 'stdout.log'), "r") as f:
        lines = f.readlines()
    
    epochs = len(glob.glob(f"{log_path}/*.pkl"))
    if epochs == 0:
        return
    all_metric_scores = [{} for _ in range(epochs)]
    for line in lines:
        if "epoch: " not in line:
            continue
        epoch = int(line.split("epoch: ")[-1].split("/10, metric")[0])
        all_metric_scores[epoch]["train_epoch"] = epoch
        metric_name = line.split("/10, metric ")[-1].split(": ")[0]
        all_metric_scores[epoch][metric_name] = float(line.split("/10, metric "+metric_name+": ")[-1].strip("\n"))
    save_list(all_metric_scores, os.path.join(log_path, 'metric_scores.xlsx'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='merge_excels')
    parser.add_argument('--output_path', default=None, nargs='+', 
                        help="The log dir")
    args, unknown = parser.parse_known_args()
    for output_path in args.output_path:
        merge_excels(output_path)