# %%
import argparse
import copy
import json
import sys
import traceback
import re
import logging
from time import sleep
import numpy as np

import torch
import os
import pandas as pd
import ray
from ray import tune

import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from PIL import Image, ImageDraw

import pytorch_lightning as pl

import cortex
from matplotlib.pyplot import cm
from config import AutoConfig

from config_utils import flatten_dict

from IPython.display import display, HTML, clear_output

from datamodule import AllDatamodule, build_dm
from models import VEModel

plt.style.use("dark_background")


# %%
def challenge_metric(y, y_pred, nc):
    y = y.astype(np.float32)
    y_pred = y_pred.astype(np.float32)
    from metrics import vectorized_correlation

    p = vectorized_correlation(y, y_pred)
    s = p**2 / (nc + 1e-5)
    # s = np.nanmedian(s)
    return s

def p_metric(y, y_pred):
    y = y.astype(np.float32)
    y_pred = y_pred.astype(np.float32)
    from metrics import vectorized_correlation

    p = vectorized_correlation(y, y_pred)
    return p


# %%
from config_utils import load_from_yaml

# cfg = load_from_yaml("/workspace/configs/crn_base.yaml")
cfg: AutoConfig = load_from_yaml("/workspace/configs/dino_mania.yaml")
# cfg.DATASET.DARK_POSTFIX = f".mania_veroi_m_gen2_darkfull"
# cfg.DATASET.DARK_POSTFIX = f".mania_veroi_m_gen3_darkfull"
cfg.DATASET.DARK_POSTFIX = f".mania_veroi_m_gen1"
# cfg.DATASET.DARK_POSTFIX = f".random_m_gen1"
# cfg.DATASET.DARK_POSTFIX = f".veroi_m_gen2n_darkgt_darkfull"
# cfg.LOSS.DARK.USE = True
dm: AllDatamodule = build_dm(cfg)
dm.setup()
num_voxels_dict = dm.num_voxel_dict
subject_list = dm.subject_list
# subject_list = ["NSD_08"]
# stages = ["test", "val"]
stages = ["test"]
for stage in stages:
    ss = []
    for subject in subject_list:

        # print(f"subject: {subject}, stage: {stage}")
        if stage == "test":
            dl = dm.test_dataloader(subject=subject)
        elif stage == "val":
            dl = dm.val_dataloader(subject=subject)
        else:
            raise ValueError("stage must be test or val")

        ys, darks = [], []
        for batch in dl:
            y = batch[1]
            dark = batch[-1]
            ys.append(y)
            darks.append(dark)
        ys = torch.stack(sum(ys, [])).numpy()
        dark = torch.stack(sum(darks, [])).numpy()

        s = p_metric(ys, dark)
        score = np.nanmean(s)
        print(f"subject: {subject}, stage: {stage}, score: {score}")
        ss.append(s)
    ss = np.concatenate(ss)
    score = np.nanmean(ss)
    print(f"stage: {stage}, score: {score}")
# %%
for stage in stages:
    ss = []
    for subject in subject_list:
        if "NSD" not in subject:
            continue
        # print(f"subject: {subject}, stage: {stage}")
        if stage == "test":
            dl = dm.test_dataloader(subject=subject)
        elif stage == "val":
            dl = dm.val_dataloader(subject=subject)
        else:
            raise ValueError("stage must be test or val")

        ys, darks = [], []
        for batch in dl:
            y = batch[1]
            dark = batch[-1]
            ys.append(y)
            darks.append(dark)
        ys = torch.stack(sum(ys, [])).numpy()
        dark = torch.stack(sum(darks, [])).numpy()

        nc = dl.dataset.noise_ceiling

        s = challenge_metric(ys, dark, nc)
        score = np.nanmedian(s)
        print(f"subject: {subject}, stage: {stage}, score: {score}")
        ss.append(s)
    ss = np.concatenate(ss)
    score = np.nanmedian(ss)
    print(f"stage: {stage}, score: {score}")
# %%

# %%
df
# %%
predict_dark_dict = torch.load("/data/results/xobf/dark_y_dict_predict.pth")
# %%
for k, v in predict_dark_dict.items():
    print(k, v.shape)
# %%
submission_dir = "/data/algonauts_2023_challenge_submission/htroi_gen1_sepmerge"
for i in range(1, 9):
    subject = f"NSD_{i:02d}"
    outs = predict_dark_dict[subject]

    mask_dir = f"/data/algonauts2023/subj{i:02d}/roi_masks"
    lh_mask = os.path.join(mask_dir, "lh.streams_challenge_space.npy")
    lh_mask = np.load(lh_mask)
    rh_mask = os.path.join(mask_dir, "rh.streams_challenge_space.npy")
    rh_mask = np.load(rh_mask)
    num_lh = lh_mask.shape[0]
    num_rh = rh_mask.shape[0]
    assert num_rh + num_lh == outs.shape[1]

    lh_outs = outs[:, :num_lh]
    rh_outs = outs[:, num_lh:]
    lh_outs = lh_outs.astype(np.float32)
    rh_outs = rh_outs.astype(np.float32)

    subject_dir = os.path.join(submission_dir, f"subj{i:02d}")  
    os.makedirs(subject_dir, exist_ok=True)
    np.save(os.path.join(subject_dir, "lh_pred_test.npy"), lh_outs)
    np.save(os.path.join(subject_dir, "rh_pred_test.npy"), rh_outs)
# %%
