import os
import re
import time
import pandas
import collections
import numpy as np

exp_start_time = time.strptime("2023-09-23-18-55-54", "%Y-%m-%d-%H-%M-%S")
exp_end_time = time.strptime("2023-09-26-11-47-05", "%Y-%m-%d-%H-%M-%S")
exp_cases = [
    "Burgers1d-C", "Burgers2d-C",
    "Poisson2d-C", "Poisson2d-CG", "Poisson3d-CG", "Poisson2d-MS",
    "Heat2d-VC", "Heat2d-MS", "Heat2d-CG",
    "NS2d-C", "NS2d-CG",
    "Wave1d-C", "Wave2d-CG", "Wave2d-MS",
    "Heat2d-LT", "NS2d-LT",
    "GS", "KS"
]

class HashableDict(dict):
    def __hash__(self) -> int:
        return hash(tuple(self.items()))

def read_log(filepath):
    infomations = [("Problem Name", str), 
                #    ("Use Preconditioner", (lambda s:s=='True')), 
                   ("Drop Tolerance (in ILU)", float)]
    result = HashableDict()
    with open(filepath, "r") as f:
        for line in f.readlines():
            for info in infomations:
                if line.startswith(info[0]):
                    result[info[0]] = info[1](line[len(info[0])+2:].strip())
    return result

def process(func):
    ret = collections.defaultdict(list)
    dup = {}
    for path in os.listdir("logs"):
        mobj = re.match(r"([\d-]+)_([-A-Za-z0-9]+)_(\d+)_(\d+)(_\d+)?", path)
        if mobj:
            date = time.strptime(mobj[1], "%Y-%m-%d-%H-%M-%S")
            name = mobj[2]
            repeat = int(mobj[3])

            if exp_start_time <= date and date <= exp_end_time:
                log = read_log(os.path.join("logs", path, "run.log"))
                if dup.get((log, repeat)):
                    print(f"\033[33mWarning: Duplicated Experiment {log} found. Repeat {repeat}\033[0m")
                    print(f"\033[33m    Experiment 1: {dup[(log, repeat)]}\033[0m")
                    print(f"\033[33m    Experiment 2: {mobj[1]}\033[0m")
                dup[(log, repeat)] = mobj[1]
                ret[log].append(func(path))
    print(len(dup))
    return ret

def extract_errs(path):
    with open(os.path.join("logs", path, "run.log"), "r") as f:
        for line in f.readlines()[::-1]: # find from last line
            results = re.findall(r"([\d\.]+e[-\+\d]+)", line)
            if results:
                if len(results) == 5:
                    results = results[1:]
                return list(map(float, results)) # Loss, MAE, MSE, L1RE, L2RE
    print(f"\033[33mWarning: Failed in finding errors in {path}\033[0m")
    return [np.nan]*5

def extract_time(path):
    with open(os.path.join("logs", path, "run.log"), "r") as f:
        for line in f.readlines()[::-1]: # find from last line
            mobj = re.match(r"Training costs: ([\d\.]+)s", line)
            if mobj:
                return float(mobj[1])
    print(f"\033[33mWarning: Failed in finding training time in {path}\033[0m")
    return np.nan

def get_mean_and_std(arr):
    return np.nanmean(arr), np.nanstd(arr)

def gettrain(export_path="result.csv"):
    # columns = ['pde', 'use_preconditioner', 'drop_tolerance', 'run_time_mean', 'run_time_std', 'mae_mean', 'mae_std', \
    #            'mse_mean', 'mse_std', 'l1re_mean', 'l1re_std', 'l2re_mean', 'l2re_std']
    columns = ['pde', 'drop_tolerance', 'run_time_mean', 'run_time_std', 'mae_mean', 'mae_std', \
               'mse_mean', 'mse_std', 'l1re_mean', 'l1re_std', 'l2re_mean', 'l2re_std']
    result = []
    errs = process(extract_errs)
    times = process(extract_time)

    for log, err in errs.items():
        err = np.array(err)
        run_time_mean, run_time_std = get_mean_and_std(times[log])
        mae_mean, mae_std = get_mean_and_std(err[:, 0])
        mse_mean, mse_std = get_mean_and_std(err[:, 1])
        l1re_mean, l1re_std = get_mean_and_std(err[:, 2])
        l2re_mean, l2re_std = get_mean_and_std(err[:, 3])
    
        # result.append([log['Problem Name'], log['Use Preconditioner'], log['Drop Tolerance (in ILU)'], run_time_mean, run_time_std, mae_mean, mae_std, \
        #                mse_mean, mse_std, l1re_mean, l1re_std, l2re_mean, l2re_std])
        result.append([log['Problem Name'], log['Drop Tolerance (in ILU)'], run_time_mean, run_time_std, mae_mean, mae_std, \
                       mse_mean, mse_std, l1re_mean, l1re_std, l2re_mean, l2re_std])

    result.sort(key=(lambda line: exp_cases.index(line[0])))

    df = pandas.DataFrame(result, columns=columns)
    df.to_csv(export_path)


if __name__ == "__main__":
    print(f"Experiment Start Time Set: {time.strftime('%Y-%m-%d-%H-%M-%S', exp_start_time)}")
    print(f"Experiment End Time Set: {time.strftime('%Y-%m-%d-%H-%M-%S', exp_end_time)}")
    gettrain()