from calendar import c
import os
import sys
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
from transformers.utils.hub import key
sys.path.append('generative_models/')
import argparse
import numpy as np
from tqdm import tqdm
import gc
import wandb
import inspect
import open_clip
import torch
import torch.nn as nn
from accelerate import Accelerator
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2 # bigG embedder from OpenCLIP
from model_variants.VCFlowModel import (Neurons,Fusion,fMRIBackbone,BrainNetwork,RedistributionHead ,PriorNetwork, BrainDiffusionPrior,
                                             CLIPProj, TextDecoder, TextDrivenDecoder, MotionProj, MultiLabelClassifier)
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn.functional as F
import utils
import json
from einops import rearrange, repeat
from diffusers import AutoencoderKL
from model_variants.VCFlow_dataset import CC2017_Dataset
from animatediff.data.dataset_DIRGOD import DIR_GOD_ImageDataset

def mse_loss_for_consistent_embeddings(clip_vision_embeds):
    """
    input:
        clip_vision_embeds (torch.Tensor): [subj, 2, 256, 1664]

    return:
        loss (torch.Tensor): MSE Loss
    """
    # print(f"\033[92m {clip_vision_embeds.shape} \033[0m")

    flattened_embeds = clip_vision_embeds.view(clip_vision_embeds.shape[0], -1)

    mean_embed = flattened_embeds.mean(dim=0, keepdim=True)

    mean_embed = repeat(mean_embed, "1 c -> s c", s=flattened_embeds.shape[0])

    mse_loss = nn.MSELoss()
    loss = mse_loss(flattened_embeds, mean_embed)

    return loss


def pairwise_distance_loss(clip_vision_embeds):
    """
    input:
        clip_vision_embeds (torch.Tensor): [subj, 2, 256, 1664]
    return:
        loss (torch.Tensor): Pairwise Distance Loss
    """

    loss = 0


    for batch_idx in range(1, clip_vision_embeds.shape[0]):

        flattened_prior = clip_vision_embeds[batch_idx-1].view(clip_vision_embeds[batch_idx].shape[0], -1)
        flattened = clip_vision_embeds[batch_idx].view(clip_vision_embeds[batch_idx].shape[0], -1)

        flattened_prior = F.normalize(flattened_prior, dim=-1)
        flattened = F.normalize(flattened, dim=-1)

        loss += utils.mixco_nce(
            flattened_prior,
            flattened,
        )

    return loss / (clip_vision_embeds.shape[0] - 1)

def log_weight(epoch, batch, batches_per_epoch, start_epoch, period):
    total_batches = period * batches_per_epoch
    current_batch = (epoch - start_epoch) * batches_per_epoch + batch
    x = current_batch / total_batches * np.pi
    weight = 1 + 9 * np.abs(np.sin(x))
    return weight

def get_loss_weights(total_epochs, epoch, batch, batches_per_epoch):
    period = total_epochs // 5 * 2
    start_epochs = [i * period//2 for i in range(4)]
    weights = []
    for start_epoch in start_epochs:
        if start_epoch <= epoch < start_epoch + period:
            weight = log_weight(epoch, batch, batches_per_epoch, start_epoch, period)
        else:
            weight = 1
        weights.append(weight)
    return weights

def count_params_(model):
    print("Parameter count per submodule:\n" + "-" * 40)
    total, trainable = 0, 0

    for name, module in model.named_modules():
        if name == "":
            continue
        
        module_total = sum(p.numel() for p in module.parameters(recurse=False))
        module_trainable = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)

        if module_total > 0:
            print(f"{name:<30}: {module_total:>10,} total, {module_trainable:>10,} trainable")

        total += module_total
        trainable += module_trainable

    print("-" * 40)
    print("Model total param counts:")
    print(f"{total:>10,} total\n{trainable:>10,} trainable")
    return trainable


def save_ckpt(tag, epoch, model, optimizer, lr_scheduler, losses, test_losses, lrs):
    ckpt_path = outdir+f'/{tag}.pth'
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        torch.save({
            'epoch': epoch,
            'model_state_dict': unwrapped_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'train_losses': losses,
            'test_losses': test_losses,
            'lrs': lrs,
            }, ckpt_path)
    print(f"---saved {outdir}/{tag} ckpt!---")


def prepare_data(args):
    num_samples_per_epoch = (4320) // num_devices
    num_iterations_per_epoch = num_samples_per_epoch // (args.batch_size)
    print("batch_size =", args.batch_size, "num_iterations_per_epoch =", num_iterations_per_epoch, "num_samples_per_epoch =",
          num_samples_per_epoch)

    train_subj_list = [i for i in range(1, 4) if i != args.subj]
    test_subj_list = [args.subj]
    
    seq_len = 1

    voxel_train_list = []
    voxel_test_list = []
    for subj in train_subj_list:
        voxel_train_list.append(torch.load(f'{args.root_dir}/origin_data/fmri_vc_new/subject{subj}_train_fmri_vc.pt', map_location='cpu')[:,:seq_len, :, :])
    for subj in test_subj_list:
        voxel_test = torch.load(f'{args.root_dir}/origin_data/fmri_vc_new/subject{subj}_test_fmri_vc.pt', map_location='cpu')
        voxel_test = torch.mean(voxel_test, dim=1).unsqueeze(1)
        voxel_test_list.append(voxel_test)


    train_images = torch.load(f'{args.root_dir}/GT_train_3fps.pt', map_location='cpu')
    test_images = torch.load(f'{args.root_dir}/GT_test_3fps.pt', map_location='cpu')
    train_text = torch.load(f'{args.root_dir}/qwen_annotation/GT_train_caption_qwen.pt', map_location='cpu')
    train_text_emb = torch.load(f'{args.root_dir}/qwen_annotation/GT_train_caption_qwen_emb.pt', map_location='cpu')
    test_text = torch.load(f'{args.root_dir}/qwen_annotation/GT_test_caption_qwen.pt', map_location='cpu')
    test_text_emb = torch.load(f'{args.root_dir}/qwen_annotation/GT_test_caption_qwen_emb.pt', map_location='cpu')

    key_objects_categories = json.load(open(f'{args.root_dir}/masks/key_objects_info_train.json'))
    key_objects_masks = torch.load(f'{args.root_dir}/masks/key_objects_masks_train.pt', map_location='cpu')

    cls_id_json = json.load(open(f'{args.root_dir}/qwen_annotation/qwen_train_caption_tag_category_id.json'))
    test_cls_id_json = json.load(open(f'{args.root_dir}/qwen_annotation/qwen_test_caption_tag_category_id.json'))

    train_dataset = CC2017_Dataset(voxel_train_list, train_images, train_text_emb, train_text, mask=key_objects_masks,
                                   cls_id=cls_id_json, key_obj_cls=key_objects_categories, is_train=True)
    test_dataset = CC2017_Dataset(voxel_test_list, test_images, test_text_emb, test_text, cls_id=test_cls_id_json, is_val=True)


    train_dl = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
    )
    test_dl = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=100,
        shuffle=True,
        drop_last=False,
    )

    return num_iterations_per_epoch, train_dl, test_dl



def prepare_data_pretrain(args):
    # prepare data for pretraining
    train_subs = np.arange(1, 9)
    train_nsd_subs = np.arange(1, 8)

    val_subs = [args.subj]

    if args.pretrain:
        train_dataset = DIR_GOD_ImageDataset(subjs=train_subs, subjs_nsd=train_nsd_subs, image_norm=True, phase='pretrain')
    else:
        train_dataset = DIR_GOD_ImageDataset(subjs=train_subs, subjs_nsd=train_nsd_subs, image_norm=True, phase='train')
    train_dl = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=args.batch_size, shuffle=True, drop_last=True)

    voxel_test_list = []
    for subj in val_subs:
        voxel_test = torch.load(f'{args.root_dir}/origin_data/fmri_vc/subject{subj}_test_fmri_vc.pt', map_location='cpu')
        voxel_test = torch.mean(voxel_test, dim=1).unsqueeze(1)
        voxel_test_list.append(voxel_test)
    test_images = torch.load(f'{args.root_dir}/GT_test_3fps.pt', map_location='cpu')
    test_text = torch.load(f'{args.root_dir}/qwen_annotation/GT_test_caption_qwen.pt', map_location='cpu')
    test_text_emb = torch.load(f'{args.root_dir}/qwen_annotation/GT_test_caption_qwen_emb.pt', map_location='cpu')
    test_cls_id_json = json.load(open(f'{args.root_dir}/qwen_annotation/qwen_test_caption_tag_category_id.json'))

    test_dataset = CC2017_Dataset(voxel_test_list, test_images, test_text_emb, test_text, cls_id=test_cls_id_json, is_val=True)
    test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=0, drop_last=False)

    return train_dl, test_dl

def prepare_masks(return_shape=False):
    roi_names = [
        'V1','V2','V3','V3A','V3B','V3CD','V4','LO1','LO2','LO3','PIT','V4t',
        'V6','V6A','V7','V8','PH','FFC','IP0','MT','MST','FST','VVC','VMV1',
        'VMV2','VMV3','PHA1','PHA2','PHA3','TE2p','IPS1'
    ]
    roi_id   = {n:i for i,n in enumerate(roi_names,1)} 

    # visual areas in different levels
    low_names  = ['V1'] + ['V2','V3','V4'] 
    mid_names = ['V3A','V3B','V6','V6A','V7','IPS1'] + [
    'LO1', 'LO2', 'LO3', 'FST', 'MT', 'MST', 'V3CD', 'V4t',
    'PH', 'IP0'
    ]
    high_names = ['FFC','PIT','V8','VMV1','VMV2','VMV3','VVC'] + [
    'PHA1', 'PHA2', 'PHA3',
    'TE2p'
    ]


    low_idx  = [roi_id[n] for n in low_names]
    mid_idx  = [roi_id[n] for n in mid_names]
    high_idx = [roi_id[n] for n in high_names]

    label_img = np.load("NSD/vc_masks.npy")

    if label_img.ndim == 3:            
        label_img = np.argmax(label_img,0)   
        label_img = label_img+1            
    # generate masks
    mask_low  = np.isin(label_img, low_idx)
    mask_mid  = np.isin(label_img, mid_idx)
    mask_high = np.isin(label_img, high_idx)

    mask_low  = torch.tensor(mask_low, dtype=torch.bool)
    mask_mid  = torch.tensor(mask_mid, dtype=torch.bool)
    mask_high = torch.tensor(mask_high, dtype=torch.bool)


    if return_shape == True:
        return sum(mask_low.flatten()==1).item(), sum(mask_mid.flatten()==1).item(), sum(mask_high.flatten()==1).item()
    return mask_low, mask_mid, mask_high

def add_hook(clip_img_embedder,hook_layers = (14, 26, 44)):
    visual = clip_img_embedder.model.visual

    if   hasattr(visual, "blocks"):                  # timm / OpenCLIP s•x•b
        block_list = visual.blocks
    elif hasattr(visual, "transformer"):             # OpenAI CLIP style
        if hasattr(visual.transformer, "resblocks"): # ViT-bigG-14
            block_list = visual.transformer.resblocks
        else:                               
            block_list = visual.transformer
    else:
        raise RuntimeError("Did not find visual transformer layer list")

    clip_img_embedder.mid_feats = {}
    clip_img_embedder._hooks    = []
    names = ["low","mid","high"]

    def _save_hook(self, name):
        def fn(module, _, out):
            tokens = out[1] if isinstance(out, tuple) else out
            tokens = tokens[1:].permute(1, 0, 2).detach()
            self.mid_feats[name] = tokens
        return fn

    for idx in range(len(hook_layers)):
        h = block_list[hook_layers[idx]].register_forward_hook(_save_hook(clip_img_embedder, names[idx]))
        clip_img_embedder._hooks.append(h)



def prepare_models(args):
    clip_img_embedder = FrozenOpenCLIPImageEmbedder(
        arch="ViT-bigG-14",
        version="laion2b_s39b_b160k",
        output_tokens=True,
        only_tokens=True,
        cache_dir=args.weights_dir
    )
    add_hook(clip_img_embedder)
    clip_img_embedder.to(device)

    clip_txt_embedder = None
    if args.pretrain:
        clip_txt_embedder = FrozenOpenCLIPEmbedder2(
            arch="ViT-bigG-14",
            version="laion2b_s39b_b160k",
            layer="last",
            legacy=False,
            always_return_pooled=True,
            cache_dir=args.weights_dir
        )
        clip_txt_embedder.to(device)

    vae = None


    clip_seq_dim = 256
    clip_emb_dim = 1664
    clip_txt_emb_dim = 1280

    model = Neurons()

    if args.neurons_decoupler:

        low_shape, mid_shape, high_shape = prepare_masks(return_shape=True)
        model.backbone = fMRIBackbone(
                            dim = 1024,
                            vision_dim = clip_emb_dim,
                            clip_txt_emb_dim = clip_txt_emb_dim,
                            emb_dropout = 0.20
                        )
        model.distribution_head = RedistributionHead(domain_classes=2)

        model.fusion_low = Fusion(voxel_len=low_shape)
        model.fusion_high = Fusion(voxel_len=high_shape)
        model.fusion_motion = Fusion(voxel_len=mid_shape)


        # setup diffusion prior network
        out_dim = clip_emb_dim
        depth = args.n_frames
        dim_head = 52
        heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim
        timesteps = 100

        prior_network = PriorNetwork(
                dim=out_dim,
                depth=depth,
                dim_head=dim_head,
                heads=heads,
                causal=False,
                num_tokens = clip_seq_dim,
                learned_query_mode="pos_emb",
            )
        model.diffusion_prior = BrainDiffusionPrior(
            net=prior_network,
            image_embed_dim=out_dim,
            condition_on_text_encodings=False,
            timesteps=timesteps,
            cond_drop_prob=0.2,
            image_embed_scale=None,
        )

        clip_txt_embedder = FrozenOpenCLIPEmbedder2(
            arch="ViT-bigG-14",
            version="laion2b_s39b_b160k",
            layer="last",
            legacy=False,
            always_return_pooled=True,
            cache_dir=args.weights_dir

        )
        clip_txt_embedder.to(device)

        vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, cache_dir=args.weights_dir,
                                            subfolder="vae").to(device)

        print(f"\033[92m autoenc loaded \033[0m")

        vae.eval()
        vae.requires_grad_(False)
        vae.to(device)

        # load pretrained weights
        checkpoint = torch.load(f'{args.exp_dir}/checkpoints/brain_model.pth', map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        del checkpoint
        model.clipproj = CLIPProj()

        model.text_seg_dec = TextDrivenDecoder(clip_emb_dim, clip_txt_emb_dim)
        model.text_dec = TextDecoder(clip_txt_emb_dim)
        model.motion_proj = MotionProj(n_frames=args.n_frames, clip_size=clip_emb_dim)
        model.classifier = MultiLabelClassifier(in_channel_img=clip_emb_dim, in_channel_text=clip_txt_emb_dim, seq_len=clip_seq_dim, class_num=51)

    else:
        low_shape, mid_shape, high_shape = prepare_masks(return_shape=True)
        model.backbone = fMRIBackbone(
                            dim = 1024,
                            vision_dim = clip_emb_dim,
                            clip_txt_emb_dim = clip_txt_emb_dim,
                            emb_dropout = 0.1
                        )

        if args.pretrain:
            model.distribution_head = RedistributionHead(domain_classes=8)
        else:
            model.distribution_head = RedistributionHead(domain_classes=2)
        model.clipproj = CLIPProj()
        if args.pretrain:
            print("---resuming from backbone.pth ckpt---")
            # You can choose to load the pre-trained backbone from MindEye2, which will accelerate your neuroclips' convergence.
            checkpoint = torch.load(f'{args.weights_dir}/last.pth', map_location='cpu')
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
            del checkpoint

            checkpoint = torch.load(f'{args.root_dir}/coco_tokens_avg_proj.pth')
            model.clipproj.load_state_dict(checkpoint)
            del checkpoint
        else:
            print("---resuming from brain_model_pretrain.pth ckpt---")
            checkpoint = torch.load(f'{args.exp_dir}/checkpoints/brain_model_pretrain.pth', map_location='cpu')
            # discard distribution_head in checkpoint
            keys_to_remove = [k for k in checkpoint['model_state_dict'] if k.startswith('distribution_head')]
            for k in keys_to_remove:
                print(f"Skipping {k} due to mismatch or exclusion.")
                checkpoint['model_state_dict'].pop(k)
            del checkpoint

    if args.neurons_decoupler:
        # freeze all parameters except the specified modules
        for param in model.parameters():
            param.requires_grad_(False)

        # only train the specified modules
        modules_to_train = [
            model.diffusion_prior,
            model.text_dec,
            model.text_seg_dec,
            model.motion_proj,
            model.classifier,
            model.fusion_low,
            model.fusion_high,
            model.fusion_motion,
            model.clipproj,
        ]

        for module in modules_to_train:
            for param in module.parameters():
                param.requires_grad_(True)
    else:
        for param in model.parameters():
            param.requires_grad_(True)

    utils.count_params(model)
    return model, clip_img_embedder, clip_txt_embedder, vae


def trainable_modules_check(is_main_process, model):
    if is_main_process:
        print(f"\033[92m================================== \033[0m")
        print(f"\033[92m Checking ... \033[0m")
        print(f"\033[92m================================== \033[0m")
        for name, param in model.named_parameters():
            if param.requires_grad == False:
                print(f"\033[94m Frozen: {name} \033[0m")
            else:
                print(f"\033[91m Trainable: {name} \033[0m")

def get_video_targets(video_tensor, clip_img_embedder):
    b, f, c, h, w = video_tensor.shape
    video_tensor = video_tensor.view(b * f, c, h, w)
    with torch.no_grad():
        frame_features = clip_img_embedder(video_tensor)  # [B * F, feature_dim]
    B, N, C = frame_features.shape
    frame_features = frame_features.view(b, f, N, C)
    return frame_features



def train(args):
    if args.pretrain:
        train_dl, test_dl = prepare_data_pretrain(args)
        num_iterations_per_epoch = len(train_dl)
    else:
        _, train_dl, test_dl = prepare_data(args)
        num_iterations_per_epoch = len(train_dl)
    model, clip_img_embedder, clip_txt_embedder, vae = prepare_models(args)


    decay, no_decay = [], []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue 
        if any(nd in name for nd in ['bias', 'LayerNorm', 'layernorm', 'ln', 'embedding']):
            no_decay.append(param)
        else:
            decay.append(param)

    param_groups = [
        {'params': decay, 'weight_decay': 0.01},
        {'params': no_decay, 'weight_decay': 0.0}
    ]

    optimizer = torch.optim.AdamW(
        param_groups,
        lr=args.max_lr,
        betas=(0.9, 0.999)
    )

    if args.lr_scheduler_type == 'linear':
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer,
            total_iters=int(np.floor(args.num_epochs*num_iterations_per_epoch)),
            last_epoch=-1
        )
    elif args.lr_scheduler_type == 'cycle':
        total_steps=int(np.floor(args.num_epochs*num_iterations_per_epoch))
        print("total_steps", total_steps)
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=args.max_lr,
            total_steps=total_steps,
            final_div_factor=1e4,
            div_factor=25,
            last_epoch=-1, pct_start=0.1
        )
    else:
        total_steps = int(np.floor(args.num_epochs * num_iterations_per_epoch))
        print("total_steps", total_steps)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=2, T_mult=2
        )

    epoch = 0
    losses, test_losses, lrs = [], [], []
    best_metric = 0
    loss_video = 0
    torch.cuda.empty_cache()
    train_dls = [train_dl]

    model, optimizer, train_dl, lr_scheduler = accelerator.prepare(model, optimizer, train_dl, lr_scheduler)

    DiceLoss = utils.DiceLoss().cuda()
    loss_ce = torch.nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
    loss_cls = nn.BCEWithLogitsLoss()
    l1 = nn.L1Loss()
    soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, args.num_epochs - int(args.mixup_pct * args.num_epochs))
    global_step = 0

    if num_devices > 1 and distributed:
        model = model.module


    trainable_modules_check(accelerator.is_main_process, model)



    if args.resume_from_ckpt is not None:
        checkpoint = torch.load(args.resume_from_ckpt)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        epoch = checkpoint['epoch'] + 1
        print(f"\033[92m ************ Load from checkpoint at epoch {epoch} \033[0m")
        del checkpoint



    for epoch in tqdm(range(epoch, args.num_epochs), disable=(local_rank!=0)):
        model.train()

        train_acc_text_gen = []
        test_acc_text_gen = []

        voxel_low_mask, voxel_mid_mask, voxel_high_mask = prepare_masks()
        
        for iter, batch in enumerate(tqdm(train_dl, disable=(local_rank!=0))):
            with torch.cuda.amp.autocast(dtype=data_type):
                optimizer.zero_grad()
                loss=0.

                if not args.pretrain:
                    voxel, video, text, cls_labels = batch['voxel'], batch['pixel_values'], batch['text'], batch['cls_label']
                    clip_tokens = batch['clip_tokens']

                    key_obj_masks = batch['key_obj_masks'].detach()
                    key_obj_cls = batch['key_obj_cls']
                    if not args.neurons_decoupler:
                        image = video[:, 2 + epoch % 2, :, :, :].float()
                        key_obj_mask = key_obj_masks[:, 2 + epoch % 2, :, :].float()
                        voxel = voxel
                    else:
                        image = video[:, 2, :, :, :].float()
                        key_obj_mask = key_obj_masks[:, 2, :, :].float()
                        voxel = voxel

                    voxel = voxel.to(device)
                    image = image.to(device)
                    video = video.to(device)
                    text = text.to(device)
                    key_obj_masks = key_obj_masks.to(device)
                    clip_tokens = clip_tokens.to(device)
                    cls_labels = cls_labels.to(device)

                    subj_lbl = list(range(0,2)) * voxel.shape[0]
                    subj_lbl = torch.tensor(subj_lbl, device=device).long()

                    B,S,T,H,W = voxel.shape
                    voxel = rearrange(voxel, "b s t h w -> (b s) t h w")
                    if not args.neurons_decoupler:
                        voxel, perm, betas, select = utils.mixco(voxel)

                    
                    voxel_low = voxel[:,:,voxel_low_mask.bool()]
                    voxel_mid = voxel[:,:,voxel_mid_mask.bool()]
                    voxel_high = voxel[:,:,voxel_high_mask.bool()]

                    clip_vision_target = clip_img_embedder(image)
                    clip_vision_low_target = clip_img_embedder.mid_feats['low']
                    clip_vision_target_norm = nn.functional.normalize(clip_vision_target.flatten(1), dim=-1)
                    clip_vision_low_target_norm = nn.functional.normalize(clip_vision_low_target.flatten(1), dim=-1)
                    clip_text_target_norm = nn.functional.normalize(text.flatten(1), dim=-1)
                
                else:
                    fMRIs = batch["fMRIs"]
                    text = batch["txt"]
                    image = batch["gt_image"]

                    fMRIs = fMRIs.to(device)

                    subj_lbl = list(range(0, 8)) * fMRIs.shape[0]
                    subj_lbl = torch.tensor(subj_lbl, device=device).long()
                    image = image.to(device)

                    B, C, S, H, W = fMRIs.shape
                    voxel = rearrange(fMRIs, 'b c s h w -> (b s) c h w')

                    voxel_low_mask, voxel_mid_mask, voxel_high_mask = prepare_masks()
                    voxel_low = voxel[:,:, voxel_low_mask.bool()]
                    voxel_mid = voxel[:,:, voxel_mid_mask.bool()]
                    voxel_high = voxel[:,:, voxel_high_mask.bool()]

                    clip_vision_target = clip_img_embedder(image)
                    clip_vision_low_target = clip_img_embedder.mid_feats['low']
                    clip_vision_target_norm = nn.functional.normalize(clip_vision_target.flatten(1), dim=-1)
                    clip_vision_low_target_norm = nn.functional.normalize(clip_vision_low_target.flatten(1), dim=-1)
                    _, clip_text_target = clip_txt_embedder(text)
                    clip_text_target_norm = nn.functional.normalize(clip_text_target.flatten(1), dim=-1)



                generic_brain_reps, individual_brain_reps, clip_vision_embeds_raw = model.backbone(voxel,istraining=True)
                clip_vision_embeds,pred_subj_cls = model.distribution_head(clip_vision_embeds_raw)
                clip_text_embeds = model.clipproj(clip_vision_embeds)


                assert not torch.any(torch.isnan(clip_vision_target))
                clip_vision_embeds_norm = nn.functional.normalize(clip_vision_embeds.flatten(1), dim=-1)
                clip_text_embeds_norm = nn.functional.normalize(clip_text_embeds.flatten(1), dim=-1)


                if args.pretrain:

                    '''============ Generic Embeds Consistency ============'''
                    generic_brain_reps = rearrange(generic_brain_reps, '(b s) n c -> s b n c', b=B)
                    loss_generic = mse_loss_for_consistent_embeddings(generic_brain_reps)


                    clip_vision_embeds_generic = rearrange(clip_vision_embeds, '(b s) n c -> s b n c', b=B)
                    clip_text_embeds_generic   = rearrange(clip_text_embeds, '(b s) c -> s b c', b=B)

                    loss_generic_vision = pairwise_distance_loss(clip_vision_embeds_generic)
                    loss_generic_text= pairwise_distance_loss(clip_text_embeds_generic)

                    '''============ Subject Classification ============'''
                    loss_subj_cls = F.cross_entropy(
                        pred_subj_cls,
                        subj_lbl.long()
                    )

                    '''============ Individual Brain Recon ============'''
                    individual_brain_reps = individual_brain_reps.view(B, S, -1)
                    loss_indiv_recon = F.mse_loss(
                        individual_brain_reps,
                        voxel.view(B, S, -1)
                    )

                    '''============ Vision Embeds Align ============'''
                    clip_vision_target_norm = repeat(clip_vision_target_norm, 'b n -> (b s) n', s=S)
                    loss_clip_vision = utils.mixco_nce(
                        clip_vision_embeds_norm,
                        clip_vision_target_norm,
                    )

                    '''============ Text Embeds Align ============'''
                    clip_text_target_norm = repeat(clip_text_target_norm, 'b n -> (b s) n', s=S)
                    loss_clip_txt = utils.mixco_nce(clip_text_embeds_norm, clip_text_target_norm)
                    utils.check_loss(loss_clip_txt)

                    '''============ Overall Loss ============'''
             
                    loss += loss_generic_vision + loss_clip_vision + \
                            loss_subj_cls + loss_indiv_recon + loss_generic

                elif not args.neurons_decoupler and not args.pretrain:


                    '''============ Generic Embeds Consistency ============'''
                    generic_brain_reps = rearrange(generic_brain_reps, '(b s) n c -> s b n c', b=B)
                    loss_generic = mse_loss_for_consistent_embeddings(generic_brain_reps)

                    clip_vision_embeds_generic = rearrange(clip_vision_embeds, '(b s) n c -> s b n c', b=B)
                    clip_text_embeds_generic   = rearrange(clip_text_embeds, '(b s) c -> s b c', b=B)

                    loss_generic_vision = pairwise_distance_loss(clip_vision_embeds_generic)
                    loss_generic_text= pairwise_distance_loss(clip_text_embeds_generic)

                    '''============ Subject Classification ============'''
                    loss_subj_cls = F.cross_entropy(
                        pred_subj_cls,
                        subj_lbl.long()
                    )

                    '''============ Individual Brain Recon ============'''
                    individual_brain_reps = individual_brain_reps.view(B, S, -1)
                    loss_indiv_recon = F.mse_loss(
                        individual_brain_reps,
                        voxel.view(B, S, -1)
                    )

                    '''============ Vision Embeds Align ============'''
                    clip_vision_target_norm = repeat(clip_vision_target_norm, 'b n -> (b s) n', s=S)
                    loss_clip_vision = utils.mixco_nce(
                        clip_vision_embeds_norm,
                        clip_vision_target_norm,
                    )

                    
                    '''============ Text Embeds Align ============'''
                    clip_text_target_norm = repeat(clip_text_target_norm, 'b n -> (b s) n', s=S)
                    loss_clip_txt = utils.mixco_nce(clip_text_embeds_norm, clip_text_target_norm)
                    utils.check_loss(loss_clip_txt)

                    '''============ Overall Loss ============'''
                    loss += loss_generic_vision + loss_clip_vision + \
                            loss_subj_cls + loss_indiv_recon + loss_generic


                elif args.neurons_decoupler:
                    M = S * T
                    # repeat
                    video = repeat(video, 'b f c h w -> (m b) f c h w', m=M)
                    text = repeat(text, 'b n -> (m b) n', m=M)
                    key_obj_masks = repeat(key_obj_masks, 'b f h w -> (m b) f h w', m=M)
                    key_obj_mask = repeat(key_obj_mask, 'b h w -> (m b) h w', m=M)
                    clip_tokens = repeat(clip_tokens, 'b n -> (m b) n', m=M)
                    cls_labels = repeat(cls_labels, 'b n -> (m b) n', m=M)
                    key_obj_cls = key_obj_cls* M

                    clip_vision_embeds = rearrange(clip_vision_embeds, '(b s) n c -> (s b) n c', b=B)

                    clip_video_target = get_video_targets(video, clip_img_embedder)

                    '''============ Prior Train ============'''
                    clip_vision_target = repeat(clip_vision_target, 'b n c -> (m b) n c', m=M)
                    loss_prior, prior_out = model.diffusion_prior(text_embed=clip_vision_embeds,
                                                                  image_embed=clip_vision_target)


                    prior_out_low = model.fusion_low(prior_out, voxel_low)
                    prior_out_high = model.fusion_high(prior_out, voxel_high)
                    prior_out_motion = model.fusion_motion(prior_out_low,voxel_mid)

                    '''============ Gen Motion Embeddings ============'''
                    motion_embeds = model.motion_proj(prior_out_motion)

                    '''============ Vision Embeds Align (Low) ============'''
                    clip_vision_low_target_norm = repeat(clip_vision_low_target_norm, 'b n -> (m b) n', m=M)
                    prior_out_low_norm = nn.functional.normalize(prior_out_low.flatten(1), dim=-1)
                    loss_clip_vision_low = utils.mixco_nce(
                        prior_out_low_norm,
                        clip_vision_low_target_norm,
                    )


                    '''============ Vision Embeds Align (High) ============'''
                    clip_vision_target_norm = nn.functional.normalize(clip_video_target.mean(1).flatten(1), dim=-1)


                    prior_out_high_norm = nn.functional.normalize(prior_out_high.flatten(1), dim=-1)
                    loss_clip_vision = utils.mixco_nce(
                        prior_out_high_norm,
                        clip_vision_target_norm,
                    )

                    '''============ Text Embeds Align ============'''
                    clip_text_target_norm = repeat(clip_text_target_norm, 'b n -> (m b) n', m=M)
                    pred_text_norm = nn.functional.normalize(model.clipproj(prior_out_high).flatten(1), dim=-1)
                    loss_clip_txt = utils.mixco_nce(
                        pred_text_norm, 
                        clip_text_target_norm)


                    '''============ Key Obj Seg ============'''
                    _, key_obj_text_embed = clip_txt_embedder(key_obj_cls)  # [B, D]

                    low_res_masks = model.text_seg_dec(prior_out_low, key_obj_text_embed.detach(), time=args.batch_size)

                    key_obj_mask = F.interpolate(key_obj_mask.unsqueeze(1), low_res_masks.shape[-2:], mode="nearest")

                    loss_key_obj_seg = DiceLoss(low_res_masks.float(), key_obj_mask.float())

                    '''============ Multi Label Classification ============'''
                    cls_pred = model.classifier(prior_out_high.mean(1))
                    loss_multi_cls = loss_cls(cls_pred.float(), cls_labels.float())


                    '''============ Scene Description ============'''


                    logits = model.text_dec(pred_text_norm.float(), clip_tokens)
                    logits = logits.logits[:, :-1]
                    clip_tokens = clip_tokens.flatten()
                    logits = logits.reshape(-1, logits.shape[-1])
                    loss_text_gen = loss_ce(logits, clip_tokens)
                    utils.check_loss(loss_text_gen)
                    acc_text_gen = ((logits.argmax(1) == clip_tokens) * (clip_tokens > 0)).sum() / (
                                clip_tokens > 0).sum().cpu()
                    train_acc_text_gen.append(acc_text_gen.cpu().numpy())


                    '''============ Blurry Video Recon ============'''
                    video_vae = video.reshape(len(video) * args.n_frames, 3, 224, 224)
                    voxel_enc = vae.encode(2 * video_vae - 1).latent_dist.mode() * 0.18215
                    vae_embeds = model.text_seg_dec(rearrange(motion_embeds, "b f n c -> (b f) n c"),
                                                    model.clipproj(motion_embeds.mean(1)),
                                                    time=args.batch_size * args.n_frames, is_seg=False)
                    vae_embeds = F.interpolate(vae_embeds, voxel_enc.shape[-2:], mode="nearest")
                    loss_recon_video = l1(vae_embeds, voxel_enc)


                    '''============ Progressive Learning ============'''
                    weights = get_loss_weights(args.num_epochs, epoch, iter, num_iterations_per_epoch)
                    loss = loss_prior * args.prior_scale + loss_clip_vision + loss_clip_txt \
                            + 0.7*loss_clip_vision_low \
                           + loss_key_obj_seg * weights[0] \
                           + loss_multi_cls * weights[1] \
                           + loss_text_gen * weights[2] \
                           + loss_recon_video * weights[3]


                utils.check_loss(loss)
                accelerator.backward(loss)
                optimizer.step()

                losses.append(loss.item())
                lrs.append(optimizer.param_groups[0]['lr'])

                if args.lr_scheduler_type is not None:
                    lr_scheduler.step()
                global_step += 1

                if args.use_wandb and accelerator.is_main_process:
                    wandb.log({"lr": lr_scheduler.get_last_lr()[0]}, step=global_step)
                    wandb.log({"loss": loss.item()}, step=global_step)
                    if not args.neurons_decoupler:
                        wandb.log({"loss_generic": loss_generic.item()}, step=global_step)
                        wandb.log({"loss_generic_vision": loss_generic_vision.item()}, step=global_step)
                        wandb.log({"loss_generic_text": loss_generic_text.item()}, step=global_step)
                        wandb.log({"loss_indiv_recon": loss_indiv_recon.item()}, step=global_step)
                        wandb.log({"loss_clip_vision": loss_clip_vision.item()}, step=global_step)
                        wandb.log({"loss_clip_txt": loss_clip_txt.item()}, step=global_step)
                        wandb.log({"loss_subj_cls": loss_subj_cls.item()}, step=global_step)
                    else:
                        wandb.log({"loss_clip_vision": loss_clip_vision.item()}, step=global_step)
                        wandb.log({"loss_clip_txt": loss_clip_txt.item()}, step=global_step)
                        wandb.log({"loss_clip_vision_low": loss_clip_vision_low.item()}, step=global_step)

                        wandb.log({"loss_prior": loss_prior.item()}, step=global_step)
                        wandb.log({"loss_key_obj_seg": loss_key_obj_seg.item()}, step=global_step)
                        wandb.log({"loss_text_gen": loss_text_gen.item()}, step=global_step)
                        wandb.log({"loss_recon_video": loss_recon_video.item()}, step=global_step)
                        wandb.log({"train_acc_text_gen": np.mean(train_acc_text_gen)}, step=global_step)

                        wandb.log({"weights_0": weights[0],
                                   "weights_1": weights[1],
                                   "weights_2": weights[2],
                                   "weights_3": weights[3],
                                   }, step=global_step)


        # ==================================================================================
        # Test begin
        # ==================================================================================
        model.eval()

        test_fwd_percent_correct = []
        test_bwd_percent_correct = []
        text_fwd_percent_correct = []

        if accelerator.is_main_process:
            with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):
                for test_i, batch in enumerate(test_dl):
                    # if not args.pretrain:
                    test_voxel, test_video, test_text = batch['voxel'], batch['pixel_values'], batch['text']
                    test_clip_tokens = batch['clip_tokens']

                    test_voxel = test_voxel[:,0]
                    test_image = test_video[:,2,:,:,:].cpu()

                    test_voxel = test_voxel.to(device)
                    test_image = test_image.to(device)
                    test_text = test_text.to(device)
                    test_clip_tokens = test_clip_tokens.to(device)


                    clip_vision_target = clip_img_embedder(test_image.float())
                    voxel_low_mask, voxel_mid_mask, voxel_high_mask = prepare_masks()
                    voxel_low = test_voxel[:,:, voxel_low_mask.bool()]
                    voxel_mid = test_voxel[:,:, voxel_mid_mask.bool()]
                    voxel_high = test_voxel[:,:, voxel_high_mask.bool()]
                    voxel_all = torch.cat([voxel_low,voxel_mid, voxel_high], dim=2)

                    clip_vision_embeds_raw = model.backbone(test_voxel)
                    clip_vision_embeds,pred_subj_cls = model.distribution_head(clip_vision_embeds_raw)



                    clip_vision_embeds = clip_vision_embeds.to(device)

                    clip_vision_embeds_norm = nn.functional.normalize(clip_vision_embeds.flatten(1), dim=-1)
                    clip_vision_target_norm = nn.functional.normalize(clip_vision_target.flatten(1), dim=-1)


                    if not args.neurons_decoupler:
                        pred_text_norm = nn.functional.normalize(model.clipproj(clip_vision_embeds).flatten(1), dim=-1)
                    else:

                        _, prior_out = model.diffusion_prior(text_embed=clip_vision_embeds, image_embed=clip_vision_target)
                        
                        prior_out_low = model.fusion_low(prior_out, voxel_low)
                        prior_out_high = model.fusion_high(prior_out, voxel_high)
                        prior_out_motion = model.fusion_motion(prior_out_low, voxel_mid)

                        motion_embeds = model.motion_proj(prior_out_motion)

                        clip_vision_embeds_norm = nn.functional.normalize(prior_out_high.flatten(1), dim=-1)
                        pred_text_norm = nn.functional.normalize(model.clipproj(prior_out_high).flatten(1), dim=-1)

                        logits = model.text_dec(pred_text_norm.float(), test_clip_tokens)
                        logits = logits.logits[:, :-1]
                        test_clip_tokens = test_clip_tokens.flatten()
                        logits = logits.reshape(-1, logits.shape[-1])
                        acc_text_gen = ((logits.argmax(1) == test_clip_tokens) * (test_clip_tokens > 0)).sum() / (test_clip_tokens > 0).sum()
                        test_acc_text_gen.append(acc_text_gen.cpu().numpy())


                    target_text_norm = nn.functional.normalize(test_text.flatten(1), dim=-1)
                    labels = torch.arange(len(pred_text_norm)).to(pred_text_norm.device)
                    text_fwd_percent_correct.append(
                        utils.topk(utils.batchwise_cosine_similarity(pred_text_norm, target_text_norm), labels, k=5).item())

                    labels = torch.arange(len(clip_vision_embeds_norm)).to(clip_vision_embeds_norm.device)
                    test_fwd_percent_correct.append(utils.topk(utils.batchwise_cosine_similarity(clip_vision_embeds_norm, clip_vision_target_norm), labels, k=1).item())
                    test_bwd_percent_correct.append(utils.topk(utils.batchwise_cosine_similarity(clip_vision_target_norm, clip_vision_embeds_norm), labels, k=1).item())

                print(f'\033[92m Evaluating Epoch {epoch} ... \033[0m')
                print(f'\033[92m \ttest_fwd_percent_correct: {np.mean(test_fwd_percent_correct)} \033[0m')
                print(f'\033[92m \ttest_bwd_percent_correct: {np.mean(test_bwd_percent_correct)} \033[0m')
                print(f'\033[92m \ttext_fwd_percent_correct: {np.mean(text_fwd_percent_correct)} \033[0m')
                if args.neurons_decoupler:
                    print(f'\033[92m \ttest_acc_text_gen       : {np.mean(test_acc_text_gen)} \033[0m')
                if args.use_wandb:
                    wandb.log({"test_fwd_percent_correct": np.mean(test_fwd_percent_correct)}, step=global_step)
                    wandb.log({"test_bwd_percent_correct": np.mean(test_bwd_percent_correct)}, step=global_step)
                    wandb.log({"text_fwd_percent_correct": np.mean(text_fwd_percent_correct)}, step=global_step)
                    if args.neurons_decoupler:
                        wandb.log({"test_acc_text_gen": np.mean(test_acc_text_gen)}, step=global_step)

            if not args.neurons_decoupler:
                metric = np.mean(test_fwd_percent_correct) + np.mean(test_bwd_percent_correct) + np.mean(text_fwd_percent_correct)
            else:
                metric = np.mean(test_fwd_percent_correct) + np.mean(test_bwd_percent_correct) + np.mean(test_acc_text_gen)

            # Save model checkpoint and reconstruct
            if metric > best_metric:
                best_metric = metric
                best_epoch = epoch
                print(f"\033[92m New best test metric: {best_metric} \033[0m")
                if args.pretrain:
                    save_ckpt(f'brain_model_pretrain', epoch, model, optimizer, lr_scheduler, losses, test_losses, lrs)
                elif not args.neurons_decoupler:
                    save_ckpt(f'brain_model', epoch, model, optimizer, lr_scheduler, losses, test_losses, lrs)
                else:
                    save_ckpt(f'brain_model_prior', epoch, model, optimizer, lr_scheduler, losses, test_losses, lrs)

            else:
                print(f"\033[91m Current metric: {metric}, best metric loss is {best_metric} in Epoch {best_epoch} \033[0m")

        # wait for other GPUs to catch up if needed
        accelerator.wait_for_everyone()
        torch.cuda.empty_cache()
        gc.collect()

    if args.ckpt_saving:
        if args.pretrain:
            save_ckpt(f'brain_model_pretrain_last', epoch, model, optimizer, lr_scheduler, losses, test_losses, lrs)
        elif not args.neurons_decoupler:
            save_ckpt(f'brain_model_last', epoch, model, optimizer, lr_scheduler, losses, test_losses, lrs)
        else:
            save_ckpt(f'brain_model_prior_last', epoch, model, optimizer, lr_scheduler, losses, test_losses, lrs)
    print("\n===Finished!===\n")


if __name__ == "__main__":
    ### Multi-GPU config ###
    local_rank = os.getenv('RANK')
    if local_rank is None:
        local_rank = 0
    else:
        local_rank = int(local_rank)
    print("LOCAL RANK ", local_rank)

    data_type = torch.float16  # change depending on your mixed_precision
    num_devices = torch.cuda.device_count()
    if num_devices == 0: num_devices = 1

    accelerator = Accelerator(split_batches=False, mixed_precision="fp16")

    print("PID of this process =", os.getpid())
    device = accelerator.device
    # device = 'cuda:0'
    print("device:", device)
    world_size = accelerator.state.num_processes
    distributed = not accelerator.state.distributed_type == 'NO'
    num_devices = torch.cuda.device_count()
    if num_devices == 0 or not distributed: num_devices = 1
    num_workers = num_devices
    print(accelerator.state)

    print("distributed =", distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =",
          world_size, "data_type =", data_type)
    print = accelerator.print  # only print if local_rank=0

    parser = argparse.ArgumentParser(description="Model Training Configuration")
    parser.add_argument(
        "--model_name", type=str, default="testing",
        help="name of model, used for ckpt saving and wandb logging (if enabled)",
    )
    parser.add_argument(
        "--subj", type=int, default=1, choices=[1, 2, 3],
        help="Validate on which subject?",
    )
    parser.add_argument(
        "--pretrain", action=argparse.BooleanOptionalAction, default=False,
        help="whether to pretrain on DIR and GOD datasets (True) or train on CC2017 dataset (False)",
    )
    parser.add_argument(
        "--neurons_decoupler", action=argparse.BooleanOptionalAction, default=False,
        help="whether to train diffusion prior (True) or just rely on retrieval part of the pipeline (False)",
    )
    parser.add_argument(
        "--batch_size", type=int, default=10,
        help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
    )
    parser.add_argument(
        "--mixup_pct", type=float, default=.33,
        help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
    )
    parser.add_argument(
        "--prior_scale", type=float, default=30,
        help="multiply diffusion prior loss by this",
    )
    parser.add_argument(
        "--num_epochs", type=int, default=150,
        help="number of epochs of training",
    )
    parser.add_argument(
        "--n_blocks", type=int, default=4,
    )
    parser.add_argument(
        "--n_frames", type=int, default=6,
    )
    parser.add_argument(
        "--hidden_dim", type=int, default=4096,
    )
    parser.add_argument(
        "--lr_scheduler_type", type=str, default='cycle', choices=['cycle', 'linear', 'cosine'],
    )
    parser.add_argument(
        "--root_dir", type=str, default='./cc2017_dataset',
    )
    parser.add_argument(
        "--weights_dir", type=str, default='./pretrained_weights',
    )
    parser.add_argument(
        "--exp_dir", type=str, default='./saved_weights_ours',
    )
    parser.add_argument(
        "--ckpt_saving", action=argparse.BooleanOptionalAction, default=True,
    )
    parser.add_argument("--pretrained-model-path", type=str, default="runwayml/stable-diffusion-v1-5")
    parser.add_argument(
        "--resume_from_ckpt", type=str, default=None,
    )
    parser.add_argument(
        "--seed", type=int, default=42,
    )
    parser.add_argument(
        "--max_lr", type=float, default=3e-4,
    )
    parser.add_argument(
        "--use_wandb",
        type=lambda x: x.lower() == "true",
        default=False,
    )
    args = parser.parse_args()

    # seed all random functions
    utils.seed_everything(args.seed)

    os.makedirs(f'{args.exp_dir}/checkpoints/', exist_ok=True)
    outdir = os.path.abspath(f'{args.exp_dir}/checkpoints')

    if args.use_wandb and accelerator.is_main_process:
        if args.pretrain:
            wandb.init(project="VCFlow", name=f"brain_pretrain--exp_{args.exp_dir.split('exp_')[-1]}")
        elif not args.neurons_decoupler:
            wandb.init(project="VCFlow", name=f"brain--exp_{args.exp_dir.split('exp_')[-1]}")
        else:
            wandb.init(project="VCFlow", name=f"decoupler--exp_{args.exp_dir.split('exp_')[-1]}")

    train(args)