# %%
import glob
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict, load_from_yaml

from IPython.display import display, HTML, clear_output

plt.style.use("dark_background")


def set_display():
    pd.options.display.float_format = "{:,.4f}".format
    pd.options.display.max_colwidth = 1000
    pd.options.display.max_rows = 1000
    pd.options.display.max_columns = 1000


def pretty_print(df):
    df.style.set_properties(**{"white-space": "pre"})
    return display(HTML(df.to_html().replace("\\n", "<br>")))


def read_config(run):
    cfg_path = os.path.join(run, "lightning_logs/hparams.yaml")
    if os.path.exists(cfg_path):
        cfg = yaml.load(open(cfg_path, "r"), Loader=yaml.FullLoader)
        cfg_string = json.dumps(cfg)
        return cfg_string
    else:
        logging.warning(f"cfg file not found: {cfg_path}")
        return ""


def read_config(run):
    cfg_path = glob.glob(os.path.join(run, "**/hparams.yaml"), recursive=True)[0]
    return load_from_yaml(cfg_path)


def read_short_config(run, pretty=True):
    json_path = os.path.join(run, "params.json")
    if os.path.exists(json_path):
        cfg = json.load(open(json_path, "r"))
        cfg_string = json.dumps(cfg)
        if pretty:
            cfg_string = (
                cfg_string.replace(",", " \n")
                .replace("{", "")
                .replace("}", "")
                .replace("[", "")
                .replace("]", "")
                .replace('"', "")
                # .replace(" ", "")
                .replace(":", ":    ")
            )
            # new_string = ""
            # max_line_len = 60
            # for line in cfg_string.split("\n"):
            #     new_line = ["_"] * max_line_len
            #     left, right = line.split(":")
            #     for i, char in enumerate(left):
            #         new_line[i] = char
            #     new_line[len(left)] = ":"
            #     for i, char in enumerate(right):
            #         new_line[max_line_len - len(right) + i] = char
            #     new_string += "".join(new_line) + "\n"
            # cfg_string = new_string
        return cfg_string
    else:
        logging.warning(f"cfg file not found: {json_path}")
        return ""


def read_short_config(run):
    json_path = glob.glob(os.path.join(run, "**/params.json"), recursive=True)[0]
    cfg = json.load(open(json_path, "r"))
    return cfg


def read_score_df(run):
    # csv_path = os.path.join(run, "stage_1/lightning_logs/metrics.csv")
    csv_path = glob.glob(os.path.join(run, "**/metrics.csv"), recursive=True)[0]
    print(csv_path)
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        return df
    else:
        logging.warning(f"csv file not found: {csv_path}")
        return None


# %%
def list_runs_from_exp_names(exp_names, exp_dir="/data/ray_results/"):
    runs = []
    for exp_name in exp_names:
        i_dir = os.path.join(exp_dir, exp_name)
        runs += os.listdir(i_dir)
        runs = [r for r in runs if os.path.isdir(os.path.join(i_dir, r))]
        runs = [os.path.join(i_dir, r) for r in runs]
    runs = sorted(runs)
    return runs


# %%
runs = list_runs_from_exp_names(names := ["supdata"])
print(runs)
print(len(runs))

subjects = ["NSD_01", "NSD_08", "EEG2_01", "MEG1_01", "fMRI1_01", "B5K_01"]
datas = []
for i, run in enumerate(runs):
    config_string = read_config(run)
    cfg = read_config(run)
    # short_config_string = read_short_config(run)
    j = read_short_config(run)
    fn = j["DATASET.SUBJECT_LIST"][0]
    df = read_score_df(run)
    hp_k = "^(TEST)/PearsonCorrCoef/.*"

    matches = [
        string
        for string in df.keys()
        if re.match(re.compile(f"{hp_k}"), string) and "challenge" not in string
    ]
    # data = {}

    for subject in subjects:
        if f"TEST/PearsonCorrCoef/{subject}/all" not in df.keys():
            continue
        mean_score = df[f"TEST/PearsonCorrCoef/{subject}/all"].values
        mean_score = mean_score[~np.isnan(mean_score)]
        mean_score = mean_score[-1]
        datas.append([fn, subject, mean_score])

    # break


df1 = pd.DataFrame(datas, columns=["fn", "subject", "score"])
# df = pd.DataFrame(datas, columns=["fn", "roi", "subject", "score"])
# %%
df1
# %%
df1.pivot_table(index="fn", columns="subject", values="score")
# %%
runs = list_runs_from_exp_names(names := ["supood"])
datas2 = []
for i, run in enumerate(runs):
    config_string = read_config(run)
    cfg = read_config(run)
    # short_config_string = read_short_config(run)
    j = read_short_config(run)
    fn = 'AFO'
    df = read_score_df(run)
    hp_k = "^(TEST)/PearsonCorrCoef/.*"

    matches = [
        string
        for string in df.keys()
        if re.match(re.compile(f"{hp_k}"), string) and "challenge" not in string
    ]
    # data = {}

    for subject in subjects:
        if f"TEST/PearsonCorrCoef/{subject}/all" not in df.keys():
            continue
        mean_score = df[f"TEST/PearsonCorrCoef/{subject}/all"].values
        mean_score = mean_score[~np.isnan(mean_score)]
        mean_score = mean_score[-1]
        datas2.append([fn, subject, mean_score])

    # break


df2 = pd.DataFrame(datas2, columns=["fn", "subject", "score"])
# %%
df2
# %%
runs = list_runs_from_exp_names(names := ["supoodnm"])
datas3 = []
for i, run in enumerate(runs):
    config_string = read_config(run)
    cfg = read_config(run)
    # short_config_string = read_short_config(run)
    j = read_short_config(run)
    fn = 'NM'
    df = read_score_df(run)
    hp_k = "^(TEST)/PearsonCorrCoef/.*"

    matches = [
        string
        for string in df.keys()
        if re.match(re.compile(f"{hp_k}"), string) and "challenge" not in string
    ]
    # data = {}

    for subject in subjects:
        if f"TEST/PearsonCorrCoef/{subject}/all" not in df.keys():
            continue
        mean_score = df[f"TEST/PearsonCorrCoef/{subject}/all"].values
        mean_score = mean_score[~np.isnan(mean_score)]
        mean_score = mean_score[-1]
        datas3.append([fn, subject, mean_score])

    # break


df3 = pd.DataFrame(datas3, columns=["fn", "subject", "score"])
# %%
df = pd.concat([df1, df2, df3])
# %%
df = df.pivot_table(index="fn", columns="subject", values="score")
# %%
# reorder by index
df = df.reindex(["AFO", "NM", "NSD_01", "NSD_08", 'all_nsd', "EEG2_01", 'all_eeg2', "MEG1_01", 'all_meg1', "fMRI1_01", 'all_fmri1', "B5K_01", 'all_b5k'])
# reorder cols
df = df[["NSD_01", "NSD_08", "EEG2_01","MEG1_01", "fMRI1_01", "B5K_01"]]
# %%
# replace row name
df = df.rename(index={"all_nsd": "NSD_ALL", 'all_eeg2': "EEG_ALL", 'all_meg1': "MEG_ALL", 'all_fmri1': "fMRI_ALL", 'all_b5k': "BOLD5K_ALL", "EEG_01": "EEG_01", "MEG_01": "MEG_01", "fMRI_01": "fMRI_01", "B5K_01": "BOLD5K_01"})
# %%
df.fillna('-', inplace=True)
# %%
# set format to .3f
df = df.applymap(lambda x: f"{x:.3f}" if isinstance(x, float) else x)
# %%
df
# %%
df.to_csv("/workspace/figs/suptabl_afo.csv")
# %%
