import numpy as np
import random
import multiprocessing as mp
import pickle


def parse_evaluation_results(file_path):
    results = {}
    with open(file_path, "r") as file:
        line_files = file.readlines()
        for line_idx, line in enumerate(line_files):
            line = line.strip()
            if line.startswith("long_term_forecast"):
                # Extract group information from the header
                parts = line.split("_")
                group_info = parts[-1]
                group_numbers = tuple(map(int, group_info.split("|")))

            elif line.startswith("per mse:") or line.startswith("per mae"):
                full_line = line
                idx = line_idx + 1
                while full_line.endswith("]") == False:
                    full_line += " "
                    full_line += line_files[idx].strip()
                    idx += 1
                line = full_line
                # Extract the per mse and per mae values
                mse_res, mae_res = line.split(",")
                mse_res = mse_res.split(":")[1]
                mae_res = mae_res.split("mae")[-1]
                mse_val = [float(_) for _ in mse_res.strip("[],").split(" ") if _ != ""]
                mae_val = [float(_) for _ in mae_res.strip("[],").split(" ") if _ != ""]
                # Initialize or update the dictionary for the group
                if group_numbers not in results:
                    results[group_numbers] = {}
                for i, task in enumerate(group_numbers):
                    if task not in results[group_numbers]:
                        results[group_numbers][task] = {"mae": [], "mse": []}
                    results[group_numbers][task]["mae"].append(mae_val[i])
                    results[group_numbers][task]["mse"].append(mse_val[i])
    return results


def read_infos(task,):
    if task == 'cop':
        all_grouping_res = pickle.load(open("./collect_infos/cop/cop_res.pkl", "rb"))
        return all_grouping_res
    elif task == 'celeba':
        all_grouping_res = pickle.load(open("./collect_infos/celeba/eval_results.pkl", "rb"))
        return all_grouping_res
    elif task == 'ettm1':
        all_grouping_res = parse_evaluation_results(
            "./collect_infos/ettm1/result_long_term_forecast.txt"
        )
        return all_grouping_res
    elif task == 'taskonomy':
        all_grouping_res = pickle.load(
            open("collect_infos/taskonomy/taskonomy_res.pkl", "rb")
        )
        return all_grouping_res
    else:
        raise ValueError("task only support cop, celeba, ettm1, taskonomy")


def get_grouping_test_res_celeba(all_comb_res, tasks, grouping_res):
    test_res = []
    test_res_std = []
    for group in grouping_res:
        test_res.append([])
        test_res_std.append([])
        sorted_group_idx = sorted(group)
        task_comb = "-".join([str(g + 1) for g in sorted_group_idx])
        task_comb_res = all_comb_res[task_comb]
        for i, task in enumerate(tasks):
            if i + 1 in task_comb_res:
                test_res[-1].append(np.mean(100 * (1 - np.array(task_comb_res[i + 1]))))
                test_res_std[-1].append(
                    np.std(100 * (1 - np.array(task_comb_res[i + 1])))
                )
            else:
                test_res[-1].append(100)
                test_res_std[-1].append(100)
    return np.array(test_res), np.array(test_res_std)


def get_grouping_test_res_taskonomy(all_comb_res, tasks, grouping_res):
    test_res = []
    test_res_std = []
    for group in grouping_res:
        test_res.append([])
        test_res_std.append([])
        sorted_group_idx = sorted(group)
        task_comb = tuple(sorted_group_idx)
        task_comb_res = all_comb_res[task_comb]
        for i, task in enumerate(tasks):
            if i in task_comb:
                idx_i = task_comb.index(i)
                test_res[-1].append(-task_comb_res[idx_i])
            else:
                test_res[-1].append(100)
    return np.array(test_res), np.ones_like(np.array(test_res))


def get_grouping_test_res_cop(all_comb_res, tasks, grouping_res):
    test_res = []
    for group in grouping_res:
        test_res.append([])
        sorted_group_idx = sorted(group)
        task_comb = tuple(sorted_group_idx)
        task_comb_res = all_comb_res[task_comb]
        for i, task in enumerate(tasks):
            if i in task_comb:
                test_res[-1].append(task_comb_res[task_comb.index(i)])
            else:
                test_res[-1].append(100)
    return np.array(test_res), np.zeros_like(np.array(test_res))


def get_grouping_test_res_ettm1(all_comb_res, tasks, grouping_res):
    mae_test_res = []
    mae_test_res_std = []
    for group in grouping_res:
        mae_test_res.append([])
        mae_test_res_std.append([])
        sorted_group_idx = sorted(group)
        task_comb = tuple(sorted_group_idx)
        task_comb_res = all_comb_res[task_comb]
        for i, task in enumerate(tasks):
            if i in task_comb_res:
                mae_test_res[-1].append(np.mean((np.array(task_comb_res[i]["mae"]))))
                mae_test_res_std[-1].append(np.std((np.array(task_comb_res[i]["mae"]))))
            else:
                mae_test_res[-1].append(100)
                mae_test_res_std[-1].append(100)
    return (
        np.array(mae_test_res),
        np.array(mae_test_res_std),
    )


def print_grouping_res(test_res_mean, test_res_std, tasks, method):
    if method == "stl":
        mean_res = test_res_mean[range(len(tasks)), range(len(tasks))]
        std_res = test_res_std[range(len(tasks)), range(len(tasks))]
        print("STL ", end=" ")
        for task, mean, std in zip(tasks, mean_res, std_res):
            print(f"{mean:.2f} ({std:.3f})", end=" ")
        print("{:.3f}".format(np.min(test_res_mean, axis=0).sum()))

    elif method == "mtl":
        mean_res = test_res_mean[0]
        std_res = test_res_std[0]
        print("MTL ", end=" ")
        for task, mean, std in zip(tasks, mean_res, std_res):
            print(f"{mean:.2f} ({std:.3f})", end=" ")
        print("{:.3f}".format(np.min(test_res_mean, axis=0).sum()))

    else:
        group_lens = [len(g) for g in test_res_mean]
        num_groups = len(group_lens)
        if 0 not in group_lens:
            # print method name in latext style for table
            for i, group_res in enumerate(test_res_mean):
                if i ==0:
                    print("& \multirow{{{}}}".format(num_groups), end="")
                    print("{*}", end='')
                    print("{{{}}}".format(method), end=' ')
                else:
                    print("& ", end=' ')
                for mean, std in zip(group_res, test_res_std[i]):
                    if mean == 100:
                        print("& -", end=" ")
                    else:
                        if std == 0:
                            print(f"& ${mean:.3f}\%$", end=" ")
                        elif std == 1:
                            print(f"& ${-mean:.3f}$", end=" ")
                        else:
                            print(f"& ${mean:.2f}\pm{std:.3f}$", end=" ")
                if i==0:
                    print("& \multirow{{{}}}".format(num_groups), end="")
                    print("{*}", end='')
                    if std == 0:
                        print(
                            "{{${:.3f}\%$}}".format(np.min(test_res_mean, axis=0).sum()), end=' '
                        )
                        print('\\\\')
                    elif std == 1:
                        print(
                            "{{${:.3f}$}}".format(-np.min(test_res_mean, axis=0).sum()), end=' '
                        )
                        print('\\\\')
                    else:
                        if method == 'MTG':
                            print(
                                "${{${:.3f}$}}$".format(
                                    test_res_mean[
                                        range(test_res_mean.shape[0]),
                                        range(test_res_mean.shape[0]),
                                    ].sum()
                                )
                            )
                        else:
                            print("{{${:.3f}$}}".format(np.min(test_res_mean, axis=0).sum()), end=' ')
                            print("\\\\")

                else:
                    print("& ", end=' ')
                    print("\\\\")


def show_res(all_comb_res, grouping_res, task, method):
    if task == 'cop':
        get_grouping_test_res = get_grouping_test_res_cop
        task_list = list(range(6))
    elif task == 'celeba':
        get_grouping_test_res = get_grouping_test_res_celeba
        task_list = list(range(9))
    elif task == 'ettm1':
        get_grouping_test_res = get_grouping_test_res_ettm1
        task_list = list(range(7))
    elif task == "taskonomy":
        get_grouping_test_res = get_grouping_test_res_taskonomy
        task_list = list(range(5))
    else:
        raise ValueError("task only support cop, celeba, ettm1")
    res= get_grouping_test_res(
        all_comb_res, task_list, grouping_res
    )
    print_grouping_res(*res, task_list, method)