# We use iBot code [https://github.com/bytedance/ibot] as the base of this code.

import os
os.chdir(os.path.abspath(os.path.dirname(__file__)))


import argparse
import os
import sys
import datetime
import time
import math
import json
import numpy as np
import utils
import models
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

from pathlib import Path
from PIL import Image
from torchvision import transforms
from tensorboardX import SummaryWriter
from models.head import CRISPHead
from loader import ImageFolderMask
from evaluation.unsupervised.unsup_cls import eval_pred

def get_args_parser():
    parser = argparse.ArgumentParser('CRISP', add_help=False)

    # Model parameters
    parser.add_argument('--arch', default='vit_small', type=str,
        choices=['vit_tiny', 'vit_small', 'vit_base', 'vit_large'], help="Name of architecture.")
    parser.add_argument('--patch_size', default=16, type=int, help="Size in pixels of input square patches.")
    parser.add_argument('--out_dim', default=8192, type=int, help="""Dimensionality of output for [CLS] token.""")
    parser.add_argument('--patch_out_dim', default=8192, type=int, help="""Dimensionality of output for patch/region tokens.""")
    parser.add_argument('--batch_size_per_gpu', default=8, type=int, help="Per-GPU batch-size")
    
    parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag, help="Whether or not to weight normalize the last layer of the head.")
    parser.add_argument('--momentum_teacher', default=0.996, type=float, help="Base EMA parameter for teacher update. ")
    parser.add_argument('--use_masked_im_modeling', default=True, type=utils.bool_flag,
        help="Whether to use masked image modeling (mim) in backbone (Default: True)")
    parser.add_argument('--pred_ratio', default=0.7, type=float, nargs='+', help="""Ratio of partial prediction.
        If a list of ratio is specified, one of them will be randomly choosed for each patch.""")
    parser.add_argument('--pred_ratio_var', default=0, type=float, nargs='+', help="""Variance of partial prediction
        ratio. Length should be indentical to the length of pred_ratio. 0 for disabling. """)
    parser.add_argument('--pred_shape', default='rand', type=str, help="""Shape of partial prediction.""")
    parser.add_argument('--pred_start_epoch', default=0, type=int, help="Start epoch to perform masked image prediction.")
    parser.add_argument('--lambda1', default=1.0, type=float, help="""loss weight for [CLS] tokens (Default: 2.0)""")
    parser.add_argument('--lambda2', default=1.0, type=float, help="""loss weight for patch (Default: 1.0)""")
    parser.add_argument('--lambda3', default=1.0, type=float, help="""loss weight for region (Default: 1.0)""")
        
    # Temperature teacher parameters
    parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
        help="""Initial value for the teacher temperature: 0.04 works well in most cases.
        Try decreasing it if the training loss does not decrease.""")
    parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
        of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
        starting with the default value of 0.04 and increase this slightly if needed.""")
    parser.add_argument('--warmup_teacher_patch_temp', default=0.04, type=float, help="""See 
        `--warmup_teacher_temp`""")
    parser.add_argument('--teacher_patch_temp', default=0.07, type=float, help=""""See `--teacher_temp`""")
    parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int,
        help='Number of warmup epochs for the teacher temperature (Default: 30).')

    # Training/Optimization parameters
    parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
        to use half precision for training. Improves training time and memory requirements,
        but can provoke instability and slight decay of performance. We recommend disabling
        mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
    parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
        weight decay. With ViT, a smaller value at the beginning of training works well.""")
    parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
        weight decay. We use a cosine schedule for WD and using a larger decay by
        the end of training improves performance for ViTs.""")
    parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter
        gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
        help optimization for larger ViT architectures. 0 for disabling.""")
        

    parser.add_argument('--epochs', default=800, type=int, help='Number of epochs of training.')
    parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
        during which we keep the output layer fixed. Typically doing so during
        the first epoch helps training. Try increasing this value if the loss does not decrease.""")
    parser.add_argument("--lr", default=1e-5, type=float, help="""Learning rate at the end of
        linear warmup (highest LR used during training). The learning rate is linearly scaled
        with the batch size, and specified here for a reference batch size of 256.""")
    parser.add_argument("--warmup_epochs", default=10, type=int,
        help="Number of epochs for the linear learning-rate warm up.")
    parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
        end of optimization. We use a cosine LR schedule with linear warmup.""")
    parser.add_argument('--optimizer', default='adamw', type=str,
        choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""")
    parser.add_argument('--load_from', default="../checkpoints/vit_small/pretrained/checkpoint.pth", help="""Path to load checkpoints to resume training.""")
    parser.add_argument('--drop_path', type=float, default=0.1, help="""Drop path rate for student network.""")

    # Multi-crop parameters
    parser.add_argument('--global_crops_number', type=int, default=2, help="""Number of global
        views to generate. Default is to use two global crops. """)
    parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.25, 1.),
        help="""Scale range of the cropped image before resizing, relatively to the origin image.
        Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
        recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
    parser.add_argument('--local_crops_number', type=int, default=2, help="""Number of small
        local views to generate. Set this parameter to 0 to disable multi-crop training.
        When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
    parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.25),
        help="""Scale range of the cropped image before resizing, relatively to the origin image.
        Used for small local view cropping of multi-crop.""")

    # Misc
    parser.add_argument('--data_path', default='', type=str, help='Please specify path to the ImageNet training data.')
    parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.')
    parser.add_argument('--saveckp_freq', default=10, type=int, help='Save checkpoint every x epochs.')
    parser.add_argument('--seed', default=0, type=int, help='Random seed.')
    parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
        distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    
    parser.add_argument("--ref_img_sz", default=224, type=int, help="size of the reference image.")
    parser.add_argument("--ref_avg_blks", default=4, type=int, help="number of M blocks.")
    parser.add_argument("--ref_before_norm", default=1, type=int, help="..")
    parser.add_argument("--sim_thresh", default=0.75, type=float, help="Similarity Threshold.")
    parser.add_argument("--cncpt_blks", default=1, type=int, help="Number of L blocks (aggregation module).")
    parser.add_argument('--wepochs', default=1, type=int, help='Number of epochs of warmup where backbone is being frozen.')
    return parser



def train_crisp(args):
    
    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True
    
    
    train_CGSSL(args, train_type='warm')
    train_CGSSL(args, train_type='full')
    
    
    
def train_CGSSL(args, train_type='full'):

    # ============ preparing data  ============
    transform = DataAugmentationSMR( args.global_crops_scale,
        args.local_crops_scale,  args.global_crops_number,  args.local_crops_number, args.ref_img_sz)
    pred_size = args.patch_size
    dataset = ImageFolderMask( args.data_path,  transform=transform, 
        patch_size=pred_size, pred_ratio=args.pred_ratio, pred_ratio_var=args.pred_ratio_var,
        pred_aspect_ratio=(0.3, 1/0.3), pred_shape=args.pred_shape, pred_start_epoch=args.pred_start_epoch)
    
    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler,
        batch_size=args.batch_size_per_gpu, num_workers=args.num_workers,
        pin_memory=True, drop_last=True )
    print(f"Data loaded: there are {len(dataset)} images.")

    # ============ building student and teacher networks ============
    student = models.__dict__[args.arch](
            patch_size=args.patch_size, drop_path_rate=args.drop_path,
            return_all_tokens=True, masked_im_modeling=args.use_masked_im_modeling)
    teacher = models.__dict__[args.arch]( patch_size=args.patch_size, return_all_tokens=True)
    embed_dim = student.embed_dim
 
    student = FullPipeline(student, CRISPHead( embed_dim, args.out_dim,
        patch_out_dim=args.patch_out_dim, norm=args.norm_in_head, act=args.act_in_head,
        norm_last_layer=args.norm_last_layer), cncpt_blks=args.cncpt_blks )
    teacher = FullPipeline(teacher, CRISPHead( embed_dim,  args.out_dim,
            patch_out_dim=args.patch_out_dim, norm=args.norm_in_head, act=args.act_in_head), cncpt_blks=args.cncpt_blks )
    
    # move networks to gpu
    student, teacher = student.cuda(), teacher.cuda()
    teacher_without_ddp = teacher
        
  
    if train_type == 'warm': #### freeze backbone
        for p in student.backbone.parameters():
            p.requires_grad = False
        student.backbone.rgstr_tokens.requires_grad = True
            
        for p in student.head.mlp.parameters():
            p.requires_grad = False
        for p in student.head.last_layer.parameters():
            p.requires_grad = False
 
    student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
    # teacher and student start with the same weights
    teacher_without_ddp.load_state_dict(student.module.state_dict(), strict=False)
    # there is no backpropagation through the teacher, so no need for gradients
    for p in teacher.parameters():
        p.requires_grad = False
    print(f"Student and Teacher are built: they are both {args.arch} network.")

    # ============ preparing loss ... ============
    same_dim = args.shared_head or args.shared_head_teacher
    crisp_loss = CRISPLoss(
        args.out_dim,
        args.out_dim if same_dim else args.patch_out_dim,
        args.global_crops_number,
        args.local_crops_number,
        args.warmup_teacher_temp,
        args.teacher_temp,
        args.warmup_teacher_patch_temp,
        args.teacher_patch_temp,
        args.warmup_teacher_temp_epochs,
        args.epochs,
        lambda1=args.lambda1,
        lambda2=args.lambda2,
        lambda3=args.lambda3,
        mim_start_epoch=args.pred_start_epoch,
    ).cuda()

    if utils.is_main_process(): # Tensorboard configuration
        local_runs = os.path.join(args.output_dir, 'tf_logs')
        writer = SummaryWriter(logdir=local_runs)
        
    # ============ preparing optimizer ... ============
    params_groups = utils.get_params_groups(student)
    if args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(params_groups)  # to use with ViTs
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9)  # lr is set by scheduler
    elif args.optimizer == "lars":
        optimizer = utils.LARS(params_groups)  # to use with convnet and large batches
    # for mixed precision training
    fp16_scaler = None
    if args.use_fp16:
        fp16_scaler = torch.cuda.amp.GradScaler()

    # ============ init schedulers ... ============
    lr_schedule = utils.cosine_scheduler(
        args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256.,  # linear scaling rule
        args.min_lr,
        args.epochs, len(data_loader),
        warmup_epochs=args.warmup_epochs,
    )
    wd_schedule = utils.cosine_scheduler(
        args.weight_decay,
        args.weight_decay_end,
        args.epochs, len(data_loader),
    )
    # momentum parameter is increased to 1. during training with a cosine schedule
    momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
                                            args.epochs, len(data_loader))
                  
    print("Loss, optimizer and schedulers ready.")

    # ============ optionally resume training ... ============
    start_epoch = 0
    to_restore = {"epoch": start_epoch}
    if args.load_from:
        utils.restart_from_checkpoint(
                os.path.join(args.load_from),
                run_variables=to_restore,
                student=student,
                teacher=teacher,
                optimizer=optimizer,
                fp16_scaler=fp16_scaler,
                ibot_loss=crisp_loss,
            )
        start_epoch = to_restore["epoch"] 
              
        student.module.head.last_layer3.load_state_dict(student.module.head.last_layer.state_dict())
        if student.module.head.last_norm is not None:
            student.module.head.last_norm3.load_state_dict(student.module.head.last_norm.state_dict())
        
        teacher.head.last_layer3.load_state_dict(teacher.head.last_layer.state_dict())
        if teacher.head.last_norm is not None:
            teacher.head.last_norm3.load_state_dict(teacher.head.last_norm.state_dict())
               
        student.module.backbone.norm_cls.load_state_dict(student.module.backbone.norm.state_dict())
        teacher.backbone.norm_cls.load_state_dict(teacher.backbone.norm.state_dict())
           
        crisp_loss.center_grp = crisp_loss.center2
        
        student.module.backbone.rgstr_tokens[:, 4:10].data.copy_(student.module.backbone.cls_token.data.repeat(1, 6, 1))
        teacher.backbone.rgstr_tokens[:, 4:10].data.copy_(teacher.backbone.cls_token.data.repeat(1, 6, 1))


        
    if train_type == 'warm': #### freeze shared and backbone
        crisp_loss.center.requires_grad = False
        
        
    to_restore = {"epoch": start_epoch}
    args.start_from = start_epoch
    
    #if args.load_from:
    utils.restart_from_checkpoint(
            os.path.join(args.output_dir, 'checkpoint.pth'),
            run_variables=to_restore,
            student=student,
            teacher=teacher,
            optimizer=optimizer,
            fp16_scaler=fp16_scaler,
            crisp_loss=crisp_loss,
        )
    
    start_epoch = to_restore["epoch"]

    start_time = time.time()
    print("Starting CRISP training!")
    
    train_to = (args.start_from + args.wepochs) if train_type == 'warm' else args.epochs
    for epoch in range(start_epoch, train_to):
        data_loader.sampler.set_epoch(epoch)
        data_loader.dataset.set_epoch(epoch)

        # ============ training one epoch of CRISP ... ============
        train_stats = train_one_epoch(student, teacher, teacher_without_ddp, crisp_loss,
            data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,
            epoch, fp16_scaler, args)

        # ============ writing logs ... ============
        save_dict = {
            'student': student.state_dict(),
            'teacher': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'args': args,
            'crisp_loss': crisp_loss.state_dict(),
        }
        if fp16_scaler is not None:
            save_dict['fp16_scaler'] = fp16_scaler.state_dict()
        utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
        if args.saveckp_freq and (epoch % args.saveckp_freq == 0) and epoch:
            utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch}
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
                for k, v in train_stats.items():
                    writer.add_scalar(k, v, epoch)
        
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

def map_to_global(X1: torch.Tensor, crop_size, H_orig, W_orig, ref_img_sz):
    (y0, x0, h2, w2) = crop_size

    device = X1.device
    
    y0 = y0.to(device)
    x0 = x0.to(device)
    h2 = h2.to(device)
    w2 = w2.to(device)
    H_orig = H_orig.to(device)
    W_orig = W_orig.to(device)
    
    B, C, H_full, W_full = X1.shape  # Typically (B, C, 448, 448)

    # 1) Create a 224 x 224 meshgrid for (r_3, c_3), then broadcast to (B, 224, 224).
    r_3_vals = torch.linspace(0, 223, steps=224, device=device)
    c_3_vals = torch.linspace(0, 223, steps=224, device=device)

    r_3_grid = r_3_vals.unsqueeze(1).expand(-1, 224)   # (224,224)
    c_3_grid = c_3_vals.unsqueeze(0).expand(224, -1)   # (224,224)

    r_3_grid = r_3_grid.unsqueeze(0).expand(B, -1, -1) # (B,224,224)
    c_3_grid = c_3_grid.unsqueeze(0).expand(B, -1, -1) # (B,224,224)

    # 2) Map (r_3, c_3) back to original image coords (R, C).
    #    R = y0 + r_3*(h2/224),  C = x0 + c_3*(w2/224)
    scale_h2 = h2 / 224.0
    scale_w2 = w2 / 224.0

    R = y0[:, None, None] + r_3_grid * scale_h2[:, None, None]  # (B,224,224)
    C = x0[:, None, None] + c_3_grid * scale_w2[:, None, None]  # (B,224,224)

    # 3) Map (R, C) to the 448 x 448 space:
    #    r_1 = R * (448 / H_orig),  c_1 = C * (448 / W_orig)
    r_1 = R * (ref_img_sz / H_orig[:, None, None])  # (B,224,224)
    c_1 = C * (ref_img_sz / W_orig[:, None, None])  # (B,224,224)

    # 4) Convert (r_1, c_1) to normalized coords for grid_sample with align_corners=False:
    #    x_norm = 2*((x_in + 0.5)/width)  - 1
    #    y_norm = 2*((y_in + 0.5)/height) - 1
    #    (because we want half-pixel alignment)
    c_1_norm = 2.0 * ((c_1 + 0.5) / W_full) - 1.0  # (B,224,224)
    r_1_norm = 2.0 * ((r_1 + 0.5) / H_full) - 1.0  # (B,224,224)

    # 5) Combine into grid: (B, 224, 224, 2) with last dim = (x, y)
    grid = torch.stack([c_1_norm, r_1_norm], dim=-1)

    # 6) grid_sample from X1 using align_corners=False
    out = F.grid_sample(
        X1, 
        grid, 
        mode='bilinear', 
        padding_mode='zeros',
        align_corners=False
    )
    return out



def train_one_epoch(student, teacher, teacher_without_ddp, crisp_loss, data_loader,
                    optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
                    fp16_scaler, args):
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
    
    # common params
    names_q, params_q, names_k, params_k = [], [], [], []
    for name_q, param_q in student.module.named_parameters():
        names_q.append(name_q)
        params_q.append(param_q)
    for name_k, param_k in teacher_without_ddp.named_parameters():
        names_k.append(name_k)
        params_k.append(param_k)
        
    names_common = list(set(names_q) & set(names_k))
    params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common]
    params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common]
    
    save_recon = os.path.join(args.output_dir, 'reconstruction_samples')
    Path(save_recon).mkdir(parents=True, exist_ok=True)
    plot_ = True

    pred_labels, real_labels = [], []
    for it, (images_crop_sz, labels, masks) in enumerate(metric_logger.log_every(data_loader, 100, header)):
        # update weight decay and learning rate according to their schedule
        it = len(data_loader) * epoch + it  # global training iteration
        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedule[it]
            if i == 0:  # only the first group is regularized
                param_group["weight_decay"] = wd_schedule[it]

        images, crop_sz = images_crop_sz

        # move images to gpu
        images = [im.cuda(non_blocking=True) for im in images]
        masks = [msk.cuda(non_blocking=True) for msk in masks]    
        
        flipping_flags = torch.rand(2) > 0.5
        if flipping_flags[0]:
            images[1] = torchvision.transforms.functional.hflip(images[1])
        if flipping_flags[1]:
            images[2] = torchvision.transforms.functional.hflip(images[2])       
        
        with torch.cuda.amp.autocast(fp16_scaler is not None):
            
            # Reference Image
            ref_output, ftrs = teacher(images[0], get_attention=args.ref_avg_blks)
            
            B, T, D = ftrs.shape

            # Step 1: Randomly select one token per batch element
            indices = torch.randint(0, T, (B,), dtype=torch.long, device=ftrs.device)  # [B]
        
            # Step 2: Gather the selected tokens
            selected = ftrs[torch.arange(B, device=ftrs.device), indices]  # [B, D]
        
            # Step 3: Normalize for cosine similarity
            x_norm = F.normalize(ftrs, dim=-1)         
            selected_norm = F.normalize(selected, dim=-1).unsqueeze(1) 
        
            # Step 4: Cosine similarity
            cos_sim = (x_norm * selected_norm).sum(dim=-1) 
            
            binary_masks = (cos_sim > args.sim_thresh).float()
            
            
            ref_grid_sz = args.ref_img_sz//args.patch_size
            global_clusters_masks = F.interpolate(binary_masks.float().reshape(B, ref_grid_sz, ref_grid_sz).unsqueeze(1), size=args.ref_img_sz)

            # Get correspondence between global and crops
            outcome_crop1 =  map_to_global(global_clusters_masks, crop_sz[1], crop_sz[0][2], crop_sz[0][3], args.ref_img_sz)
            outcome_crop1 = outcome_crop1.reshape(outcome_crop1.shape[0], outcome_crop1.shape[1], 14, 16, 14, 16).mean(dim=(3, 5))

            outcome_crop2 =  map_to_global(global_clusters_masks.float(), crop_sz[2], crop_sz[0][2], crop_sz[0][3], args.ref_img_sz)
            outcome_crop2 = outcome_crop2.reshape(outcome_crop2.shape[0], outcome_crop2.shape[1], 14, 16, 14, 16).mean(dim=(3, 5))
            
            outcome_crop1 = outcome_crop1>0.2 # give a bit of slack to consider boundary
            outcome_crop2 = outcome_crop2>0.2 # give a bit of slack to consider boundary
            
            if flipping_flags[0]:
                outcome_crop1 = torchvision.transforms.functional.hflip(outcome_crop1)
            if flipping_flags[1]:
                outcome_crop2 = torchvision.transforms.functional.hflip(outcome_crop2)
            
            
            # get global views
            teacher_output, t_smr, _, _ = teacher(images[1:(1+args.global_crops_number)], concept_masks=torch.cat((outcome_crop1, outcome_crop2), dim=0))
            student_output, s_smr, disc_out, bal_loss = student(images[1:(1+args.global_crops_number)], mask=masks[:args.global_crops_number], concept_masks=torch.cat((outcome_crop1, outcome_crop2), dim=0), disc=False)

            # get local views
            student.module.backbone.masked_im_modeling = False
            student_local_cls = student(images[(args.global_crops_number+1):], lcl=True) if args.local_crops_number > 0 else None
            student.module.backbone.masked_im_modeling = args.use_masked_im_modeling
            
            
            aug_masks = torch.cat((outcome_crop1, outcome_crop2), dim=0).squeeze(1).flatten(1)
            aug_masks = (aug_masks.sum(dim=1) > 2).float() 

            all_loss = crisp_loss(student_output, teacher_output, student_local_cls, masks, 
                                 s_smr, t_smr, aug_masks, epoch)
            
            loss = all_loss.pop('loss') + 0.1*bal_loss
            
        if plot_==True and utils.is_main_process():
            plot_ = False
            
            #import torchvision
            print_out = save_recon + '/epoch_' + str(epoch).zfill(5) + '_it_' + str(it).zfill(5)  +'.jpg' 
 
            
            images_vis = []
            for idx_ in range(min(8, args.batch_size_per_gpu)):
                image_tensor = images[0][idx_]
                _, H, W = image_tensor.shape
                device = image_tensor.device
            
                # Convert token mask to image mask
                grid_size = ref_grid_sz
                token_mask_2d = binary_masks[idx_].reshape(grid_size, grid_size)  # remove CLS
                mask_image = torch.nn.functional.interpolate(
                    token_mask_2d.unsqueeze(0).unsqueeze(0),  # [1,1,Ht,Wt]
                    size=(H, W),
                    mode='nearest'
                ).squeeze()  # [H, W]
            
                # Apply red overlay where mask is 1
                image_with_overlay = image_tensor.clone()
                alpha = 0.7
                green = torch.tensor([1, 0, 0], device=device).view(3, 1, 1)
                image_with_overlay = torch.where(
                    mask_image.bool().unsqueeze(0),
                    alpha * green + (1 - alpha) * image_tensor,
                    image_tensor
                )
            
                # Draw blue dot at selected token
                token_index = indices[idx_]
                patch_size = 16
                if token_index != 0:  # skip CLS
                    idx = token_index - 1
                    row = idx // grid_size
                    col = idx % grid_size
                    y = row * patch_size + patch_size // 2
                    x = col * patch_size + patch_size // 2
            
                    radius = patch_size // 2
                    yy, xx = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij')
                    circle_mask = ((yy - y) ** 2 + (xx - x) ** 2) <= radius ** 2
                    red = torch.tensor([0, 0, 1], device=device).view(3, 1, 1)
                    image_with_overlay = torch.where(
                        circle_mask.unsqueeze(0),
                        red,
                        image_with_overlay
                    )
                    
                images_vis.append(image_with_overlay)
                
            images_aug1 = []
            for idx_ in range(min(8, args.batch_size_per_gpu)):
                image_tensor = images[1][idx_]
                _, H, W = image_tensor.shape
                device = image_tensor.device
            
                # Convert token mask to image mask
                grid_size = 14
                token_mask_2d = outcome_crop1[idx_].reshape(grid_size, grid_size).float()  # remove CLS
                mask_image = torch.nn.functional.interpolate(
                    token_mask_2d.unsqueeze(0).unsqueeze(0),  # [1,1,Ht,Wt]
                    size=(H, W),
                    mode='nearest'
                ).squeeze()  # [H, W]
            
                # Apply red overlay where mask is 1
                image_with_overlay = image_tensor.clone()
                alpha = 0.7
                green = torch.tensor([1, 0, 0], device=device).view(3, 1, 1)
                image_with_overlay = torch.where(
                    mask_image.bool().unsqueeze(0),
                    alpha * green + (1 - alpha) * image_tensor,
                    image_tensor
                )
                
                images_aug1.append(image_with_overlay)
        
            images_aug2 = []
            for idx_ in range(min(8, args.batch_size_per_gpu)):
                image_tensor = images[2][idx_]
                _, H, W = image_tensor.shape
                device = image_tensor.device
            
                # Convert token mask to image mask
                grid_size = 14
                token_mask_2d = outcome_crop2[idx_].reshape(grid_size, grid_size).float()  # remove CLS
                mask_image = torch.nn.functional.interpolate(
                    token_mask_2d.unsqueeze(0).unsqueeze(0),  # [1,1,Ht,Wt]
                    size=(H, W),
                    mode='nearest'
                ).squeeze()  # [H, W]
            
                # Apply red overlay where mask is 1
                image_with_overlay = image_tensor.clone()
                alpha = 0.7
                green = torch.tensor([1, 0, 0], device=device).view(3, 1, 1)
                image_with_overlay = torch.where(
                    mask_image.bool().unsqueeze(0),
                    alpha * green + (1 - alpha) * image_tensor,
                    image_tensor
                )
                
                images_aug2.append(image_with_overlay)

                
            torchvision.utils.save_image(torch.cat( (F.interpolate(torch.stack(images_vis), size=224),
                                                     torch.stack(images_aug1),
                                                     torch.stack(images_aug2)), dim=0) , print_out, nrow=len(images_vis), normalize=True)
            
            

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()), force=True)
            sys.exit(1)

        # log statistics
        probs1 = teacher_output[0].chunk(args.global_crops_number)
        probs2 = student_output[0].chunk(args.global_crops_number)
        pred1 = utils.concat_all_gather(probs1[0].max(dim=1)[1]) 
        pred2 = utils.concat_all_gather(probs2[1].max(dim=1)[1])
        acc = (pred1 == pred2).sum() / pred1.size(0)
        pred_labels.append(pred1)
        real_labels.append(utils.concat_all_gather(labels.to(pred1.device)))

        # student update
        optimizer.zero_grad()
        param_norms = None
        if fp16_scaler is None:
            loss.backward()
            if args.clip_grad:
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            optimizer.step()
        else:
            fp16_scaler.scale(loss).backward()
            if args.clip_grad:
                fp16_scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            fp16_scaler.step(optimizer)
            fp16_scaler.update()

        # EMA update for the teacher
        with torch.no_grad():
            m = momentum_schedule[it]  # momentum parameter
            for param_q, param_k in zip(params_q, params_k):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

        # logging
        torch.cuda.synchronize()
        metric_logger.update(loss=loss.item())
        for key, value in all_loss.items():
            metric_logger.update(**{key: value.item()})
        metric_logger.update(bal_loss=bal_loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
        metric_logger.update(acc=acc)
        
        #break

    pred_labels = torch.cat(pred_labels).cpu().detach().numpy()
    real_labels = torch.cat(real_labels).cpu().detach().numpy()
    nmi, ari, fscore, adjacc = eval_pred(real_labels, pred_labels, calc_acc=False)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("NMI: {}, ARI: {}, F: {}, ACC: {}".format(nmi, ari, fscore, adjacc))
    print("Averaged stats:", metric_logger)
    return_dict = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    return_dict.update({"nmi": nmi, "ari": ari, "fscore": fscore, "adjacc": adjacc})
    return return_dict


class CRISPLoss(nn.Module):
    def __init__(self, out_dim, patch_out_dim, ngcrops, nlcrops, warmup_teacher_temp, 
                 teacher_temp, warmup_teacher_temp2, teacher_temp2, 
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1, 
                 center_momentum=0.9, center_momentum2=0.9,
                 lambda1=1.0, lambda2=1.0, mim_start_epoch=0):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.center_momentum2 = center_momentum2
        self.ngcrops = ngcrops
        self.nlcrops = nlcrops
        self.ncrops = ngcrops + nlcrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        self.register_buffer("center2", torch.zeros(1, 1, patch_out_dim))
        self.register_buffer("center_grp", torch.zeros(1, 1, patch_out_dim))
        self.lambda1 = lambda1
        self.lambda2 = lambda2

        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))
        self.teacher_temp2_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp2,
                        teacher_temp2, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp2
        )) if mim_start_epoch == 0 else np.concatenate((
            np.ones(mim_start_epoch) * warmup_teacher_temp2,
            np.linspace(warmup_teacher_temp2,
                        teacher_temp2, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs - mim_start_epoch) * teacher_temp2
        ))

    def forward(self, student_output, teacher_output, student_local_cls, student_mask, s_grp, t_grp, aug_masks, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_cls, student_patch = student_output
        teacher_cls, teacher_patch = teacher_output
        
        if student_local_cls is not None:
            student_cls = torch.cat([student_cls, student_local_cls])

        # [CLS] and patch for global patches
        student_cls = student_cls / self.student_temp
        student_cls_c = student_cls.chunk(self.ncrops)
        student_patch = student_patch / self.student_temp
        student_patch_c = student_patch.chunk(self.ngcrops)
        
        s_grp = s_grp / self.student_temp
        s_grp_c = s_grp.chunk(self.ngcrops)
        aug_masks = aug_masks.detach().chunk(self.ngcrops)
        
        
        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        temp2 = self.teacher_temp2_schedule[epoch]
        teacher_cls_c = F.softmax((teacher_cls - self.center) / temp, dim=-1)
        teacher_cls_c = teacher_cls_c.detach().chunk(self.ngcrops)
        teacher_patch_c = F.softmax((teacher_patch - self.center2) / temp2, dim=-1)
        teacher_patch_c = teacher_patch_c.detach().chunk(self.ngcrops)
        
        t_grp_c = F.softmax((t_grp - self.center_grp) / temp2, dim=-1)
        t_grp_c = t_grp_c.detach().chunk(self.ngcrops)
        

        total_loss1, n_loss_terms1 = 0, 0
        total_loss2, n_loss_terms2 = 0, 0
        total_loss3, n_loss_terms3 = 0, 0
        for q in range(len(teacher_cls_c)):
            for v in range(len(student_cls_c)):
                if v == q:
                    loss2 = torch.sum(-teacher_patch_c[q] * F.log_softmax(student_patch_c[v], dim=-1), dim=-1)
                    mask = student_mask[v].flatten(-2, -1)
                    loss2 = torch.sum(loss2 * mask.float(), dim=-1) / mask.sum(dim=-1).clamp(min=1.0)
                    total_loss2 += loss2.mean()
                    n_loss_terms2 += 1
                else:
                    loss1 = torch.sum(-teacher_cls_c[q] * F.log_softmax(student_cls_c[v], dim=-1), dim=-1)
                    total_loss1 += loss1.mean()
                    n_loss_terms1 += 1
                    
                    if v < 2:
                        ##### ACCROSS VIEW GRP
                        loss3 = torch.sum(-t_grp_c[q] * F.log_softmax(s_grp_c[v], dim=-1), dim=-1)
                        loss3 = loss3.squeeze() * aug_masks[q] * aug_masks[v]
                        total_loss3 += loss3.sum() / (aug_masks[q] * aug_masks[v]).sum().clamp(min=1.0)
                        n_loss_terms3 += 1

                        
            
        total_loss1 = total_loss1 / n_loss_terms1 * self.lambda1
        total_loss2 = total_loss2 / n_loss_terms2 * self.lambda2
        total_loss3 = total_loss3 / n_loss_terms3 * self.lambda3
        total_loss = dict(cls=total_loss1, patch=total_loss2, grp=total_loss3, loss=total_loss1 + total_loss2 + total_loss3)
        self.update_center(teacher_cls, teacher_patch, t_grp)                  
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_cls, teacher_patch, t_grp):
        """
        Update center used for teacher output.
        """
        cls_center = torch.sum(teacher_cls, dim=0, keepdim=True)
        dist.all_reduce(cls_center)
        cls_center = cls_center / (len(teacher_cls) * dist.get_world_size())
        self.center = self.center * self.center_momentum + cls_center * (1 - self.center_momentum)

        patch_center = torch.sum(teacher_patch.mean(1), dim=0, keepdim=True)
        dist.all_reduce(patch_center)
        patch_center = patch_center / (len(teacher_patch) * dist.get_world_size())
        self.center2 = self.center2 * self.center_momentum2 + patch_center * (1 - self.center_momentum2)
        
        grp_center = torch.sum(t_grp.mean(1), dim=0, keepdim=True)
        dist.all_reduce(grp_center)
        grp_center = grp_center / (len(t_grp) * dist.get_world_size())
        self.center_grp = self.center_grp * self.center_momentum2 + grp_center * (1 - self.center_momentum2)

class DataAugmentationSMR(object):
    def __init__(self, global_crops_scale, local_crops_scale, global_crops_number, local_crops_number, ref_img_sz=448):
        
        color_jitter = transforms.Compose([transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8),
                transforms.RandomGrayscale(p=0.2)])
        
        normalize = transforms.Compose([transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),])
        
        # transformation for the global image
        self.global_transfo = transforms.Compose([
            transforms.Resize((ref_img_sz, ref_img_sz), interpolation=Image.BICUBIC),
            color_jitter,
            normalize])

        self.global_crops_number = global_crops_number
        # transformation for the first global crop
        self.global_transfo1_crop = RandomResizedCropWithParams(224, scale=global_crops_scale, interpolation=Image.BICUBIC)
        self.global_transfo1 = transforms.Compose([
            color_jitter,
            utils.GaussianBlur(1.0),
            normalize])
        
        # transformation for the rest of global crops
        self.global_transfo2_crop = RandomResizedCropWithParams(224, scale=global_crops_scale, interpolation=Image.BICUBIC)
        self.global_transfo2 = transforms.Compose([
            color_jitter,
            utils.GaussianBlur(0.1),
            utils.Solarization(0.2),
            normalize])
        
        # transformation for the local crops
        self.local_crops_number = local_crops_number
        self.local_transfo = transforms.Compose([
            transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            color_jitter,
            utils.GaussianBlur(p=0.5),
            normalize])

    def __call__(self, image):
        crops, crops_sz = [], []
        
        crops.append(self.global_transfo(image))
        crops_sz.append((0, 0, image.size[1], image.size[0]))
        
        img1, (i1, j1, h1, w1) = self.global_transfo1_crop(image)
        crops.append(self.global_transfo1(img1))
        crops_sz.append((i1, j1, h1, w1))
        
        for _ in range(self.global_crops_number - 1):
            img, (i, j, h, w) = self.global_transfo2_crop(image)
            crops.append(self.global_transfo2(img))
            crops_sz.append((i, j, h, w))
            
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(image))
            
        return crops, crops_sz

import torchvision
class RandomResizedCropWithParams(transforms.RandomResizedCrop):
    def __call__(self, img):
        """
        Perform the RandomResizedCrop operation and return both the transformed image
        and the crop parameters (i, j, h, w).
        """
        # Use the parent class's `get_params` to compute the crop window
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
        
        # Perform the crop and resize using the parent class's logic
        cropped_img = torchvision.transforms.functional.resized_crop(img, i, j, h, w, self.size, self.interpolation)
        
        # Return both the cropped image and the crop parameters
        return cropped_img, (i, j, h, w)


class ViT_Concept(nn.Module):
    def __init__(self, embed_dim, num_heads=12, depth=4, mlp_ratio=4., qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm):
        super().__init__()

        self.depth = depth
        
        if depth > 0:
            self.concept_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
            
            #self.linear_proj = nn.Linear(embed_dim, embed_dim)
            
            self.concept_blocks = nn.ModuleList([
                models.vision_transformer.Block(
                    dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer)
                for i in range(depth)])
            
            self.norm = nn.LayerNorm(embed_dim)
            
            utils.trunc_normal_(self.concept_token, std=.02)
            self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            utils.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        
    def forward(self, x, attn_mask):
        
        if self.depth == 0:
            attn_mask = attn_mask.bool().flatten(2).transpose(1, 2).squeeze(2).unsqueeze(-1)
            return ((x*attn_mask).sum(dim=1)/attn_mask.sum(dim=-2).clamp(min=1.0)).unsqueeze(1)
        
        B, N, D = x.shape
 
        full_mask = torch.cat( (torch.ones_like(x[:, 0, 0].unsqueeze(1)), attn_mask.bool().flatten(2).transpose(1, 2).squeeze(2)), dim=1)  # (B, 197)

        # Compute additive mask: (B, 197, 197) where (i,j) = -inf if j is masked out for i
        attn_mask = full_mask.unsqueeze(1).expand(-1, N+1, -1)  # (B, 197, 197)
        attn_mask = attn_mask.logical_not()  # Now True where we want to mask
        attn_mask = attn_mask.float() * -1e9  # Convert to float additive mask
        
        # add the [CLS] token to the embed patch tokens
        concept_token = self.concept_token.expand(B, -1, -1)
        x = torch.cat((concept_token, x), dim=1)
        
        for blk in self.concept_blocks:
            x = blk(x, attn_mask=attn_mask)
        
        x = self.norm(x)
        return x[:, :1]
    
    
class FullPipeline(nn.Module):
    def __init__(self, backbone, head=None, cncpt_blks=4):
        super(FullPipeline, self).__init__()
        # disable layers dedicated to ImageNet labels classification
        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.backbone = backbone
        if head is None:
            self.head = nn.Identity()
        else:
            self.head = head
            
        self.concept_block = ViT_Concept(backbone.embed_dim, backbone.num_heads, depth=cncpt_blks)
        
        #layers = [nn.Linear(backbone.embed_dim, 128)]
        #layers.append(nn.ReLU())
        #layers.append(nn.Linear(128, 32))
        #layers.append(nn.ReLU())
        #layers.append(nn.Linear(32, 1))
 
        #self.discriminator = nn.Sequential(*layers)

    def forward(self, x, mask=None, lcl=False, get_attention=0, concept_masks=None, disc=False, **kwargs):
        # convert to list
        if not isinstance(x, list):
            x = [x]
            mask = [mask] if mask is not None else None
 
        inp_x = torch.cat(x[0:])

        if mask is not None:
            inp_m = torch.cat(mask[0:])
            kwargs.update(dict(mask=inp_m))

        _out, _out_before_norm, hidden_states, bal_loss = self.backbone(inp_x, n_interm_layers=get_attention, **kwargs)
        
        if get_attention > 0:
            return self.head(_out), torch.stack(hidden_states, dim=0).mean(dim=0)
        if lcl:
            return self.head(_out[:, 0])
        
        disc_ftrs = None
        if disc:
            disc_ftrs = self.discriminator(_out[:, 0])
            
            if concept_masks is None:
                return disc_ftrs
        
        
        embed_tokens = self.concept_block(_out_before_norm, concept_masks)

        output_cls, _output_patch, smr_rep = self.head(_out, embed_tokens)
 
        return (output_cls, _output_patch), smr_rep, disc_ftrs, bal_loss

if __name__ == '__main__':
    parser = argparse.ArgumentParser('CRISP', parents=[get_args_parser()])
    args = parser.parse_args()
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    
    train_crisp(args)
