# %%
import glob
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict, load_from_yaml

from IPython.display import display, HTML, clear_output

plt.style.use("dark_background")


def set_display():
    pd.options.display.float_format = "{:,.4f}".format
    pd.options.display.max_colwidth = 1000
    pd.options.display.max_rows = 1000
    pd.options.display.max_columns = 1000


def pretty_print(df):
    df.style.set_properties(**{"white-space": "pre"})
    return display(HTML(df.to_html().replace("\\n", "<br>")))


def read_config(run):
    cfg_path = os.path.join(run, "lightning_logs/hparams.yaml")
    if os.path.exists(cfg_path):
        cfg = yaml.load(open(cfg_path, "r"), Loader=yaml.FullLoader)
        cfg_string = json.dumps(cfg)
        return cfg_string
    else:
        logging.warning(f"cfg file not found: {cfg_path}")
        return ""


def read_config(run):
    cfg_path = glob.glob(os.path.join(run, "**/hparams.yaml"), recursive=True)[0]
    return load_from_yaml(cfg_path)


def read_short_config(run, pretty=True):
    json_path = os.path.join(run, "params.json")
    if os.path.exists(json_path):
        cfg = json.load(open(json_path, "r"))
        cfg_string = json.dumps(cfg)
        if pretty:
            cfg_string = (
                cfg_string.replace(",", " \n")
                .replace("{", "")
                .replace("}", "")
                .replace("[", "")
                .replace("]", "")
                .replace('"', "")
                # .replace(" ", "")
                .replace(":", ":    ")
            )
            # new_string = ""
            # max_line_len = 60
            # for line in cfg_string.split("\n"):
            #     new_line = ["_"] * max_line_len
            #     left, right = line.split(":")
            #     for i, char in enumerate(left):
            #         new_line[i] = char
            #     new_line[len(left)] = ":"
            #     for i, char in enumerate(right):
            #         new_line[max_line_len - len(right) + i] = char
            #     new_string += "".join(new_line) + "\n"
            # cfg_string = new_string
        return cfg_string
    else:
        logging.warning(f"cfg file not found: {json_path}")
        return ""


def read_short_config(run):
    json_path = glob.glob(os.path.join(run, "**/params.json"), recursive=True)[0]
    cfg = json.load(open(json_path, "r"))
    return cfg


def read_score_df(run):
    csv_path = os.path.join(run, "lightning_logs/metrics.csv")
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        return df
    else:
        logging.warning(f"csv file not found: {csv_path}")
        return None


# %%
def list_runs_from_exp_names(exp_names, exp_dir="/data/ray_results/"):
    runs = []
    for exp_name in exp_names:
        i_dir = os.path.join(exp_dir, exp_name)
        runs += os.listdir(i_dir)
        runs = [r for r in runs if os.path.isdir(os.path.join(i_dir, r))]
        runs = [os.path.join(i_dir, r) for r in runs]
    runs = sorted(runs)
    return runs


# %%
runs = list_runs_from_exp_names(names := ["topyneck_ablation"])
print(runs)
print(len(runs))
# %%
datas = []
for i, run in enumerate(runs):
    config_string = read_config(run)
    cfg = read_config(run)
    # short_config_string = read_short_config(run)
    j = read_short_config(run)
    fn = j["fn"]
    subject = j["subject"][0]
    df = read_score_df(run)
    hp_k = "^(TEST)/PearsonCorrCoef/.*"

    matches = [
        string
        for string in df.keys()
        if re.match(re.compile(f"{hp_k}"), string) and "challenge" not in string
    ]
    # data = {}

    mean_score = df["TEST/PearsonCorrCoef/mean"].values
    mean_score = mean_score[~np.isnan(mean_score)]
    mean_score = mean_score[-1]
    # data[(fn, "all", subject)] = mean_score
    datas.append([fn, "all", subject.replace("EEG", "xeeg"), mean_score])

    if "NSD" in subject:
        for roi in ["early", "mid", "late"]:
            mean_score = df[f"TEST/PearsonCorrCoef/{subject}/{roi}"].values
            mean_score = mean_score[~np.isnan(mean_score)]
            mean_score = mean_score[-1]
            # data[(fn, roi, subject)] = mean_score
            datas.append([fn, roi, subject.replace("EEG", "xeeg"), mean_score])

    # datas.append(data)
    
df = pd.DataFrame(datas, columns=["fn", "roi", "subject", "score"])
# # %%
# # mean over fn-subject-roi
# df = df.groupby(["fn", "roi", "subject"]).mean()
# df = df.pivot_table(index=["fn", "subject"], columns="roi", values="score")
# # replace nan with -
# df = df.fillna("-")
# # rename fn
# df = df.rename(
#     index={
#         "row1": "FullTopyNeck",
#         "row2": "FrozenNeuronProjector",
#         "row3": "w/o AvgMaxPooling",
#         "row4": "FrozenLayerSelector",
#         "row5": "NoRegLayerSelector",
#     }
# )
# # %%
# # format to .4f
# df = df.applymap(lambda x: f"{x:.4f}" if isinstance(x, float) else x)
# # %%
# df
# # %%
# df.to_csv("/workspace/figs/topyneck_ablation.csv")
# %%

df = df.pivot_table(index=["fn"], columns=["subject", "roi"], values="score")
df = df.fillna("-")
# rename fn
df = df.rename(
    index={
        "row1": "FullTopyNeck",
        "row2": "FrozenNeuProjector",
        "row3": "w/o AvgMaxPooling",
        "row4": "FrozenLayerSelector",
        "row5": "NoRegLayerSelector",
    }
)
df = df.applymap(lambda x: f"{x:.3f}" if isinstance(x, float) else x)
df
# %%
df.to_excel("/workspace/figs/topyneck_ablation.xlsx")
# %%
