# %%
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict

from IPython.display import display, HTML, clear_output

plt.style.use("dark_background")


def set_display():
    pd.options.display.float_format = "{:,.4f}".format
    pd.options.display.max_colwidth = 1000
    pd.options.display.max_rows = 1000
    pd.options.display.max_columns = 1000


def pretty_print(df):
    df.style.set_properties(**{"white-space": "pre"})
    return display(HTML(df.to_html().replace("\\n", "<br>")))


def read_config(run):
    cfg_path = os.path.join(run, "lightning_logs/hparams.yaml")
    if os.path.exists(cfg_path):
        cfg = yaml.load(open(cfg_path, "r"), Loader=yaml.FullLoader)
        cfg_string = json.dumps(cfg)
        return cfg_string
    else:
        logging.warning(f"cfg file not found: {cfg_path}")
        return ""


def read_short_config(run, pretty=True):
    json_path = os.path.join(run, "params.json")
    if os.path.exists(json_path):
        cfg = json.load(open(json_path, "r"))
        cfg_string = json.dumps(cfg)
        if pretty:
            cfg_string = (
                cfg_string.replace(",", " \n")
                .replace("{", "")
                .replace("}", "")
                .replace("[", "")
                .replace("]", "")
                .replace('"', "")
                # .replace(" ", "")
                .replace(":", ":    ")
            )
            # new_string = ""
            # max_line_len = 60
            # for line in cfg_string.split("\n"):
            #     new_line = ["_"] * max_line_len
            #     left, right = line.split(":")
            #     for i, char in enumerate(left):
            #         new_line[i] = char
            #     new_line[len(left)] = ":"
            #     for i, char in enumerate(right):
            #         new_line[max_line_len - len(right) + i] = char
            #     new_string += "".join(new_line) + "\n"
            # cfg_string = new_string
        return cfg_string
    else:
        logging.warning(f"cfg file not found: {json_path}")
        return ""


def read_score_df(run):
    csv_path = os.path.join(run, "lightning_logs/metrics.csv")
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        return df
    else:
        logging.warning(f"csv file not found: {csv_path}")
        return None
# %%
def list_runs_from_exp_names(exp_names, exp_dir="/data/ray_results/"):
    runs = []
    for exp_name in exp_names:
        i_dir = os.path.join(exp_dir, exp_name)
        runs += os.listdir(i_dir)
        runs = [r for r in runs if os.path.isdir(os.path.join(i_dir, r))]
        runs = [os.path.join(i_dir, r) for r in runs]
    runs = sorted(runs)
    return runs

# %%
runs = list_runs_from_exp_names(names:=["sync_ablation"])
print(runs)
print(len(runs))
# %%
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for i, run in enumerate(runs):
    # ax = axs[i // 2, i % 2]
    ax = axs[i % 2, i // 2]
    plt.sca(ax)
    
    config_string = read_config(run)
    short_config_string = read_short_config(run)
    df = read_score_df(run)
    hp_k="^(VAL)/PearsonCorrCoef/.*"

    matches = [
        string for string in df.keys() if re.match(re.compile(f"{hp_k}"), string) and "challenge" not in string
    ]
    data = {}
    data["config"] = short_config_string
    
    for m in matches:
        df[m] = df[m].apply(lambda x: np.nan if x == "None" else x)
        values = df[m].values
        values = values[~np.isnan(values)]
        m = m.replace("PearsonCorrCoef/", "")
        plt.plot(values, alpha=0.9, label=m)
    
    mean_score = df["VAL/PearsonCorrCoef/mean"].values
    mean_score = mean_score[~np.isnan(mean_score)]
    mean_score = mean_score[-1]
    
    plt.ylim(0, 0.55)
    plt.legend()
    plt.title(short_config_string+"\n"+f"mean score: {mean_score:.3f}")
plt.show()
# %%
