"""
Goal
---
1. Read test results from log.txt files
2. Compute mean and std across different folders (seeds)

Usage
---
Assume the output files are saved under output/my_experiment,
which contains results of different seeds, e.g.,

my_experiment/
    seed1/
        log.txt
    seed2/
        log.txt
    seed3/
        log.txt

Run the following command from the root directory:

$ python tools/parse_test_res.py output/my_experiment

Add --ci95 to the argument if you wanna get 95% confidence
interval instead of standard deviation:

$ python tools/parse_test_res.py output/my_experiment --ci95

If my_experiment/ has the following structure,

my_experiment/
    exp-1/
        seed1/
            log.txt
            ...
        seed2/
            log.txt
            ...
        seed3/
            log.txt
            ...
    exp-2/
        ...
    exp-3/
        ...

Run

$ python tools/parse_test_res.py output/my_experiment --multi-exp
"""
import re
import numpy as np
import os.path as osp
import argparse
from collections import OrderedDict, defaultdict
import pandas as pd
from dassl.utils import check_isfile, listdir_nohidden
import os
import json

def compute_ci95(res):
    return 1.96 * np.std(res) / np.sqrt(len(res))

def results_to_csv(args, directory, key, results):
    if 'train_base' in directory or 'test_new' in directory:
        base2new_results_to_csv(args, directory, key, results)
    elif 'xd_test' in directory or 'xd_train' in directory:
        xd_results_to_csv(args, directory, key, results)


def base2new_results_to_csv(args, directory, key, results):

    parts = directory.split("/")
    split = parts[2]
    dataset = parts[3]
    shot = int(parts[4].split("_")[1])
    algorithm = parts[5]
    cfgs = parts[6]
    try:
        reg_dir = parts[7]
    except IndexError:
        reg_dir = None

    if args.calibration_config:
        calibration_cfgs = json.loads(args.calibration_config)

        if calibration_cfgs['BASE_CALIBRATION_MODE']:

            if calibration_cfgs['SCALING_CONFIG']:
                algorithm = algorithm + '+' + calibration_cfgs['SCALING_CALIBRATOR_NAME']
            if calibration_cfgs['BIN_CALIBRATOR_NAME']:
                algorithm = algorithm + '+' + calibration_cfgs['BIN_CALIBRATOR_NAME']
        if calibration_cfgs['IF_DAC']:
            algorithm = algorithm + '+DAC'
        if calibration_cfgs['IF_PROCAL']:
            algorithm = algorithm + '+ProCal'

    # create a dataframe to store the data
    df = pd.DataFrame({
        "dataset": [dataset],
        "split": [split],
        "shot": [shot],
        "algorithm": [algorithm],
        "cfgs": [cfgs],
        "ood_reg": [str(reg_dir)],
        "metrics": [key],
        "results": [results]
    })

    csv_file = "output/base2new/logs_base2new.csv"

    # check and save the data
    if os.path.exists(csv_file):
        # concat former and current data
        existing_df = pd.read_csv(csv_file)
        updated_df = pd.concat([existing_df, df], ignore_index=True)
        updated_df.to_csv(csv_file, index=False)
    else:
        # if not, create it
        df.to_csv(csv_file, index=False)


def xd_results_to_csv(args, directory, key, results):

    parts = directory.split("/")
    split = parts[2]
    algorithm = parts[3]
    cfgs = parts[4]
    dataset = parts[5]

    if args.calibration:
        algorithm = algorithm + '+' + args.calibration
    

    df = pd.DataFrame({
        "dataset": [dataset],
        "split": [split],
        "algorithm": [algorithm],
        "cfgs": [cfgs],
        "metrics": [key],
        "results": [results]
    })

    csv_file = "output/xd/logs_xd.csv"

    if os.path.exists(csv_file):
        existing_df = pd.read_csv(csv_file)
        updated_df = pd.concat([existing_df, df], ignore_index=True)
        updated_df.to_csv(csv_file, index=False)
    else:
        df.to_csv(csv_file, index=False)





def parse_function(*metrics, directory="", args=None, end_signal=None):
    print(f"Parsing files in {directory}")
    subdirs = listdir_nohidden(directory, sort=True)

    outputs = []

    for subdir in subdirs:
        
        base_path = osp.join(directory, subdir)
        base_name = 'log'

        if args.calibration_config:
            calibration_cfgs = json.loads(args.calibration_config)
            if calibration_cfgs['BASE_CALIBRATION_MODE']:

                if calibration_cfgs['SCALING_CONFIG']:
                    base_name = base_name + '_' + calibration_cfgs['SCALING_CALIBRATOR_NAME']
                if calibration_cfgs['BIN_CALIBRATOR_NAME']:
                    base_name = base_name + '_' + calibration_cfgs['BIN_CALIBRATOR_NAME']

            if calibration_cfgs['IF_DAC']:
                base_name = base_name + '_dac'
            if calibration_cfgs['IF_PROCAL']:
                base_name = base_name + '_procal'


        base_name = base_name + '.txt'
        fpath = osp.join(base_path, base_name)
        print(fpath)

        assert check_isfile(fpath)
        good_to_go = False
        output = OrderedDict()

        with open(fpath, "r") as f:
            lines = f.readlines()

            for line in lines:
                line = line.strip()

                if line == end_signal:
                    good_to_go = True

                for metric in metrics:
                    match = metric["regex"].search(line)
                    if match and good_to_go:
                        if "file" not in output:
                            output["file"] = fpath
                        num = float(match.group(1))
                        name = metric["name"]
                        output[name] = num

        if output:
            outputs.append(output)

    assert len(outputs) > 0, f"Nothing found in {directory}"

    metrics_results = defaultdict(list)

    for output in outputs:
        msg = ""
        for key, value in output.items():
            if isinstance(value, float):
                msg += f"{key}: {value:.2f}%. "
            else:
                msg += f"{key}: {value}. "
            if key != "file":
                metrics_results[key].append(value)
        print(msg)

    output_results = OrderedDict()

    print("===")
    print(f"Summary of directory: {directory}")
    for key, values in metrics_results.items():
        avg = np.mean(values)
        std = compute_ci95(values) if args.ci95 else np.std(values)
        print(f"* {key}: {avg:.2f}% +- {std:.2f}%")
        output_results[key] = avg
        # results_to_csv(directory, key, f"{avg:.2f}% +- {std:.2f}%")
        results_to_csv(args, directory, key, f"{avg:.2f}")
    print("===")

    return output_results


def main(args, end_signal):
    metric = {
        "name": args.keyword,
        "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"),
    }

    if args.multi_exp:
        final_results = defaultdict(list)

        for directory in listdir_nohidden(args.directory, sort=True):
            directory = osp.join(args.directory, directory)
            results = parse_function(
                metric, directory=directory, args=args, end_signal=end_signal
            )

            for key, value in results.items():
                final_results[key].append(value)

        print("Average performance")
        for key, values in final_results.items():
            avg = np.mean(values)
            print(f"* {key}: {avg:.2f}%")

    else:
        parse_function(
            metric, directory=args.directory, args=args, end_signal=end_signal
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("directory", type=str, help="path to directory")
    parser.add_argument(
        "--ci95", action="store_true", help=r"compute 95\% confidence interval"
    )
    parser.add_argument("--test-log", action="store_true", help="parse test-only logs")
    parser.add_argument(
        "--multi-exp", action="store_true", help="parse multiple experiments"
    )
    parser.add_argument(
        "--keyword", default="accuracy", type=str, help="which keyword to extract"
    )
    parser.add_argument(
        "--calibration-config", default=False, type=str, help="load task difficult aware log or not"
    )
    args = parser.parse_args()

    end_signal = "Finished training"
    if args.test_log:
        end_signal = "=> result"

    main(args, end_signal)
