# %%
import argparse
import copy
import fnmatch
from functools import partial
import glob
import operator
import os
import sys
from typing import Dict
from filelock import FileLock

import numpy as np
import pytorch_lightning as pl
import torch

import shutil

from tqdm import tqdm

from config import AutoConfig
from config_utils import load_from_yaml
from datamodule import build_dm, AllDatamodule
from models import VEModel
from topyneck import VoxelOutBlock

# submission_dir = "/data/algonauts_2023_challenge_submission/clip_gen1"
# save_postfix = "mania_veroi_m_gen1"
save_postfix = "random_m_gen1"
# save_postfix = "veroi_m_gen2n_darkgt_darkfull"
submission_dir = f"/data/algonauts_2023_challenge_submission/{save_postfix}"
os.makedirs(submission_dir, exist_ok=True)
# %%
exp_dir = "/data/results/xeaa_mkv/dino_mania/"
# exp_name = "veroi_m"
exp_name = "random_m"
# exp_name = "veroi_m_darkgt_darkfull"


def read_one_dir(exp_dir):
    runs = os.listdir(exp_dir)
    runs = sorted(runs)
    runs = [r for r in runs if "run_tune" in r]
    runs = [os.path.join(exp_dir, r) for r in runs]
    return runs


runs = []
for ed in os.walk(exp_dir):
    if exp_name not in os.path.basename(ed[0]):
        continue
    runs += read_one_dir(ed[0])
print(f"num_runs: {len(runs)}")
print(runs)

# %%
from config_utils import load_from_yaml

# cfg = load_from_yaml("/workspace/configs/crn_base.yaml")
cfg = load_from_yaml("/workspace/configs/dino_bb.yaml")
dm = build_dm(cfg)
dm.setup()
num_voxels_dict = dm.num_voxel_dict
subject_list = dm.subject_list


# %%
def load_model_dm(run_dir):
    # run_dir = os.path.join(run_dir, "stage_2")
    # run_dir = os.path.join(run_dir, "r1s1_1")

    cfg: AutoConfig = torch.load(os.path.join(run_dir, "r1s1_1/cfg.pth"))
    cfg.DATASET.DARK_POSTFIX = ""
    cfg.MODEL.LAYER_GATE.SKIP = False

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )
    model = VEModel(*model_args)

    state_dict = torch.load(os.path.join(run_dir, "soup.pth"))

    model.load_state_dict(state_dict)

    model = model.eval()

    return model, dm, cfg


# %%
print(runs)


# %%
@torch.no_grad()
def get_outs(model, trainer, dataloader):
    outs = trainer.predict(model, dataloader)
    outs = torch.stack(sum(outs, []))
    outs = outs.cpu().numpy().astype(np.float16)
    return outs


# %%
# trainer = pl.Trainer(accelerator='cuda', devices=[0], enable_progress_bar=True)

stages = ["train", "val", "test", "predict"]
# stages = ["test"]
# vi_dict = {subject: {} for subject in subject_list}
# subject_out_dict = {subject: {} for subject in subject_list}
# stage_out_dict = {stage: copy.deepcopy(subject_out_dict) for stage in stages}
# y_idx_dict = {subject: {} for subject in subject_list}
# stage_y_idx_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}
# stage_y_path_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}
# for run_dir in runs[::-1]:
#     print(run_dir)
#     model, dm, cfg = load_model_dm(run_dir)
#     roi = cfg.DATASET.ROIS[0]
#     for subject in subject_list[:]:
#         print(subject)
#         if subject not in dm.dss[0].keys():
#             vi = np.array([])
#             vi_dict[subject][roi] = vi
#             continue # skip subjects not in dataset
#         train_dl = dm.train_dataloader(subject=subject, shuffle=False)
#         val_dl = dm.val_dataloader(subject=subject)
#         test_dl = dm.test_dataloader(subject=subject)
#         predict_dl = dm.predict_dataloader(subject=subject) if "NSD" in subject else None
#         for dl, stage in zip([train_dl, val_dl, test_dl, predict_dl], stages):
#             # print(stage)
#             print(run_dir, subject, stage)
#             # debug
#             # if stage == "train":
#             #     continue

#             if stage == "predict":
#                 if "NSD" not in subject:
#                     continue
#             outs = get_outs(model, trainer, dl)
#             stage_out_dict[stage][subject][roi] = outs

#             if stage != "predict":
#                 y_paths = dl.dataset.y_paths
#                 y_idx = [int(os.path.basename(yp).split(".")[0]) for yp in y_paths]
#                 stage_y_idx_dict[stage][subject] = y_idx
#                 stage_y_path_dict[stage][subject] = y_paths


#         vi = train_dl.dataset.voxel_index.numpy()
#         vi_dict[subject][roi] = vi
#         # break
#     # break
# %%
def load_modify_save(path, keys, data):
    from flatten_dict import flatten
    from flatten_dict import unflatten
    
    keys = tuple(keys)

    # lock = FileLock(path + ".lock")
    with FileLock(path + ".lock"):
        if not os.path.exists(path):
            d = {}
            torch.save(d, path)
        d = torch.load(path)
        flat_d = flatten(d)
        flat_d[keys] = data
        d = unflatten(flat_d)
        torch.save(d, path)

# os.system(f"rm /tmp/*.pth")

# %%
def process_one_run(run_dir):
    vi_dict = {subject: {} for subject in subject_list}
    subject_out_dict = {subject: {} for subject in subject_list}
    stage_out_dict = {stage: copy.deepcopy(subject_out_dict) for stage in stages}
    y_idx_dict = {subject: {} for subject in subject_list}
    stage_y_idx_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}
    stage_y_path_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}

    trainer = pl.Trainer(accelerator="cuda", devices=[0], enable_progress_bar=False)

    model, dm, cfg = load_model_dm(run_dir)
    roi = cfg.DATASET.ROIS[0]
    for subject in subject_list[:]:
        print(subject)
        if subject not in dm.dss[0].keys():
            vi = np.array([])
            vi_dict[subject][roi] = vi
            # load_modify_save("/tmp/vi_dict.pth", [subject, roi], vi)
            continue  # skip subjects not in dataset
        train_dl = dm.train_dataloader(subject=subject, shuffle=False)
        val_dl = dm.val_dataloader(subject=subject)
        test_dl = dm.test_dataloader(subject=subject)
        predict_dl = (
            dm.predict_dataloader(subject=subject) if "NSD" in subject else None
        )
        stage_dl_dict = {
            "train": train_dl,
            "val": val_dl,
            "test": test_dl,
            "predict": predict_dl,
        }
        for stage in stages:
            dl = stage_dl_dict[stage]
            # print(stage)
            print(run_dir, subject, stage)
            # debug
            # if stage == "train":
            #     continue

            if stage == "predict":
                if "NSD" not in subject:
                    continue
            outs = get_outs(model, trainer, dl)
            stage_out_dict[stage][subject][roi] = outs
            # load_modify_save("/tmp/stage_out_dict.pth", [stage, subject, roi], outs)

            if stage != "predict":
                y_paths = dl.dataset.y_paths
                y_idx = [int(os.path.basename(yp).split(".")[0]) for yp in y_paths]
                stage_y_idx_dict[stage][subject] = y_idx
                stage_y_path_dict[stage][subject] = y_paths
                # load_modify_save("/tmp/stage_y_idx_dict.pth", [stage, subject], y_idx)
                # load_modify_save(
                #     "/tmp/stage_y_path_dict.pth", [stage, subject], y_paths
                # )

        vi = train_dl.dataset.voxel_index.numpy()
        vi_dict[subject][roi] = vi
        # load_modify_save("/tmp/vi_dict.pth", [subject, roi], vi)
    
    # load_modify_save(f"/tmp/{roi}_vi_dict.pth", [subject, roi], vi_dict)
    # load_modify_save(f"/tmp/{roi}_stage_out_dict.pth", [stage, subject, roi], stage_out_dict)
    # load_modify_save(f"/tmp/{roi}_stage_y_idx_dict.pth", [stage, subject], stage_y_idx_dict)
    # load_modify_save(f"/tmp/{roi}_stage_y_path_dict.pth", [stage, subject], stage_y_path_dict)
    torch.save(vi_dict, f"/tmp/{roi}_vi_dict.pth")
    torch.save(stage_out_dict, f"/tmp/{roi}_stage_out_dict.pth")
    torch.save(stage_y_idx_dict, f"/tmp/{roi}_stage_y_idx_dict.pth")
    torch.save(stage_y_path_dict, f"/tmp/{roi}_stage_y_path_dict.pth")


def run_tune(tune_dict):
    run_dir = tune_dict["run_dir"]
    process_one_run(run_dir)

# %%
from ray import tune
os.system(f"rm /tmp/*.pth")

tune_dict = {"run_dir": tune.grid_search(runs)}
ana = tune.run(
    run_tune, config=tune_dict, resources_per_trial={"cpu": 1, "gpu": 1}, verbose=True
)
# exit()
# %%
# torch.save(stage_out_dict, "/tmp/stage_out_dict.pth")
# torch.save(stage_y_idx_dict, "/tmp/stage_y_idx_dict.pth")
# torch.save(stage_y_path_dict, "/tmp/stage_y_path_dict.pth")
# torch.save(vi_dict, "/tmp/vi_dict.pth")
# %%
torch.save(num_voxels_dict, "/tmp/num_voxels_dict.pth")
torch.save(subject_list, "/tmp/subject_list.pth")
# exit()
# %%
# stage_out_dict = torch.load("/tmp/stage_out_dict.pth")
# stage_y_idx_dict = torch.load("/tmp/stage_y_idx_dict.pth")
# stage_y_path_dict = torch.load("/tmp/stage_y_path_dict.pth")
# vi_dict = torch.load("/tmp/vi_dict.pth")
num_voxels_dict = torch.load("/tmp/num_voxels_dict.pth")
subject_list = torch.load("/tmp/subject_list.pth")
# %%
def load_and_merge(pattern):
    from flatten_dict import flatten, unflatten
    files = glob.glob(pattern)
    d = {}
    for f in files:
        dd = flatten(torch.load(f))
        print(dd.keys())
        d.update(dd)
    d = unflatten(d)
    return d

stage_out_dict = load_and_merge("/tmp/*_stage_out_dict.pth")
stage_y_idx_dict = load_and_merge("/tmp/*_stage_y_idx_dict.pth")
stage_y_path_dict = load_and_merge("/tmp/*_stage_y_path_dict.pth")
vi_dict = load_and_merge("/tmp/*_vi_dict.pth")

# %%
# stages = ["train", "val", "test", "predict"]
# %%
# gather
dark_y_dict = {stage: {} for stage in stages}
for stage in stages:
    # # debug
    # if stage == 'train':
    #     continue
    for subject in subject_list[:]:

        if stage == 'predict':
            if "NSD" not in subject:
                continue

        roi_out_dict = stage_out_dict[stage][subject]
        rois = list(roi_out_dict.keys())

        num_outs = roi_out_dict[rois[0]].shape[0]
        full_outs = np.zeros((num_outs, num_voxels_dict[subject]), dtype=np.float16)
        print(f"num_outs: {num_outs}")

        for roi in rois:
            outs = roi_out_dict[roi]
            vi = vi_dict[subject][roi]
            full_outs[:, vi] = outs

        if stage != 'train':
            dark_y_dict[stage][subject] = full_outs

        if stage != 'predict':
            y_idxs = stage_y_idx_dict[stage][subject]
            y_paths = stage_y_path_dict[stage][subject]
            for i in range(num_outs):
                y_i = full_outs[i]
                orig_save_path = y_paths[i]
                save_path = orig_save_path.replace(".npy", f".{save_postfix}.npy")
                np.save(save_path, y_i.astype(np.float16))

#     #     break
#     # break
# # %%
# # save submission
# # %%
# for i in range(1, 9):
#     subject = f"NSD_{i:02d}"
#     outs = dark_y_dict["predict"][subject]

#     mask_dir = f"/data/algonauts2023/subj{i:02d}/roi_masks"
#     lh_mask = os.path.join(mask_dir, "lh.streams_challenge_space.npy")
#     lh_mask = np.load(lh_mask)
#     rh_mask = os.path.join(mask_dir, "rh.streams_challenge_space.npy")
#     rh_mask = np.load(rh_mask)
#     num_lh = lh_mask.shape[0]
#     num_rh = rh_mask.shape[0]
#     assert num_rh + num_lh == outs.shape[1]

#     lh_outs = outs[:, :num_lh]
#     rh_outs = outs[:, num_lh:]
#     lh_outs = lh_outs.astype(np.float32)
#     rh_outs = rh_outs.astype(np.float32)

#     subject_dir = os.path.join(submission_dir, f"subj{i:02d}")
#     os.makedirs(subject_dir, exist_ok=True)
#     np.save(os.path.join(subject_dir, "lh_pred_test.npy"), lh_outs)
#     np.save(os.path.join(subject_dir, "rh_pred_test.npy"), rh_outs)

# %%
