# %%
import argparse
import copy
import fnmatch
from functools import partial
import glob
import operator
import os
import sys
from typing import Dict

import numpy as np
import pytorch_lightning as pl
import torch

import shutil

from tqdm import tqdm

from config import AutoConfig
from config_utils import load_from_yaml
from datamodule import build_dm, AllDatamodule
from models import VEModel, DarkVEModel
from topyneck import VoxelOutBlock

parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="random_m_gen2_darkfull", help="name of the experiment")

args = parser.parse_args("")

debug = False

# submission_dir = "/data/algonauts_2023_challenge_submission/clip_gen1"
# save_postfix = "veroi_m_gen2_darkfull"
save_postfix = args.name
# save_postfix = "veroi_m_gen2n_darkgt_darkfull"
# save_postfix = 'darkdebug'
submission_dir = f"/data/algonauts_2023_challenge_submission/{save_postfix}"
os.makedirs(submission_dir, exist_ok=True)
# %%
exp_dir = "/data/results/xfaa/mkv/dino_mania/"
# exp_name = "veroi_m_gen2_darkfull"
# exp_name = "veroi_m_gen3"
# exp_name = "veroi_m_darkgt_darkfull"
exp_name = "random_m_gen2"


def read_one_dir(exp_dir):
    runs = os.listdir(exp_dir)
    runs = sorted(runs)
    runs = [r for r in runs if "run_tune" in r]
    runs = [os.path.join(exp_dir, r) for r in runs]
    return runs


runs = []
for ed in os.walk(exp_dir):
    if exp_name not in os.path.basename(ed[0]):
        continue
    runs += read_one_dir(ed[0])
print(f"num_runs: {len(runs)}")
print(runs)

# %%
from config_utils import load_from_yaml
# cfg = load_from_yaml("/workspace/configs/crn_base.yaml")
cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")
fdm = build_dm(cfg)
fdm.setup()
num_voxels_dict = fdm.num_voxel_dict
subject_list = fdm.subject_list
# %%
def load_model_dm(run_dir):
    # run_dir = os.path.join(run_dir, "stage_2")
    # run_dir = os.path.join(run_dir, "r1s1_1")
    basename = "hparams.yaml"
    f = glob.glob(os.path.join(run_dir, "**", basename), recursive=True)
    assert len(f) == 1
    f = f[0]
    cfg = load_from_yaml(f)

    from_rois = copy.deepcopy(cfg.DATASET.ROIS)
    to_rois = copy.deepcopy(cfg.LOSS.DARK.GT_ROIS)
    # cfg.DATASET.ROIS = to_rois
    print(from_rois, to_rois)

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()
    
    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )
    
    model = DarkVEModel(*model_args)

    basename = "soup.pth"
    f = glob.glob(os.path.join(run_dir, "**", basename), recursive=True)
    assert len(f) == 1
    f = f[0]
    state_dict = torch.load(f, map_location="cpu")
    
    model.load_state_dict(state_dict)
    model.eval()
    
    model.predict_vi_dict = model.dark_gt_vis

    
    return model, dm, cfg
# %%
@torch.no_grad()
def get_outs(model, trainer, dataloader):
    outs = trainer.predict(model, dataloader)
    outs = torch.stack(sum(outs, []))
    outs = outs.cpu().half().numpy()
    # outs = outs[:, vi]
    return outs

# trainer = pl.Trainer(accelerator='cuda', devices=[0], enable_progress_bar=True, precision=16)

subject_list = [s for s in subject_list if "NSD" in s]
# subject_list = ["HCP"]
# subject_list = ["NSD_01"]
# stages = ["train", "val", "test", "predict"]
stages = ["val"]
# subject_out_dict = {subject: {} for subject in subject_list}
# y_idx_dict = {subject: {} for subject in subject_list}
# vi_dict = {subject: {} for subject in subject_list}
# stage_out_dict = {stage: copy.deepcopy(subject_out_dict) for stage in stages}
# stage_y_idx_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}
# stage_y_path_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}
# for run_dir in runs[:]:
#     print(run_dir)
#     model2, dm2, cfg = load_model_dm(run_dir)
#     roi = cfg.LOSS.DARK.GT_ROIS[0]
#     for subject in subject_list[:]:
#         print(subject)
#         # if subject not in dm2.dss[0].keys():
#         #     vi = np.array([])
#         #     vi_dict[subject][roi] = vi
#         #     continue # skip subjects not in dataset
#         train_dl = dm2.train_dataloader(subject=subject, shuffle=False)
#         val_dl = dm2.val_dataloader(subject=subject)
#         test_dl = dm2.test_dataloader(subject=subject)
#         predict_dl = dm2.predict_dataloader(subject=subject) if "NSD" in subject else None
#         stage_dl_dict = {
#             "train": train_dl,
#             "val": val_dl,
#             "test": test_dl,
#             "predict": predict_dl,
#         }
#         # vi = train_dl.dataset.voxel_index.numpy()
#         # vi = train_dl.dataset.voxel_index.numpy()[dm2.roi_dict[subject][roi]]
#         whole_vi = train_dl.dataset.voxel_index.numpy()
#         vi = model2.predict_vi_dict[subject]
#         vi = whole_vi[vi]
#         vi_dict[subject][roi] = vi
#         for stage in stages:
#             dl = stage_dl_dict[stage]
#             print(run_dir, subject, stage)
#             # debug
#             # if stage == "train":
#             #     continue
            
#             if stage == "predict":
#                 if "NSD" not in subject:
#                     continue
#             outs = get_outs(model2, trainer, dl)
#             stage_out_dict[stage][subject][roi] = outs
            
#             if stage != "predict":
#                 y_paths = dl.dataset.y_paths
#                 y_idx = [int(os.path.basename(yp).split(".")[0]) for yp in y_paths]
#                 stage_y_idx_dict[stage][subject] = y_idx
#                 stage_y_path_dict[stage][subject] = y_paths
        
        # break
    # break
    
def process_one_run(run_dir):
    vi_dict = {subject: {} for subject in subject_list}
    subject_out_dict = {subject: {} for subject in subject_list}
    stage_out_dict = {stage: copy.deepcopy(subject_out_dict) for stage in stages}
    y_idx_dict = {subject: {} for subject in subject_list}
    stage_y_idx_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}
    stage_y_path_dict = {stage: copy.deepcopy(y_idx_dict) for stage in stages}

    # if not debug:
    trainer = pl.Trainer(accelerator="cuda", devices=[0], enable_progress_bar=False, precision=16)
    # else:
    #     trainer = pl.Trainer(accelerator="cpu", enable_progress_bar=True)

    model, dm, cfg = load_model_dm(run_dir)
    roi = cfg.LOSS.DARK.GT_ROIS[0]
    for subject in subject_list[:]:
        print(subject)
        train_dl = dm.train_dataloader(subject=subject, shuffle=False)
        val_dl = dm.val_dataloader(subject=subject)
        test_dl = dm.test_dataloader(subject=subject)
        predict_dl = (
            dm.predict_dataloader(subject=subject) if "NSD" in subject else None
        )
        stage_dl_dict = {
            "train": train_dl,
            "val": val_dl,
            "test": test_dl,
            "predict": predict_dl,
        }
        whole_vi = train_dl.dataset.voxel_index.numpy()
        vi = model.predict_vi_dict[subject]
        vi = whole_vi[vi]
        vi_dict[subject][roi] = vi
        
        for stage in stages:
            dl = stage_dl_dict[stage]
            # print(stage)
            print(run_dir, subject, stage)
            # debug
            # if stage == "train":
            #     continue

            if stage == "predict":
                if "NSD" not in subject:
                    continue
            outs = get_outs(model, trainer, dl) if len(vi) > 0 else np.zeros((len(dl.dataset), 0))
            # print(subject, roi, outs.shape, outs.flatten()[:10])
            stage_out_dict[stage][subject][roi] = outs
            # load_modify_save("/tmp/stage_out_dict.pth", [stage, subject, roi], outs)

            if stage != "predict":
                y_paths = dl.dataset.y_paths
                y_idx = [int(os.path.basename(yp).split(".")[0]) for yp in y_paths]
                stage_y_idx_dict[stage][subject] = y_idx
                stage_y_path_dict[stage][subject] = y_paths
                # load_modify_save("/tmp/stage_y_idx_dict.pth", [stage, subject], y_idx)
                # load_modify_save(
                #     "/tmp/stage_y_path_dict.pth", [stage, subject], y_paths
                # )

        # load_modify_save("/tmp/vi_dict.pth", [subject, roi], vi)
    
    # load_modify_save(f"/tmp/{roi}_vi_dict.pth", [subject, roi], vi_dict)
    # load_modify_save(f"/tmp/{roi}_stage_out_dict.pth", [stage, subject, roi], stage_out_dict)
    # load_modify_save(f"/tmp/{roi}_stage_y_idx_dict.pth", [stage, subject], stage_y_idx_dict)
    # load_modify_save(f"/tmp/{roi}_stage_y_path_dict.pth", [stage, subject], stage_y_path_dict)
    torch.save(vi_dict, f"/tmp/{roi}_vi_dict.pth")
    torch.save(stage_out_dict, f"/tmp/{roi}_stage_out_dict.pth")
    torch.save(stage_y_idx_dict, f"/tmp/{roi}_stage_y_idx_dict.pth")
    torch.save(stage_y_path_dict, f"/tmp/{roi}_stage_y_path_dict.pth")


def run_tune(tune_dict):
    run_dir = tune_dict["run_dir"]
    process_one_run(run_dir)
    
# %%
from ray import tune
import ray

if debug:
    ray.init(num_gpus=1)

os.system(f"rm /tmp/*.pth")

tune_dict = {"run_dir": tune.grid_search(runs[:])}
ana = tune.run(
    run_tune, config=tune_dict, resources_per_trial={"cpu": 1, "gpu": 1}, verbose=True
)

# %%
def load_and_merge(pattern):
    from flatten_dict import flatten, unflatten
    files = glob.glob(pattern)
    d = {}
    for f in files:
        dd = flatten(torch.load(f))
        print(dd.keys())
        d.update(dd)
    d = unflatten(d)
    return d

stage_out_dict = load_and_merge("/tmp/*_stage_out_dict.pth")
stage_y_idx_dict = load_and_merge("/tmp/*_stage_y_idx_dict.pth")
stage_y_path_dict = load_and_merge("/tmp/*_stage_y_path_dict.pth")
vi_dict = load_and_merge("/tmp/*_vi_dict.pth")
# %%
# gather
dark_y_dict = {stage: {} for stage in stages}
for stage in stages:
    # # debug
    # if stage == 'train':
    #     continue
    for subject in subject_list[:]:

        if stage == 'predict':
            if "NSD" not in subject:
                continue

        roi_out_dict = stage_out_dict[stage][subject]
        rois = list(roi_out_dict.keys())

        num_outs = roi_out_dict[rois[0]].shape[0]
        full_outs = np.zeros((num_outs, num_voxels_dict[subject]), dtype=np.float16)
        print(f"num_outs: {num_outs}")

        for roi in rois:
            outs = roi_out_dict[roi]
            vi = vi_dict[subject][roi]
            full_outs[:, vi] = outs

        if stage != 'train':
            dark_y_dict[stage][subject] = full_outs

        if stage != 'predict':
            y_idxs = stage_y_idx_dict[stage][subject]
            y_paths = stage_y_path_dict[stage][subject]
            for i in range(num_outs):
                y_i = full_outs[i]
                orig_save_path = y_paths[i]
                save_path = orig_save_path.replace(".npy", f".{save_postfix}.npy")
                np.save(save_path, y_i.astype(np.float16))

    #     break
    # break
print("done")
# %%
# save submission
# %%
for i in range(1, 9):
    subject = f"NSD_{i:02d}"
    outs = dark_y_dict["predict"][subject]

    mask_dir = f"/data/algonauts2023/subj{i:02d}/roi_masks"
    lh_mask = os.path.join(mask_dir, "lh.streams_challenge_space.npy")
    lh_mask = np.load(lh_mask)
    rh_mask = os.path.join(mask_dir, "rh.streams_challenge_space.npy")
    rh_mask = np.load(rh_mask)
    num_lh = lh_mask.shape[0]
    num_rh = rh_mask.shape[0]
    assert num_rh + num_lh == outs.shape[1]

    lh_outs = outs[:, :num_lh]
    rh_outs = outs[:, num_lh:]
    lh_outs = lh_outs.astype(np.float32)
    rh_outs = rh_outs.astype(np.float32)

    subject_dir = os.path.join(submission_dir, f"subj{i:02d}")  
    os.makedirs(subject_dir, exist_ok=True)
    np.save(os.path.join(subject_dir, "lh_pred_test.npy"), lh_outs)
    np.save(os.path.join(subject_dir, "rh_pred_test.npy"), rh_outs)

# # %%
# # a = np.load('/data/algonauts2023/subj01/training_split/training_fmri/lh_training_fmri.npy')
# # %%
# # a.shape
# # %%
# cfg = load_from_yaml("/workspace/configs/dino_base.yaml")
# dm = build_dm(cfg)
# dm.setup()
# # %%
# # eval
# val_y_dict = {}
# test_y_dict = {}
# nc_dict = {}
# for i in range(1, 9):
#     subject = f"NSD_{i:02d}"
#     val_dl = dm.val_dataloader(subject=subject)
#     test_dl = dm.test_dataloader(subject=subject)
    
#     nc = val_dl.dataset.noise_ceiling
#     nc_dict[subject] = nc.numpy()
    
#     for dl, stage in zip([val_dl, test_dl], ["val", "test"]):
#         y_paths = dl.dataset.y_paths
#         ys = []
#         for y_path in y_paths:
#             y = np.load(y_path)
#             ys.append(y)
#         ys = np.stack(ys)
#         if stage == "val":
#             val_y_dict[subject] = ys
#         if stage == "test":
#             test_y_dict[subject] = ys
# # %%
# def challenge_metric(y, y_pred, nc):
#     y = y.astype(np.float32)
#     y_pred = y_pred.astype(np.float32)
#     from metrics import vectorized_correlation
#     p = vectorized_correlation(y, y_pred)
#     s = p ** 2 / (nc + 1e-5)
#     # s = np.nanmedian(s)
#     return s
# # %%
# ss = []
# for subject in nc_dict.keys():
#     y = val_y_dict[subject]
#     y_pred = dark_y_dict["val"][subject]
#     nc = nc_dict[subject]
#     s = challenge_metric(y, y_pred, nc)
#     ss.append(s)
# ss = np.concatenate(ss)
# val_score = np.nanmedian(ss)
# print(f"val_score: {val_score}")
# # %%
# ss = []
# for subject in nc_dict.keys():
#     y = test_y_dict[subject]
#     y_pred = dark_y_dict["test"][subject]
#     nc = nc_dict[subject]
#     s = challenge_metric(y, y_pred, nc)
#     ss.append(s)
# ss = np.concatenate(ss)
# test_score = np.nanmedian(ss)
# print(f"test_score: {test_score}")
# # %%
# ss = []
# for subject in nc_dict.keys():
#     y = np.concatenate([val_y_dict[subject], test_y_dict[subject]], axis=0)[:]
#     y_pred = np.concatenate([dark_y_dict["val"][subject], dark_y_dict["test"][subject]], axis=0)[:]
#     nc = nc_dict[subject]
#     s = challenge_metric(y, y_pred, nc)
#     ss.append(s)
# ss = np.concatenate(ss)
# val_score = np.nanmedian(ss)
# print(f"both: {val_score}")
# # %%
# import matplotlib.pyplot as plt
# ss[ss>1] = 1
# plt.hist(ss, bins=100)
# plt.show()
# # %%
# for i in range(8):
#     start = i * 100
#     end = (i + 1) * 100
#     print(f"start: {start}, end: {end}")
#     ss = []
#     for subject in nc_dict.keys():
#         y = np.concatenate([val_y_dict[subject], test_y_dict[subject]], axis=0)[start:end]
#         y_pred = np.concatenate([dark_y_dict["val"][subject], dark_y_dict["test"][subject]], axis=0)[start:end]
#         nc = nc_dict[subject]
#         s = challenge_metric(y, y_pred, nc)
#         ss.append(s)
#     ss = np.concatenate(ss)
#     val_score = np.nanmedian(ss)
#     print(f"both: {val_score}")
# # %%
# for i in range(2):
#     start = i * 400
#     end = (i + 1) * 400
#     print(f"start: {start}, end: {end}")
#     ss = []
#     for subject in nc_dict.keys():
#         y = np.concatenate([val_y_dict[subject], test_y_dict[subject]], axis=0)[start:end]
#         y_pred = np.concatenate([dark_y_dict["val"][subject], dark_y_dict["test"][subject]], axis=0)[start:end]
#         nc = nc_dict[subject]
#         s = challenge_metric(y, y_pred, nc)
#         ss.append(s)
#     ss = np.concatenate(ss)
#     val_score = np.nanmedian(ss)
#     print(f"both: {val_score}")
# # %%

# %%
