# %%
from collections import OrderedDict
import glob
import json
import sys
import traceback
import re
import logging
from time import sleep
from einops import repeat
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import cortex
from matplotlib.pyplot import cm

from config_utils import flatten_dict

from IPython.display import display, HTML, clear_output

from datamodule import build_dm

plt.style.use("dark_background")
# %%
save_dir = "/data/script_base/xdab/"
exp_dir = "/data/results/xdaa/dino_bb/"
# roi_prefix = "htt"

subject_data_dir = "/data/VWET/"
device = "cpu"


# %%
def read_one_dir(exp_dir):
    runs = os.listdir(exp_dir)
    runs = sorted(runs)
    runs = [r for r in runs if "run_tune" in r]
    runs = [os.path.join(exp_dir, r) for r in runs]
    return runs


runs = []
for ed in os.walk(exp_dir):
    runs += read_one_dir(ed[0])
print(len(runs))

# %%
full_state_dict = {}
for run in runs:
    soup_path = os.path.join(run, "greedy_soup.pth")
    state_dict = torch.load(soup_path, map_location="cpu")
    if "state_dict" in state_dict.keys():
        state_dict = state_dict["state_dict"]
    for key in state_dict.keys():
        if "neuron_projectors" in key:
            full_state_dict[key] = state_dict[key]
# %%
for k, v in full_state_dict.items():
    print(k, v.shape)
print(len(full_state_dict.keys()))
# %%
os.makedirs(save_dir, exist_ok=True)
torch.save(full_state_dict, os.path.join(save_dir, "state_dict.pth.dinobb"))
# %%
