import argparse
import pickle
from os.path import join

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoImageProcessor

from DinoExtractor import DinoWithRegistersExtractor, ImageDataset, CKA


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--image_dir',
        help='Path to the directory containing test images',
        required=True
    )

    parser.add_argument(
        '--batch_size',
        default=32
    )

    return parser.parse_args()


def main(args):

    results = {}
    for model_size in ['small', 'base', 'large', 'giant'][::-1]:
        print('Working on {} model.'.format(model_size))
        processor = AutoImageProcessor.from_pretrained("facebook/dinov2-" + model_size, use_fast=True)
        model = AutoModel.from_pretrained("facebook/dinov2-" + model_size)

        extractor = DinoWithRegistersExtractor(processor, model)

        img_dir = args.image_dir

        batch_size = args.batch_size

        dataset = ImageDataset(img_dir)
        dataloader = DataLoader(dataset, batch_size=int(batch_size), shuffle=False)

        model_results = {
            'cls_output_residual': [],
            'cls_output_patch': [],
            'cls_output_total': [],
            'residual_norm': [],
            'patch_norm': []
        }
        for i, inputs in enumerate(tqdm(dataloader)):
            # Forward pass
            outputs = extractor.get_features(inputs, probe_registers=False)
            # keep only the activations for the highest-norm token
            pass
            for key, val in outputs.items():
                model_results[key].append(val.squeeze().detach().cpu())

            if i == 100: # only use the first 3.2k samples to reduce the memory demand. only relevant for imagenet
                break

        for key, val in model_results.items():
            model_results[key] = torch.cat(val)


        results_to_save = {}

        results_to_save['patch_attention'] = model_results['patch_norm'] / (model_results['patch_norm'] + model_results['residual_norm'])
        results_to_save['residual_attention'] = model_results['residual_norm'] / (model_results['patch_norm'] + model_results['residual_norm'])

        results_to_save['cka_patch'] = CKA(model_results['cls_output_total'], model_results['cls_output_patch'])
        results_to_save['cka_residual'] = CKA(model_results['cls_output_total'], model_results['cls_output_residual'])

        results[model_size] = results_to_save


    with open(join('results', 'residual_results' + '.pkl'), 'wb') as fp:
        pickle.dump(results, fp)

if __name__ == '__main__':
    args = parse_args()
    main(args)