#Acknowledgement: This repository is built using the timm library, the DeiT repository and the Dino repository.

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

import argparse
import os
import sys
import datetime
import time
import math
import json
from pathlib import Path

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision import models as torchvision_models
from numpy.random import randint
from datasets import prepare_datasets

import utils
import vision_transformer as vits
from vision_transformer import PatchHead, RECHead

torchvision_archs = sorted(name for name in torchvision_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(torchvision_models.__dict__[name]))

def get_args_parser():
    parser = argparse.ArgumentParser('MCSSL', add_help=False)

    # Model parameters
    parser.add_argument('--arch', default='vit_small', type=str, help='Transformer architecture')
    parser.add_argument('--patch_size', default=16, type=int, help='Patch size')
    parser.add_argument('--img_size', default=224, type=int, help='Image size')
    
    # meta-parameters
    parser.add_argument('--batch_size', default=16, type=int, help='batch size')
    parser.add_argument('--epochs', default=100, type=int, help='Number of training epochs.')
    parser.add_argument('--out_dim_data', default=1024, type=int, help='Output of the patch classification head')    
    parser.add_argument('--drop_perc', type=float, default=0.3, help='Drop percentage') 

    
 
    # selfsupervised Tasks
    parser.add_argument('--applyReconstruction', default=True, type=utils.bool_flag, help='Train the network to reconstruct the images as well.')
        
    parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
        during which we keep the output layer fixed. Typically doing so during
        the first epoch helps training. Try increasing this value if the loss does not decrease.""")

    parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
        parameter for teacher update. The value is increased to 1 during training with cosine schedule.
        We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")

    # Temperature teacher parameters
    parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
        help="""Initial value for the teacher temperature: 0.04 works well in most cases. Try decreasing it if the training loss does not decrease.""")
    parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
        of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
        starting with the default value of 0.04 and increase this slightly if needed.""")
    parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int,
        help='Number of warmup epochs for the teacher temperature (Default: 30).')

    # Multi-crop parameters
    parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
        help="""Scale range of the cropped image before resizing, relatively to the origin image.
        Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
        recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
    parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
        local views to generate. Set this parameter to 0 to disable multi-crop training.
        When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
    parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
        help="""Scale range of the cropped image before resizing, relatively to the origin image.
        Used for small local view cropping of multi-crop.""")
        
        
    # Training/Optimization parameters
    parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
        to use half precision for training. Improves training time and memory requirements,
        but can provoke instability and slight decay of performance. We recommend disabling
        mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
    parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
        weight decay. With ViT, a smaller value at the beginning of training works well.""")
    parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
        weight decay. We use a cosine schedule for WD and using a larger decay by
        the end of training improves performance for ViTs.""")
    parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter
        gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
        help optimization for larger ViT architectures. 0 for disabling.""")
    
    parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of
        linear warmup (highest LR used during training). The learning rate is linearly scaled
        with the batch size, and specified here for a reference batch size of 256.""")
    parser.add_argument("--warmup_epochs", default=10, type=int,
        help="Number of epochs for the linear learning-rate warm up.")
    parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
        end of optimization. We use a cosine LR schedule with linear warmup.""")
    parser.add_argument('--optimizer', default='adamw', type=str,
        choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""")
    parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate")

    # Dataset
    parser.add_argument('--data_set', default='CIFAR10', type=str, 
                        choices=['CIFAR10', 'CIFAR100', 'Cars', 'Flowers', 'ImageNet', 'VisualGenome500', 'PASCALVOC', 'MSCOCO'], 
                        help='Name of the dataset.')
    parser.add_argument('--ImageNet_trainingFile', default='TrainFiles_10percent_shuffled.csv', type=str, help='path to the ImageNet data.')
    parser.add_argument('--data_location', default='path/to/data', type=str, help='path to the datasetfiles.')
    
    # Misc
    parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.')
    parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.')
    parser.add_argument('--seed', default=0, type=int, help='Random seed.')
    parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
    parser.add_argument("--dist_url", default="env://", type=str, help="""set up distributed training""")
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    return parser


def train_MCSSL(args):
    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True

    # ============ preparing data ... ============
    transform = DataAugmentationSiT(args)    
    dataset , _ = prepare_datasets.build_dataset(args, True, trnsfrm=transform)
    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
        dataset, sampler=sampler, batch_size=args.batch_size,
        num_workers=args.num_workers, pin_memory=True, drop_last=True)
    print(f"Data loaded: there are {len(dataset)} images.")


    # ============ building student and teacher networks ... ============
    student = vits.__dict__[args.arch](patch_size=args.patch_size, img_size=[args.img_size], drop_path_rate=args.drop_path_rate)
    teacher = vits.__dict__[args.arch](patch_size=args.patch_size, img_size=[args.img_size])
    embed_dim = student.embed_dim


    # create the full pipline
    RecHead_s, RecHead_t = nn.Identity(), nn.Identity()
    if args.applyReconstruction:
         RecHead_s, RecHead_t = RECHead(embed_dim), RECHead(embed_dim)
         
    student = FullpiplineSiT(student, PatchHead(embed_dim, args.out_dim_data), RecHead_s)
    teacher = FullpiplineSiT(teacher, PatchHead(embed_dim, args.out_dim_data), RecHead_t)
    
    # move networks to gpu
    student, teacher = student.cuda(), teacher.cuda()
    teacher_without_ddp = teacher
        
    student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
    
    # teacher and student start with the same weights
    teacher_without_ddp.load_state_dict(student.module.state_dict())
    # there is no backpropagation through the teacher, so no need for gradients
    for p in teacher.parameters():
        p.requires_grad = False
        
    print("Student and Teacher are built.")


    # ============ preparing loss ... ============
    clssf_loss = DATATOKENLoss(args.out_dim_data, args.warmup_teacher_temp, args.teacher_temp, 
        args.warmup_teacher_temp_epochs, args.epochs).cuda()
    recons_loss = torch.nn.L1Loss().cuda()
    
    # ============ preparing optimizer ... ============
    params_groups = utils.get_params_groups(student)
    optimizer = torch.optim.AdamW(params_groups)  # to use with ViTs

    # for mixed precision training
    fp16_scaler = None
    if args.use_fp16:
        fp16_scaler = torch.cuda.amp.GradScaler()

    # ============ init schedulers ... ============
    lr_schedule = utils.cosine_scheduler(args.lr * (args.batch_size * utils.get_world_size()) / 256.,  
        args.min_lr, args.epochs, len(data_loader), warmup_epochs=args.warmup_epochs)
    wd_schedule = utils.cosine_scheduler(args.weight_decay,
        args.weight_decay_end, args.epochs, len(data_loader))
    
    # momentum parameter is increased to 1. during training with a cosine schedule
    momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, args.epochs, len(data_loader))

    # ============ resume training ... ============
    to_restore = {"epoch": 0}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth"),
        run_variables=to_restore, student=student,
        teacher=teacher, optimizer=optimizer,
        fp16_scaler=fp16_scaler, clssf_loss=clssf_loss)
    start_epoch = to_restore["epoch"]

    start_time = time.time()
    print("Starting training !")
    for epoch in range(start_epoch, args.epochs):
        data_loader.sampler.set_epoch(epoch)

        # ================ training ... ================
        train_stats = train_one_epoch(student, teacher, teacher_without_ddp, clssf_loss, recons_loss, data_loader, 
                                      optimizer, lr_schedule, wd_schedule, momentum_schedule, epoch, fp16_scaler, args)

        # ============ writing logs ... ============
        save_dict = {
            'student': student.state_dict(),
            'teacher': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'args': args,
            'clssf_loss': clssf_loss.state_dict(),
        }
        
        if fp16_scaler is not None:
            save_dict['fp16_scaler'] = fp16_scaler.state_dict()
            
        utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
        
        if args.saveckp_freq and epoch % args.saveckp_freq == 0:
            utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch}
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
                
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


def train_one_epoch(student, teacher, teacher_without_ddp, clssf_loss, recons_loss, data_loader,
                    optimizer, lr_schedule, wd_schedule, momentum_schedule, epoch, fp16_scaler, args):
    
    save_recon = os.path.join(args.output_dir, 'reconstruction_samples')
    Path(save_recon).mkdir(parents=True, exist_ok=True)
    bz = args.batch_size
    saveImgs = True
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
    for it, ((clean_crops, corrupted_crops), _) in enumerate(metric_logger.log_every(data_loader, 100, header)):
        # update weight decay and learning rate according to their schedule
        it = len(data_loader) * epoch + it  # global training iteration
        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedule[it]
            if i == 0:  # only the first group is regularized
                param_group["weight_decay"] = wd_schedule[it]

        # move images to gpu
        clean_crops = [im.cuda(non_blocking=True) for im in clean_crops]
        
        # drop replace
        corrupted_crops = [im.cuda(non_blocking=True) for im in corrupted_crops]
        corrupted_crops = distortImages(corrupted_crops, drop_perc=args.drop_perc)
        
        
        '''
        # for validation
        imm = 0
        import torchvision 
        torchvision.transforms.ToPILImage()(clean_crops[0][imm].clamp(-1, 1).sub(-1).div(max(2, 1e-5))).convert("RGB").show()
        torchvision.transforms.ToPILImage()(corrupted_crops[0][imm].clamp(-1, 1).sub(-1).div(max(2, 1e-5))).convert("RGB").show()
        
        torchvision.transforms.ToPILImage()(clean_crops[1][imm].clamp(-1, 1).sub(-1).div(max(2, 1e-5))).convert("RGB").show()
        torchvision.transforms.ToPILImage()(corrupted_crops[1][imm].clamp(-1, 1).sub(-1).div(max(2, 1e-5))).convert("RGB").show()
        
        '''
        
        
        # teacher and student forward passes + compute dino loss
        with torch.cuda.amp.autocast(enabled=False):
            t_datatokens_g, _, t_datatokens_l, _ = teacher(clean_crops) 
            s_datatokens_g, s_recons_g, s_datatokens_l, s_recons_l = student(corrupted_crops)
            
            recloss = 0.
            if args.applyReconstruction == True:
                recloss = recons_loss(s_recons_g, torch.cat(clean_crops[0:2])) 
                if args.local_crops_number > 0:
                    recloss += recons_loss(s_recons_l, torch.cat(clean_crops[2:])) 
                
                if saveImgs==True and utils.is_main_process():
                    saveImgs = False
                    #validating: check the reconstructed images
                    print_out = save_recon + '/epoch_' + str(epoch).zfill(5)  + '.jpg' 
                    imagesToPrint = torch.cat([clean_crops[0][0: min(15, bz)].cpu(),  corrupted_crops[0][0: min(15, bz)].cpu(),
                                           s_recons_g[0: min(15, bz)].cpu()], dim=0)
                    torchvision.utils.save_image(imagesToPrint, print_out, nrow=min(15, bz), normalize=True, range=(-1, 1))
                
            
            dtloss = clssf_loss(s_datatokens_g, t_datatokens_g, epoch) 
            if args.local_crops_number > 0:
                dtloss += clssf_loss(s_datatokens_l, t_datatokens_l, epoch)
                dtloss /= 2.0
            
            loss = dtloss + recloss
                        
        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()), force=True)
            sys.exit(1)

        # student update
        optimizer.zero_grad()
        param_norms = None
        if fp16_scaler is None:
            loss.backward()
            if args.clip_grad:
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
            optimizer.step()
        else:
            fp16_scaler.scale(loss).backward()
            if args.clip_grad:
                fp16_scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
            fp16_scaler.step(optimizer)
            fp16_scaler.update()

        # EMA update for the teacher
        with torch.no_grad():
            m = momentum_schedule[it]  # momentum parameter
            for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

        # logging
        torch.cuda.synchronize()
        metric_logger.update(dtloss=dtloss.item())
                
        metric_logger.update(recloss=recloss.item()) if hasattr(recloss, 'item') else metric_logger.update(recloss=recloss)

        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}



class DATATOKENLoss(nn.Module):
    def __init__(self, out_dim, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1, center_momentum=0.9):
        super().__init__()
        
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        
        self.register_buffer("center", torch.zeros(1, 1, out_dim))

        
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student data tokens.
        """
        
        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach()
        
        loss = torch.sum(-teacher_out * F.log_softmax(student_output / self.student_temp, dim=-1), dim=-1)
        total_loss = loss.mean()
        
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        sz = teacher_output.size()
        
        batch_center = torch.sum(torch.sum(teacher_output, dim=0, keepdim=True), dim=1, keepdim=True)
        dist.all_reduce(batch_center)
        batch_center = batch_center / (sz[0] * sz[1] * dist.get_world_size())

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


def distortImages(samples, drop_perc=0.3):
    if not isinstance(samples, list):
        samples = [samples]
    
    n_imgs = samples[0].size()[0] #this is batch size, but in case bad instance happened while loading
    
    for si, s in enumerate(samples):
        samples_aug = s.detach().clone()
        for i in range(n_imgs):
            idx_rnd = randint(0, n_imgs)
            if idx_rnd != i:
                samples_aug[i] = drop_rand_patches(samples_aug[i], samples[si][idx_rnd], max_drop=drop_perc)
        samples[si] = samples_aug
      
    return samples

def drop_rand_patches(X, X_rep=None, max_drop=0.3, max_block_sz=0.25, tolr=0.05):
    #######################
    # X_rep: replace X with patches from X_rep. If X_rep is None, replace the patches with Noise
    # max_drop: percentage of image to be dropped
    # max_block_sz: percentage of the maximum block to be dropped
    # tolr: minimum size of the block in terms of percentage of the image size
    #######################
    
    C, H, W = X.size()
    n_drop_pix = np.random.uniform(0, max_drop)*H*W
    mx_blk_height = int(H*max_block_sz)
    mx_blk_width = int(W*max_block_sz)
    
    tolr = (int(tolr*H), int(tolr*W))
    
    total_pix = 0
    while total_pix < n_drop_pix:
        
        # get a random block by selecting a random row, column, width, height
        rnd_r = randint(0, H-tolr[0])
        rnd_c = randint(0, W-tolr[1])
        rnd_h = min(randint(tolr[0], mx_blk_height)+rnd_r, H) #rnd_r is alread added - this is not height anymore
        rnd_w = min(randint(tolr[1], mx_blk_width)+rnd_c, W)
        
        if X_rep is None:
            X[:, rnd_r:rnd_h, rnd_c:rnd_w] = torch.empty((C, rnd_h-rnd_r, rnd_w-rnd_c), dtype=X.dtype, device=X.device).normal_()
        else:
            X[:, rnd_r:rnd_h, rnd_c:rnd_w] = X_rep[:, rnd_r:rnd_h, rnd_c:rnd_w]    
         
        total_pix = total_pix + (rnd_h-rnd_r)*(rnd_w-rnd_c)

    return X


class FullpiplineSiT(nn.Module):

    def __init__(self, backbone, head_datatokens, head_recons):
        super(FullpiplineSiT, self).__init__()

        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.backbone = backbone
        self.head_datatokens = head_datatokens
        self.head_recons = head_recons

    def forward(self, x, global_crops=2):  
        
        # global output
        _out_global = self.backbone(torch.cat(x[0:global_crops]))
        output_data_global = self.head_datatokens(_out_global[:, 1:])
        output_recons_global = self.head_recons(_out_global[:, 1:])
        
        # local_output
        if len(x) > global_crops:
            _out_local = self.backbone(torch.cat(x[global_crops:]))
            output_data_local = self.head_datatokens(_out_local[:, 1:])
            output_recons_local = self.head_recons(_out_local[:, 1:])
        
            return output_data_global, output_recons_global, output_data_local, output_recons_local
        
        return output_data_global, output_recons_global, None, None


class DataAugmentationSiT(object):
    def __init__(self, args):

        self.local_crops_number = args.local_crops_number
        self.drop_perc = args.drop_perc
       
        self.rand_resize_flip = transforms.Compose([
            transforms.RandomResizedCrop(args.img_size, scale=args.global_crops_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5)])
        
        self.rand_resize_flip_local = transforms.Compose([
            transforms.RandomResizedCrop(96, scale=args.local_crops_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5)])


        # light color jittering for the teacher
        self.color_jitter1 = transforms.Compose([
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01)],
                p=0.3
            ),
        ])

        # harsh color jittering for the student
        self.color_jitter2 = transforms.Compose([
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
        ])

        normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # clean crop
        self.clean_transfo = transforms.Compose([
            self.color_jitter1,
            utils.GaussianBlur(0.1),
            normalize,
        ])

        # corrupted crop
        self.corrupt_transfo = transforms.Compose([
            self.color_jitter2,
            utils.GaussianBlur(1.0),
            utils.Solarization(0.2),
            normalize,
        ])

    def __call__(self, image):

        clean_crops = []
        corrupted_crops = []

        ## augmented 1
        im = self.rand_resize_flip(image)

        im_orig = self.clean_transfo(im)
        im_corrupted = self.corrupt_transfo(im)
        im_corrupted = drop_rand_patches(im_corrupted, max_drop=self.drop_perc)

        clean_crops.append(im_orig)
        corrupted_crops.append(im_corrupted)


        ## augmented 2
        im = self.rand_resize_flip(image)

        im_orig = self.clean_transfo(im)
        im_corrupted = self.corrupt_transfo(im)
        im_corrupted = drop_rand_patches(im_corrupted, max_drop=self.drop_perc)

        clean_crops.append(im_orig)
        corrupted_crops.append(im_corrupted)
        
        for _ in range(self.local_crops_number):
            im = self.rand_resize_flip_local(image)

            im_orig = self.clean_transfo(im)
            im_corrupted = self.corrupt_transfo(im)
            im_corrupted = drop_rand_patches(im_corrupted, max_drop=self.drop_perc)
    
            clean_crops.append(im_orig)
            corrupted_crops.append(im_corrupted)

        return clean_crops, corrupted_crops


if __name__ == '__main__':
    parser = argparse.ArgumentParser('MCSSL', parents=[get_args_parser()])
    args = parser.parse_args()
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    train_MCSSL(args)
