import os
import yaml
import argparse
import json
import numpy as np
import pandas as pd
import scipy.io as sio
import random
from copy import deepcopy
from tqdm import tqdm

import torch
import utils
import nibabel as nib
import h5py
from scipy import stats

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

parser = argparse.ArgumentParser(description="test model")
parser.add_argument("--name", type=str)
parser.add_argument("--checkpoints_dir", type=str, default="checkpoints/")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch_size", type=int)
parser.add_argument("--subj_list", nargs='+', type=int)
parser.add_argument("--nsddir", type=str)
parser.add_argument("--space", type=str)
parser.add_argument("--func", type=str)
parser.add_argument("--clip_model", type=str)
parser.add_argument("--norm_nii", action='store_true')
parser.add_argument("--num_blocks", type=int)
parser.add_argument("--patch_size", type=int)
parser.add_argument("--patch_type", type=str, default="conv")
parser.add_argument("--patch_drop", type=float, default=0.0)
parser.add_argument("--attn_drop", type=float, default=0.0)
parser.add_argument("--block_drop", type=float, default=0.0)
parser.add_argument("--nii_mask", type=str, default="brain")
parser.add_argument("--tag", type=str)
parser.add_argument("--topk_list", nargs='+', type=int)
args = parser.parse_args()
utils.seed_everything(args.seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device:",device)
num_devices = 1

clip_model = {'CLIP-ViT-H-14': "/opt/data/private/huggingface/models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K", 'CLIP-ViT-L-14': "/opt/data/private/huggingface/models--openai--clip-vit-large-patch14", "CLIP-ViT-bigG-14": "/opt/data/private/huggingface/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k"}

print("Loading CLIP model from ", clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']))
from transformers import CLIPConfig, CLIPModel, CLIPProcessor, CLIPImageProcessor
processor = CLIPProcessor.from_pretrained(clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']))
clip = CLIPModel.from_pretrained(clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']))
clip = clip.to(device)
clip.eval()
clip.requires_grad_(False)

from torch.utils.data import DataLoader, Dataset

class NSDDataset_New_all_triggle_image_text_int16(Dataset):
    def __init__(self, data_list, stimulus=None, tokenized_captions=None, image_processor=None):
        super().__init__()
        self.data_list = data_list
        self.stimulus = stimulus
        self.tokenized_captions = tokenized_captions
        self.image_processor = image_processor

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        # nii = torch.load(data['nii'])[None]#.float()/300.
        nii = [torch.load(n) for n in data['nii']]
        if len(nii)==1:
            nii = torch.stack(nii).repeat(3,1,1,1)
        elif len(nii)==2:
            nii = torch.stack([nii[0], nii[1], nii[0]])
        else:
            nii = torch.stack(nii)
        # stim = torch.load(data['stim'])
        stim = torch.from_numpy(self.stimulus[data['stim_idx']])
        if self.image_processor is not None:
            stim = self.image_processor(stim, return_tensors="pt", do_rescale=False).pixel_values.squeeze(0)

        coco = data['cocoid']
        n = len(self.tokenized_captions[str(coco)]['input_ids'])
        random_idx = random.randint(0, n-1)
        input_ids = self.tokenized_captions[str(coco)]['input_ids'][random_idx]
        attention_mask = self.tokenized_captions[str(coco)]['attention_mask'][random_idx]
        return nii, stim, input_ids, attention_mask, data['subj']

nii_path = os.path.join(args.nsddir,'nsddata_betas','ppdata_split_pth','{:s}','{:s}','subj{:02d}_betas_session{:02d}', '{:03d}.pth')
# stim_path = os.path.join(args.nsddir,'nsddata_stimuli','stimuli','nsd',args.target_emb,'{:05d}.pth')

nsd_info = sio.loadmat(os.path.join(args.nsddir,'nsddata','experiments','nsd','nsd_expdesign.mat'))
stim_info = pd.read_csv(os.path.join(args.nsddir,'nsddata','experiments','nsd', 'nsd_stim_info_merged.csv'), index_col=0)
cocoId = stim_info['cocoId']
stim_sort = (nsd_info['subjectim'][:,nsd_info['masterordering']-1]-1).squeeze()
up_sess_bounds = [40, 40, 32, 30, 40, 32, 40, 30] 
# with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'nsd_captions.json'), 'r') as f:
with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'qwen2_5_vl_captions_all.json'), 'r') as f:
# with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'llava13b_captions_all.json'), 'r') as f:
    captions = json.load(f)
    
test_dict = {}
if args.norm_nii:
    norm_dict = {}
for subj in args.subj_list:
    test_list = []
    stim_unique = pd.Series(stim_sort[subj-1]).unique()
    for i in range(len(stim_unique)):
        nii_loc = np.where(stim_sort[subj-1]==stim_unique[i])[0]
        up_loc = np.where(nii_loc>=up_sess_bounds[subj-1]*750)[0]
        if len(up_loc)!=0:
            # print(f"delete nii: {list(up_loc)} in stim: {stim_unique[i]:05d}")
            nii_loc = np.delete(nii_loc, up_loc)
        if len(nii_loc)==0:
            continue
        if stim_unique[i] in nsd_info['sharedix']-1:
            # test
            test_list.append({
                'subj': subj,
                # 'stim': stim_path.format(stim_unique[i]),
                'stim_idx': stim_unique[i],
                'nii': [os.path.join(nii_path.format(args.space,args.func,subj,(nloc//750)+1,nloc%750)) for nloc in nii_loc],
                'text': captions[str(cocoId[stim_unique[i]])],
                'cocoid': cocoId[stim_unique[i]],
            })
    print(f"subj {subj} test set size: {len(test_list)}")
    test_dict['subj{:02d}'.format(subj)] = test_list
    if args.norm_nii:
        norm_dict.update({'subj{:02d}_mean'.format(subj): torch.load(f'{args.nsddir}/nsddata_betas/mean_std/subj{subj}_{args.space}_{args.func}_mean.pth')})
        norm_dict.update({'subj{:02d}_std'.format(subj): torch.load(f'{args.nsddir}/nsddata_betas/mean_std/subj{subj}_{args.space}_{args.func}_std.pth')})

tokenized_captions = None
try:
    tokenized_captions = torch.load('tools/tokenized_captions_qwen25vl.pt')
    # tokenized_captions = torch.load('tools/tokenized_captions.pt')
    # tokenized_captions = torch.load('tools/tokenized_captions_llava13b.pt')
except:
    tokenized_captions = {}
    for k, v in tqdm(captions.items()):
        tokenized_captions[k] =processor.tokenizer(captions[k], padding="max_length", truncation=True, max_length=77, return_tensors="pt")

stimulus = h5py.File(os.path.join(args.nsddir,'nsddata_stimuli','stimuli','nsd','coco_images_224_float16.hdf5'), 'r')['images']


test_dls = {}
for k, v in test_dict.items():
    test_set = NSDDataset_New_all_triggle_image_text_int16(data_list=v, stimulus=stimulus, tokenized_captions=tokenized_captions, image_processor=processor.image_processor)
    test_dls[k] = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=4)
test_dls.keys()

def cal_padding_list(patch_size):
    base_size = [91, 109, 91]
    padding_list = []
    for dim in reversed(base_size):
        pad_size = (patch_size - dim % patch_size) % patch_size
        padding_list.extend([pad_size // 2, pad_size // 2 + (1 if pad_size % 2 else 0)])
    return padding_list
padding_list = cal_padding_list(args.patch_size)

if args.nii_mask=='brain':
    mask = torch.from_numpy(nib.load(os.path.join("{:s}", "fsl_tmp", "MNI152_T1_2mm_brain_mask.nii.gz").format(args.nsddir)).get_fdata().astype(np.uint8))
    mask = torch.nn.functional.pad(mask, pad=padding_list, mode='constant', value=0)
    token_mask = torch.nn.functional.conv3d(mask[None,None].float(), torch.ones((1, 1, args.patch_size, args.patch_size, args.patch_size)).float(), stride=(args.patch_size,args.patch_size,args.patch_size)).byte().flatten(2).squeeze()
    token_ids = torch.where(token_mask!=0)[0].tolist()
elif args.nii_mask=='roi':
    mask = torch.from_numpy(nib.load("{:s}/nsddata/ppdata/all/{:s}/roi/nsdgeneral.nii.gz".format(args.nsddir, args.space)).get_fdata().astype(np.uint8))
    mask = torch.nn.functional.pad(mask, pad=padding_list, mode='constant', value=0)
    token_mask = torch.nn.functional.conv3d(mask[None,None].float(), torch.ones((1, 1, args.patch_size, args.patch_size, args.patch_size)).float(), stride=(args.patch_size,args.patch_size,args.patch_size)).byte().flatten(2).squeeze()
    token_ids = torch.where(token_mask!=0)[0].tolist()
else:
    token_ids = None

PATCH_SIZE = args.patch_size 

NUM_BLOCKS = args.num_blocks 
PATCH_TYPE = args.patch_type 
ATTN_DROP = args.attn_drop
BLOCK_DROP = args.block_drop
PATCH_DROP = args.patch_drop

from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from vit3d import ViTEmbeddings3D, ViTConfig, calc_image_size

clip_cfg = CLIPVisionConfig(
    attention_dropout=ATTN_DROP,
    dropout=BLOCK_DROP,
    hidden_act=clip.config.vision_config.hidden_act,
    hidden_size=clip.config.vision_config.hidden_size,
    initializer_factor=1.0,
    initializer_range=0.02,
    intermediate_size=clip.config.vision_config.intermediate_size,
    layer_norm_eps=clip.config.vision_config.layer_norm_eps,
    model_type="clip_vision_model",
    num_attention_heads=clip.config.vision_config.num_attention_heads,
    num_hidden_layers=NUM_BLOCKS,
    projection_dim=clip.config.vision_config.projection_dim,
    patch_size=14,
)

model_args = deepcopy(clip_cfg)
model_args.image_size = calc_image_size(PATCH_SIZE)
model_args.patch_size = (PATCH_SIZE, PATCH_SIZE, PATCH_SIZE)
model_args.hidden_dropout_prob = PATCH_DROP
model_args.num_channels=1
model_args.token_ids = token_ids
model_args.patch_type = PATCH_TYPE

clip_fmri = CLIPVisionModelWithProjection(clip_cfg)
clip_fmri.vision_model.embeddings = ViTEmbeddings3D(model_args)
utils.count_params(clip_fmri)

# ================= load ckpt ================ #
checkpoint_dir = os.path.join(args.checkpoints_dir, args.name)
print(f"\n---loading {checkpoint_dir}/{args.tag}.pth ckpt---\n")

state_dict = torch.load(checkpoint_dir+f'/{args.tag}.pth', map_location='cpu')
clip_fmri.load_state_dict(state_dict, strict=False)
json_save_path = os.path.join(checkpoint_dir, f'{args.tag}.json')
with open(json_save_path, 'r') as json_file:
    state = json.load(json_file)
    epoch_idx = state['epoch_idx']
print("Epoch", epoch_idx)
clip_fmri = clip_fmri.to(device)

def cal_norm_nii(nii, mean, std):
    mean = mean.to(nii.device)
    std = std.to(nii.device)
    nii = (nii - mean) / (std+1e-5)
    return nii

clip_fmri.eval()

for subj, test_dataloader in test_dls.items():
    print(f"\n===================== {subj} =====================\n")
    if args.norm_nii:
        norm_mean = norm_dict[f'{subj}_mean']
        norm_std = norm_dict[f'{subj}_std']
    
    clip_fmri_emb = []
    clip_image_emb = []
    clip_text_emb = []
    with torch.no_grad():
        loop = tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc=f"Test: ", dynamic_ncols=True)
        for batch_idx, batch in loop:
            logs={}
            nii_batch, image_batch, input_ids_batch, attention_mask_batch, _ = batch
            nii_batch = nii_batch.to(device)
            image_batch = image_batch.to(device)
            input_ids_batch = input_ids_batch.to(device)
            attention_mask_batch = attention_mask_batch.to(device)

            nii_batch = nii_batch.float()/300.
            if args.norm_nii:
                nii_batch = cal_norm_nii(nii_batch, norm_mean, norm_std)
            nii_batch = torch.nn.functional.pad(nii_batch, pad=padding_list, mode='constant', value=0)

            nii_batch = nii_batch.mean(dim=1, keepdim=True)
            image_batch = image_batch.to(nii_batch.dtype)
            
            with torch.no_grad():
                output = clip(input_ids=input_ids_batch, attention_mask=attention_mask_batch, pixel_values=image_batch)
                clip_feat_text = output.text_embeds
                clip_feat_img = output.image_embeds
                
                fmri_emb = clip_fmri(nii_batch)
                clip_feat_fmri = fmri_emb.image_embeds
            
            clip_text_emb.append(clip_feat_text)
            clip_image_emb.append(clip_feat_img)
            clip_fmri_emb.append(clip_feat_fmri)

    clip_text_emb = torch.cat(clip_text_emb)
    clip_image_emb = torch.cat(clip_image_emb)
    clip_fmri_emb = torch.cat(clip_fmri_emb)

    for k in args.topk_list:
        print(f"===================== retrieval Top {k} image =====================")
        utils.seed_everything(args.seed)
        percent_correct_fwds, percent_correct_bwds = utils.retrieval_score_topk(clip_fmri_emb, clip_image_emb, k=k)

        percent_correct_fwd = np.mean(percent_correct_fwds)
        fwd_sd = np.std(percent_correct_fwds) / np.sqrt(len(percent_correct_fwds))
        fwd_ci = stats.norm.interval(0.95, loc=percent_correct_fwd, scale=fwd_sd)

        percent_correct_bwd = np.mean(percent_correct_bwds)
        bwd_sd = np.std(percent_correct_bwds) / np.sqrt(len(percent_correct_bwds))
        bwd_ci = stats.norm.interval(0.95, loc=percent_correct_bwd, scale=bwd_sd)

        print(f"fwd percent_correct: {percent_correct_fwd:.4f} 95% CI: [{fwd_ci[0]:.4f},{fwd_ci[1]:.4f}]")
        print(f"bwd percent_correct: {percent_correct_bwd:.4f} 95% CI: [{bwd_ci[0]:.4f},{bwd_ci[1]:.4f}]")

        print(f"===================== retrieval Top {k} text =====================")
        utils.seed_everything(args.seed)
        percent_correct_fwds, percent_correct_bwds = utils.retrieval_score_topk(clip_fmri_emb, clip_text_emb, k=k)

        percent_correct_fwd = np.mean(percent_correct_fwds)
        fwd_sd = np.std(percent_correct_fwds) / np.sqrt(len(percent_correct_fwds))
        fwd_ci = stats.norm.interval(0.95, loc=percent_correct_fwd, scale=fwd_sd)

        percent_correct_bwd = np.mean(percent_correct_bwds)
        bwd_sd = np.std(percent_correct_bwds) / np.sqrt(len(percent_correct_bwds))
        bwd_ci = stats.norm.interval(0.95, loc=percent_correct_bwd, scale=bwd_sd)

        print(f"fwd percent_correct: {percent_correct_fwd:.4f} 95% CI: [{fwd_ci[0]:.4f},{fwd_ci[1]:.4f}]")
        print(f"bwd percent_correct: {percent_correct_bwd:.4f} 95% CI: [{bwd_ci[0]:.4f},{bwd_ci[1]:.4f}]")
