# %%
import glob
import json
import sys
import traceback
import re
import logging
from time import sleep
from matplotlib import ticker
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict, load_from_yaml

from IPython.display import display, HTML, clear_output

# plt.style.use("dark_background")

def load_cfg(run):
    path = glob.glob(run + "/**/hparams.yaml", recursive=True)
    print(path)
    path = path[0]
    cfg = load_from_yaml(path)
    return cfg


def read_neuron_location(run):
    dir_path = os.path.join(run, "neuron_location")
    pt_files = os.listdir(dir_path)
    pt_files = [f for f in pt_files if f.endswith(".pt")]
    pt_files = sorted(pt_files)

    if len(pt_files) == 0:
        return None, None

    mus, gates = [], []
    for f in pt_files:
        f_path = os.path.join(dir_path, f)
        data = torch.load(f_path)
        subs = data.keys()
        assert len(subs) == 1
        sub = list(subs)[0]
        if "mu" in f:
            data = torch.tensor(data[sub]['layer2'])
            mus.append(data)
        elif "gate" in f:
            data = torch.tensor(data[sub])
            gates.append(data)
        else:
            # raise ValueError(f"unknown file: {f}")
            continue

    return mus, gates

def scatter_plot_gate_mu(mu, gate, argmax=True):
    # np.random.seed(0)
    # random_indices = np.random.choice(mu.shape[0], 1000, replace=False)
    # mu = mu[random_indices]
    # gate = gate[random_indices]

    if argmax:
        labels = np.argmax(gate, axis=1) + 1
    else:
        arr = np.arange(gate.shape[1]) + 1
        arr = arr.reshape(1, -1)
        labels = np.sum(gate * arr, axis=1)
    # cm = plt.get_cmap("jet")
    # norm = matplotlib.colors.Normalize(vmin=0, vmax=gate.shape[1])
    # fig = plt.figure(figsize=(5, 5))
    plt.scatter(
        mu[:, 0],
        mu[:, 1],
        s=1,
        c=labels,
        alpha=0.5,
        cmap="gist_rainbow",
        vmin=1,
        vmax=4,
        rasterized=True,
    )
    
# %%
exp_dir = "/data/results/xdaa/dino_mania/suplong/"
runs = glob.glob(exp_dir + "/run_tune*")
runs.sort()
# %%
subs = []
for run in runs:
    cfg = load_cfg(run)
    sub = cfg.DATASET.SUBJECT_LIST[0]
    subs.append(sub)

# # %%
# for run, sub in zip(runs, subs):
#     print(f"{sub}: {run}")
#     mus, gates = read_neuron_location(run)
#     print(len(mus), len(gates))
#     for i in range(0, 61, 10):
#         scatter_plot_gate_mu(mus[i], gates[i])
#         plt.title(f"{sub} {i}")
#         plt.show()
        

#     break
# %%
# %%
fig, axs = plt.subplots(12, 11, figsize=(22, 24))
axs = axs.flatten()
i_ax = 0
for run, sub in zip(runs, subs):
    mus, gates = read_neuron_location(run)
    for i in range(0, 101, 10):
        ax = axs[i_ax]
        plt.sca(ax)
        scatter_plot_gate_mu(mus[i], gates[i])
        # plt.title(f"{sub} {i}")
        plt.xlim(-1.0, 1.0)
        plt.ylim(-1.0, 1.0)
        ax.grid(axis="both", linestyle="--", alpha=0.5)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(0.0625 * 2))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(0.0625 * 2))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_facecolor('#535353')
        for tick in ax.xaxis.get_major_ticks():
            tick.tick1line.set_visible(False)
            tick.tick2line.set_visible(False)
            tick.label1.set_visible(False)
            tick.label2.set_visible(False)
        for tick in ax.yaxis.get_major_ticks():
            tick.tick1line.set_visible(False)
            tick.tick2line.set_visible(False)
            tick.label1.set_visible(False)
            tick.label2.set_visible(False)
        
        if i_ax % 11 == 0:
            ax.set_ylabel(sub, fontsize=20)
        if i_ax < 11:
            ep = i
            title = f"{ep}" if ep != 0 else "Epoch 0"
            ax.set_title(title, fontsize=20)
        
        i_ax += 1
    # break
        
plt.tight_layout()
# plt.show()
plt.savefig("/workspace/figs/supfig2_retinamap.pdf")
plt.close()
# %%
# make GIF
def make_gif_fn(image_paths, gif_path, duration=1000):

    images = []
    for path in image_paths:
        images.append(Image.open(path))

    def draw_progress_bar(draw, progress, width, height, bg="grey", fg="red"):
        draw.rectangle((0, 0, width, height), fill=bg)
        draw.rectangle((0, 0, width * progress, height), fill=fg)
        return draw

    total_steps = int(image_paths[-1].split("_")[-1].split(".")[0])
    for image in images:
        step = int(image.filename.split("_")[-1].split(".")[0])
        progress = step / total_steps
        draw = ImageDraw.Draw(image)
        width = image.width
        high = image.height * 0.02
        draw = draw_progress_bar(draw, progress, width, high)

    images[0].save(
        gif_path,
        save_all=True,
        append_images=images[1:],
        optimize=True,
        duration=duration,
        loop=0,
    )
    return gif_path

def make_video_fn(gif_path, video_path, overwrite=True):
    if os.path.exists(video_path) and not overwrite:
        return video_path
    cmd = f"ffmpeg -y -i {gif_path} -pix_fmt yuv420p -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' {video_path}"
    os.system(cmd)
    return video_path

# %%
mu_dict = {}
gate_dict = {}
for run, sub in zip(runs, subs):
    mus, gates = read_neuron_location(run)
    mu_dict[sub] = mus
    gate_dict[sub] = gates
# %%
png_paths = []
from tqdm import tqdm
for i_ep in tqdm(range(0, 201, 1)):
    fig, axs = plt.subplots(3, 4, figsize=(12, 9))
    axs = axs.flatten()
    i_ax = 0
    for i_ax, sub in enumerate(subs):
        ax = axs[i_ax]
        plt.sca(ax)
        mus = mu_dict[sub]
        gates = gate_dict[sub]
        scatter_plot_gate_mu(mus[i_ep], gates[i_ep])
        # plt.title(f"{sub} {i}")
        plt.xlim(-1.0, 1.0)
        plt.ylim(-1.0, 1.0)
        ax.grid(axis="both", linestyle="--", alpha=0.5)
        ax.xaxis.set_major_locator(ticker.MultipleLocator(0.0625 * 2))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(0.0625 * 2))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_facecolor('#535353')
        for tick in ax.xaxis.get_major_ticks():
            tick.tick1line.set_visible(False)
            tick.tick2line.set_visible(False)
            tick.label1.set_visible(False)
            tick.label2.set_visible(False)
        for tick in ax.yaxis.get_major_ticks():
            tick.tick1line.set_visible(False)
            tick.tick2line.set_visible(False)
            tick.label1.set_visible(False)
            tick.label2.set_visible(False)
        
        ax.set_title(sub, fontsize=20)
                
        # break
    
    plt.suptitle(f"Epoch {i_ep:3d}", fontsize=20)
    plt.tight_layout()
    # plt.show()
    path = f"/tmp/supfig2_retinamap_{i_ep:03d}.png"
    plt.savefig(path, dpi=144)
    png_paths.append(path)
    plt.close()
# %%
png_paths = [f"/tmp/supfig2_retinamap_{i_ep:03d}.png" for i_ep in range(0, 101, 1)]
make_gif_fn(png_paths, "/workspace/figs/supfig2_hanabi.gif", duration=100)
# %%
make_video_fn("/workspace/figs/supfig2_hanabi.gif", "/workspace/figs/supfig2_hanabi.mp4")
# %%
