# --------------------------------------------------------
# BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366)
# Github source: https://github.com/microsoft/unilm/tree/master/beitv2
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Zhiliang Peng
# Based on BEiT, timm, DeiT and DINO code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'

import os
import sys
import argparse

import torch
from torch import nn
from torchvision import transforms as pth_transforms
from timm.models import create_model

from PIL import Image

import utils
import modeling_vqkd 

def get_code(args):
    # ============ preparing data ... ============
    transform = pth_transforms.Compose([
        pth_transforms.Resize(256, interpolation=3),
        pth_transforms.CenterCrop(224),
        pth_transforms.ToTensor(),
        # pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # Normalize in pre-process of vqkd
    ])
    print(f"Image transforms: {transform}")

    images = transform(Image.open(args.img_path)).unsqueeze(0)

    # ============ building network ... ============
    model = create_model(
            args.model,
            pretrained=True,
            pretrained_weight=args.pretrained_weights,
            as_tokenzer=True,
        ).eval()

    input_ids = model.get_codebook_indices(images)
    print(input_ids)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Get code for VQ-KD')
    parser.add_argument('--model', default='vqkd_encoder_base_decoder_1x768x12_clip', type=str, help="model")
    parser.add_argument('--pretrained_weights', 
                        default='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/vqkd_encoder_base_decoder_1x768x12_clip-d93179da.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D', 
                        type=str, help="Path to pretrained weights to evaluate.")
    parser.add_argument('--img_path', default='demo/ILSVRC2012_val_00031649.JPEG', type=str, help="image path.")
    args = parser.parse_args()

    get_code(args)
    
    
# tensor([[3812, 7466, 1913, 1913, 1903, 1913, 1903, 1913, 3812, 7820, 6337, 2189,
#          7466, 7466, 2492, 3743, 5268, 3481, 5268, 4987,  445, 8009, 3501, 5268,
#          7820, 7831, 4816, 2189, 7549, 7549, 5548, 4987,  445, 4198,  445, 5216,
#          4987, 5268, 3278, 5203, 6337, 1799,  847, 6454, 4527, 5302, 8009, 3743,
#          5216, 4678, 3743, 4858, 5203, 4816, 7831, 2189, 7549, 5386, 6628, 5004,
#          2779, 7131, 7131, 7131, 4928, 3743,  119,  445, 1903, 7466, 4527, 5386,
#          5398, 5704, 2104, 5398, 2779, 7258, 7989,  624, 7131, 1186, 5216, 7466,
#          8015, 5004,  452, 7243, 3145, 6690, 7017, 2104, 5398, 4198, 7989, 7131,
#          3717, 7466,  580, 5004, 5004, 6202, 6202, 6202, 1826, 7521, 1473, 5722,
#          2486, 5663, 4928, 3941,  580, 5548, 7983, 7983, 7983, 2104, 5004, 2063,
#          2637, 1822, 3100, 3100, 1405, 1637, 8187, 5433, 2779, 5398, 5004, 5004,
#          1107, 3469, 3469, 5302, 2590, 6381, 3100, 4194, 3717,  356, 7131, 7688,
#          5104, 3081, 3812, 3950, 1186, 7131, 7131, 3717, 4399, 1186, 2221, 6501,
#          7131, 5433, 3014, 3950, 3278, 2812, 7131, 1186, 7036, 6947, 7036, 4648,
#          2812, 7131, 3014, 5295, 7266, 5180, 4123, 3792, 4648, 8009, 4648, 4816,
#          1511, 7036,  375, 2221, 5813, 5698,  168, 7131, 3792, 5698, 5698, 2667,
#          5698, 4648, 4171, 6501]], device='cuda:0')