#!/usr/bin/env python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import numpy as np

import segment.datasets as datasets
import vits_seg_hrch_v3_vis_attn


model_names = ['vit_conv_small']

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet50)')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')

# additional configs:
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--save_dir', default='', type=str,
                    help='path to saved attention.')

best_acc1 = 0


def main():
    args = parser.parse_args()

    main_worker(args)


def main_worker(args):

    # create model
    print("=> creating model '{}'".format(args.arch))
    model = vits_seg_hrch_v3_vis_attn.__dict__[args.arch]()

    # load from pre-trained, before DistributedDataParallel constructor
    print("=> loading checkpoint '{}'".format(args.pretrained))
    checkpoint = torch.load(args.pretrained, map_location="cpu")
    interpolate_pos_embed(model, checkpoint['state_dict'])

    # rename moco pre-trained keys
    state_dict = checkpoint['state_dict']
    for k in list(state_dict.keys()):
        # retain only base_encoder up to before the embedding layer
        if k.startswith('module.base_encoder'):
            # remove prefix
            state_dict[k[len("module.base_encoder."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]

    args.start_epoch = 0
    msg = model.load_state_dict(state_dict, strict=False)
    print(msg.missing_keys)
    print("=> loaded pre-trained model '{}'".format(args.pretrained))

    model.cuda()

    cudnn.benchmark = True

    # Data loading code
    valdir = os.path.join(args.data)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_dataset = ReturnIndexDataset(
        valdir,
        transforms.Compose([
          transforms.Resize(256),
          transforms.CenterCrop(224),
          transforms.ToTensor(),]),
        normalize=normalize,
        n_segments=196,
        compactness=10.0,
        blur_ops=transforms.GaussianBlur(3, sigma=(1.0, 2.0)),
        slic_scale_factor=1.0,
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    validate(val_loader, val_dataset, model, args)


def validate(val_loader, val_dataset, model, args):
    # switch to evaluate mode
    model.eval()


    with torch.no_grad():
        for i, (images, segments, img_ids) in enumerate(val_loader):
            images = images.cuda()
            segments = segments.cuda()

            # compute output
            outputs = model.forward_features(images, segments)

            seg_attns = {'level1': [], 'level2': [], 'level3': [], 'level4': [], 'path': []}
            # Spectral clustering.
            pool_inds = torch.arange(segments.max()+1, dtype=torch.long).cuda()
            pool_inds = pool_inds.view(1, -1).repeat(segments.shape[0], 1)
            for level in [1, 2, 3, 4]:
                # Visualize cls_token-to-nodes attention.
                attn = outputs['cls_block_attn{:d}'.format(level)][:, :, 0]
                attn = torch.softmax(attn, dim=-1)
                attn = attn[..., 1:]

                attns = torch.gather(attn, 2, pool_inds.unsqueeze(1).expand(-1, attn.shape[1], -1))
                attns = torch.gather(attns, 2, segments.unsqueeze(1).flatten(2, 3).expand(-1, attn.shape[1], -1))
                attns = attns.view(attn.shape[0], attn.shape[1], segments.shape[-2], segments.shape[-1])
                seg_attns['level{:d}'.format(level)].append(attns.data.cpu().numpy())
                pool_inds = torch.gather(torch.argmax(outputs['logit{:d}'.format(level)], dim=-1), 1, pool_inds)

            for img_id in img_ids.data.cpu().numpy():
                img_id = img_id.item()
                img_path = val_dataset.samples[img_id][0]
                seg_attns['path'].append(img_path)

            os.makedirs(args.save_dir, exist_ok=True)
            for attn_ind in range(len(seg_attns['path'])):
                basename = os.path.basename(seg_attns['path'][attn_ind]).replace('.JPEG', '.npy')
                attn_path = os.path.join(args.save_dir, basename)
                np.save(attn_path,
                        {'level1': seg_attns['level1'][attn_ind],
                         'level2': seg_attns['level2'][attn_ind],
                         'level3': seg_attns['level3'][attn_ind],
                         'level4': seg_attns['level4'][attn_ind],
                         'image_path': seg_attns['path'][attn_ind],
                         'slic_seg': segments.data.cpu().numpy()},
                )


def interpolate_pos_embed(model, checkpoint_model):
    for k in checkpoint_model:
        if 'pos_embed' in k:
            pos_embed_checkpoint = checkpoint_model[k]
            embedding_size = pos_embed_checkpoint.shape[-1]
            num_patches = model.patch_embed.num_patches
            num_extra_tokens = model.pos_embed.shape[-2] - num_patches
            # height (== width) for the checkpoint position embedding
            orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
            # height (== width) for the new position embedding
            new_size = int(num_patches ** 0.5)
            # class_token and dist_token are kept unchanged
            if orig_size != new_size:
                print("Position interpolate from %dx%d to %dx%d, with key name %s" % (orig_size, orig_size, new_size, new_size, k))
                extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
                # only the position tokens are interpolated
                pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
                pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
                pos_tokens = torch.nn.functional.interpolate(
                    pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
                pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
                new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                checkpoint_model[k] = new_pos_embed

class ReturnIndexDataset(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, seg, lab = super(ReturnIndexDataset, self).__getitem__(idx)
        return img, seg, idx


if __name__ == '__main__':
    main()

