# %%
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # to disable parallelism warning from transformers
import yaml
import argparse
import logging
import json
import numpy as np
import pandas as pd
import scipy.io as sio
from copy import deepcopy
from tqdm import tqdm

import torch
import accelerate
import utils
import nibabel as nib
import h5py

from train import train

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.benchmark = True # fixes Conv3D if used
# torch.autograd.set_detect_anomaly(True)
torch.set_num_threads(4)
# torch.set_default_device('cpu')

# %%
### Multi-GPU config ###
# local_rank = os.getenv('RANK')
# if local_rank is None: 
#     local_rank = 0
# else:
#     local_rank = int(local_rank)
# print("LOCAL RANK ", local_rank)  

# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
# if True:
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = ""
    jupyter_args += "--name train_ipynb "
    jupyter_args += "--mixed_precision fp16 "
    jupyter_args += "--log_dir ../logs/ "
    jupyter_args += "--log_step 50 "
    jupyter_args += "--checkpoints_dir ../checkpoints/ "
    jupyter_args += "--save_step 999 "
    jupyter_args += "--val_step 1 "
    jupyter_args += "--seed 42 "
    jupyter_args += "--num_epochs 150 "
    jupyter_args += "--lr 3e-4 "
    jupyter_args += "--batch_size 128 "
    jupyter_args += "--gradient_accumulation_steps 1 "
    jupyter_args += "--subj 1 "
    jupyter_args += "--subj_list 1 "
    jupyter_args += "--nsddir ../nsd "
    jupyter_args += "--space MNI_2mm "
    jupyter_args += "--func betas_fithrf_GLMdenoise_RR "
    jupyter_args += "--clip_model CLIP-ViT-H-14 "
    jupyter_args += "--norm_nii "
    jupyter_args += "--text_scale 0 "
    jupyter_args += "--mixup_pct 1 "
    jupyter_args += "--patch_size 14 "
    jupyter_args += "--num_blocks 12 "
    jupyter_args += "--patch_drop 0 "
    jupyter_args += "--attn_drop 0 "
    jupyter_args += "--block_drop 0 "
    jupyter_args += "--nii_mask brain "
    jupyter_args += "--patch_type conv "
    jupyter_args += "--buffer_size 4096 "

    print(jupyter_args)
    jupyter_args = jupyter_args.split()

# %%
parser = argparse.ArgumentParser(description="train model")
parser.add_argument("--name", type=str, help="Name of the experiment")
parser.add_argument("--mixed_precision", type=str, default="no")
parser.add_argument("--log_dir", type=str, default="logs/")
parser.add_argument("--log_step", type=int, default=1)
parser.add_argument("--checkpoints_dir", type=str, default="checkpoints/")
parser.add_argument("--save_step", type=int, default=999)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--val_step", type=int, default=1)
parser.add_argument("--use_resume", action='store_true')
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument("--num_epochs", type=int)
parser.add_argument("--lr", type=float)
parser.add_argument("--batch_size", type=int)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--subj_list", nargs='+', type=int)
parser.add_argument("--nsddir", type=str)
parser.add_argument("--space", type=str)
parser.add_argument("--func", type=str)
parser.add_argument("--clip_model", type=str)
parser.add_argument("--norm_nii", action='store_true')
parser.add_argument("--text_scale", type=float)
parser.add_argument("--mixup_pct", type=float)
parser.add_argument("--patch_size", type=int)
parser.add_argument("--num_blocks", type=int)
parser.add_argument("--patch_drop", type=float, default=0.0)
parser.add_argument("--attn_drop", type=float, default=0.0)
parser.add_argument("--block_drop", type=float, default=0.0)
parser.add_argument("--nii_mask", type=str, default="brain")
parser.add_argument("--patch_type", type=str, default="conv")
parser.add_argument("--use_image_aug", action='store_true')
parser.add_argument("--mixin", action='store_true')
parser.add_argument("--buffer_size", type=int, default=0)
parser.add_argument("--local_loss", action='store_true')
parser.add_argument("--local_loss2", action='store_true')
parser.add_argument("--gather_with_grad", action='store_true')

if utils.is_interactive():
# if True:
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

utils.seed_everything(args.seed)


# %%
accelerator = accelerate.Accelerator(
    mixed_precision=args.mixed_precision,
    log_with="tensorboard",
    project_dir=args.log_dir,
)
if accelerator.is_main_process:
    accelerator.init_trackers(args.name)
    # save config
    with open(f"{args.log_dir}/{args.name}/config.yaml", 'w') as file:
        yaml.safe_dump(vars(args), file, default_flow_style=False)

os.makedirs(f"{args.log_dir}/{args.name}", exist_ok=True)
logging.basicConfig(
    filename=f"{args.log_dir}/{args.name}/info.log",
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = accelerate.logging.get_logger(__name__)

logger.info(accelerator.state, main_process_only=False)
print(accelerator.state)

device = accelerator.device
print("device:",device)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
# num_devices = torch.cuda.device_count() if distributed else 1
num_devices = world_size if distributed else 1
if num_devices==0: num_devices = 1
num_workers = num_devices
global_batch_size = args.batch_size * num_devices
print("PID of this process =",os.getpid())

print("distributed =",distributed, "num_devices =", num_devices, "local rank =", accelerator.local_process_index, "world size =", world_size)
logger.info(f"distributed = {distributed} num_devices = {num_devices} local rank = {accelerator.local_process_index} world size = {world_size}", main_process_only=False)
print = accelerator.print # only print if local_rank=0

# %%
clip_model = {'CLIP-ViT-H-14': "/opt/data/private/huggingface/models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K", 'CLIP-ViT-L-14': "/opt/data/private/huggingface/models--openai--clip-vit-large-patch14", "CLIP-ViT-bigG-14": "/opt/data/private/huggingface/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k"}
print("Loading CLIP model from ", clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']))
from transformers import CLIPConfig, CLIPModel, CLIPProcessor, CLIPImageProcessor
image_processor = CLIPImageProcessor.from_pretrained(clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']), use_fast=True)
clip = CLIPModel.from_pretrained(clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']), torch_dtype=torch.float16)
clip = clip.to(device)
clip.eval()
clip.requires_grad_(False)
# %%
from torch.utils.data import DataLoader
from dataset import NSDDataset_New_1_triggle_image_text_int16

nii_path = os.path.join(args.nsddir,'nsddata_betas','ppdata_split_pth','{:s}','{:s}','subj{:02d}_betas_session{:02d}', '{:03d}.pth')
nsd_info = sio.loadmat(os.path.join(args.nsddir,'nsddata','experiments','nsd','nsd_expdesign.mat'))
stim_info = pd.read_csv(os.path.join(args.nsddir,'nsddata','experiments','nsd', 'nsd_stim_info_merged.csv'), index_col=0)
cocoId = stim_info['cocoId']
stim_sort = (nsd_info['subjectim'][:,nsd_info['masterordering']-1]-1).squeeze()
up_sess_bounds = [40, 40, 32, 30, 40, 32, 40, 30] 
# with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'nsd_captions.json'), 'r') as f:
with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'qwen2_5_vl_captions_all.json'), 'r') as f:
# with open(os.path.join(args.nsddir, 'nsddata_stimuli/stimuli/nsd/annotations', 'llava13b_captions_all.json'), 'r') as f:
    captions = json.load(f)

train_dict = {}
val_dict = {}
if args.norm_nii:
    norm_dict = {}
for subj in args.subj_list:
    train_list = []
    val_list = []
    for session in range(1, up_sess_bounds[subj-1]+1):
        for n in range(750):
            stim_n = 750*(session-1)+n
            if stim_sort[subj-1, stim_n] in nsd_info['sharedix']-1:
                val_list.append({
                    'subj': subj,
                    'subj_label': args.subj_list.index(subj),
                    'stim_idx': stim_sort[subj-1, stim_n],
                    'nii': os.path.join(nii_path.format(args.space,args.func,subj,session,n)),
                    'text': captions[str(cocoId[stim_sort[subj-1, stim_n]])],
                    'cocoid': cocoId[stim_sort[subj-1, stim_n]],
                })
            else:
                train_list.append({
                    'subj': subj,
                    'subj_label': args.subj_list.index(subj),
                    'stim_idx': stim_sort[subj-1, stim_n],
                    'nii': os.path.join(nii_path.format(args.space,args.func,subj,session,n)),
                    'text': captions[str(cocoId[stim_sort[subj-1, stim_n]])],
                    'cocoid': cocoId[stim_sort[subj-1, stim_n]],
                })
    train_dict['subj{:02d}'.format(subj)] = train_list
    val_dict['subj{:02d}'.format(subj)] = val_list
    if args.norm_nii:
        norm_dict.update({'subj{:02d}_mean'.format(subj): torch.load(f'{args.nsddir}/nsddata_betas/mean_std/subj{subj}_{args.space}_{args.func}_mean.pth')})
        norm_dict.update({'subj{:02d}_std'.format(subj): torch.load(f'{args.nsddir}/nsddata_betas/mean_std/subj{subj}_{args.space}_{args.func}_std.pth')})

# %%
# tc_file = 'tools/tokenized_captions.pt'
# tc_file = 'tools/tokenized_captions_llava13b.pt'
tc_file = 'tools/tokenized_captions_qwen25vl.pt'
try:
    tokenized_captions = torch.load(tc_file)
except:
    from transformers import AutoTokenizer
    tokenized_captions = {}
    tokenizer = AutoTokenizer.from_pretrained(clip_model.get(args.clip_model, clip_model['CLIP-ViT-H-14']))
    for k, v in tqdm(captions.items()):
        tokenized_captions[k] =tokenizer(captions[k], padding="max_length", truncation=True, max_length=77, return_tensors="pt")
    torch.save(tokenized_captions, tc_file)
stimulus = h5py.File(os.path.join(args.nsddir,'nsddata_stimuli','stimuli','nsd','coco_images_224_float16.hdf5'), 'r', swmr=True)['images']

# %%
train_set_list=[]
for train_values in train_dict.values(): train_set_list+=train_values
train_set = NSDDataset_New_1_triggle_image_text_int16(train_set_list, stimulus=stimulus, tokenized_captions=tokenized_captions, image_processor=image_processor)
train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=8)

val_set_list=[]
for val_values in val_dict.values(): val_set_list+=val_values
val_set = NSDDataset_New_1_triggle_image_text_int16(val_set_list, stimulus=stimulus, tokenized_captions=tokenized_captions, image_processor=image_processor)
val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=8)
num_iterations_per_epoch = len(train_dataloader)//num_devices
print("batch_size =", args.batch_size, "num_iterations_per_epoch =", num_iterations_per_epoch)
logger.info(f"batch_size = {args.batch_size} num_iterations_per_epoch = {num_iterations_per_epoch}")

# %%
def cal_padding_list(patch_size):
    base_size = [91, 109, 91]
    padding_list = []
    for dim in reversed(base_size):
        pad_size = (patch_size - dim % patch_size) % patch_size
        padding_list.extend([pad_size // 2, pad_size // 2 + (1 if pad_size % 2 else 0)])
    return padding_list
padding_list = cal_padding_list(args.patch_size)

if args.nii_mask=='brain':
    mask = torch.from_numpy(nib.load(os.path.join("{:s}", "fsl_tmp", "MNI152_T1_2mm_brain_mask.nii.gz").format(args.nsddir)).get_fdata().astype(np.uint8))
    mask = torch.nn.functional.pad(mask, pad=padding_list, mode='constant', value=0)
    token_mask = torch.nn.functional.conv3d(mask[None,None].float(), torch.ones((1, 1, args.patch_size, args.patch_size, args.patch_size)).float(), stride=(args.patch_size,args.patch_size,args.patch_size)).byte().flatten(2).squeeze()
    token_ids = torch.where(token_mask!=0)[0].tolist()
elif args.nii_mask=='roi':
    mask = torch.from_numpy(nib.load("{:s}/nsddata/ppdata/all/{:s}/roi/nsdgeneral.nii.gz".format(args.nsddir, args.space)).get_fdata().astype(np.uint8))
    mask = torch.nn.functional.pad(mask, pad=padding_list, mode='constant', value=0)
    token_mask = torch.nn.functional.conv3d(mask[None,None].float(), torch.ones((1, 1, args.patch_size, args.patch_size, args.patch_size)).float(), stride=(args.patch_size,args.patch_size,args.patch_size)).byte().flatten(2).squeeze()
    token_ids = torch.where(token_mask!=0)[0].tolist()
else:
    token_ids = None
# %%
PATCH_SIZE = args.patch_size
NUM_BLOCKS = args.num_blocks
PATCH_TYPE = args.patch_type
ATTN_DROP = args.attn_drop
BLOCK_DROP = args.block_drop
PATCH_DROP = args.patch_drop
# %%
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from vit3d import ViTEmbeddings3D, ViTConfig, calc_image_size

clip_cfg = CLIPVisionConfig(
    attention_dropout=ATTN_DROP,
    dropout=BLOCK_DROP,
    hidden_act=clip.config.vision_config.hidden_act,
    hidden_size=clip.config.vision_config.hidden_size,
    # image_size=224,
    initializer_factor=1.0,
    initializer_range=0.02,
    intermediate_size=clip.config.vision_config.intermediate_size,
    layer_norm_eps=clip.config.vision_config.layer_norm_eps,
    model_type="clip_vision_model",
    num_attention_heads=clip.config.vision_config.num_attention_heads,
    # num_channels=1,
    num_hidden_layers=NUM_BLOCKS,
    projection_dim=clip.config.vision_config.projection_dim,
    patch_size=14,
)

model_args = deepcopy(clip_cfg)
model_args.image_size = calc_image_size(PATCH_SIZE)
model_args.patch_size = (PATCH_SIZE, PATCH_SIZE, PATCH_SIZE)
model_args.hidden_dropout_prob = PATCH_DROP
model_args.num_channels=1
model_args.token_ids = token_ids
model_args.patch_type = PATCH_TYPE

# %%
clip_fmri = CLIPVisionModelWithProjection(clip_cfg)
clip_fmri.vision_model
clip_fmri.vision_model.embeddings = ViTEmbeddings3D(model_args)
# %%
if args.use_image_aug:
    import kornia
    from kornia.augmentation.container import AugmentationSequential
    img_augment = AugmentationSequential(
        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
        kornia.augmentation.Resize((224, 224)),
        kornia.augmentation.RandomHorizontalFlip(p=0.5),
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
        kornia.augmentation.RandomGrayscale(p=0.3),
        data_keys=["input"],
    )
# %%
no_decay = ['bias', 'norm', 'LayerNorm']
opt_grouped_parameters = [
    {'params': [p for n, p in clip_fmri.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in clip_fmri.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=args.lr)
total_steps=len(train_dataloader)//args.gradient_accumulation_steps*args.num_epochs

pct_start=2/args.num_epochs
pct_start=0.01
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, pct_start=pct_start, total_steps=total_steps, final_div_factor=1000, last_epoch=-1)
# %%
checkpoint_dir = os.path.join(args.checkpoints_dir, args.name)
os.makedirs(checkpoint_dir, exist_ok=True)

# %%
torch.cuda.empty_cache()
## train
clip_fmri, optimizer, lr_scheduler, train_dataloader, val_dataloader = accelerator.prepare(clip_fmri, optimizer, lr_scheduler, train_dataloader, val_dataloader)

scheduler = {'scheduler': lr_scheduler,
        'interval': 'step',
        'frequency': 1}

# ================= resume ================ #
start_epoch = 0
if args.use_resume:
    state_path = os.path.join(checkpoint_dir, f'state_last')
    accelerator.load_state(state_path)
    json_save_path = os.path.join(checkpoint_dir, 'last.json')
    with open(json_save_path, 'r') as json_file:
        state = json.load(json_file)
        start_epoch = state['epoch_idx']+1
# %%
train(
        accelerator=accelerator, 
        distributed=distributed,
        num_epochs=args.num_epochs, 
        start_epoch=start_epoch,
        clip_fmri=clip_fmri, 
        clip=clip,
        use_image_aug=args.use_image_aug,
        img_augment=None if not args.use_image_aug else img_augment,
        optimizer=optimizer, 
        scheduler=scheduler, 
        train_dataloader=train_dataloader, 
        val_dataloader=val_dataloader, 
        logger=logger, 
        checkpoint_dir=checkpoint_dir,
        num_iterations_per_epoch=num_iterations_per_epoch,
        save_step=args.save_step,
        log_step=args.log_step,
        val_step=args.val_step,
        best_value=None,
        text_scale=args.text_scale,
        mixup_pct=args.mixup_pct,
        norm_nii=args.norm_nii,
        norm_dict=norm_dict if args.norm_nii else None,
        padding_list=padding_list,
        mixin=args.mixin,
        buffer_size=args.buffer_size,
        local_loss=args.local_loss,
        local_loss2=args.local_loss2,
        gather_with_grad=args.gather_with_grad,
        )


