import os
from os.path import join
import os
from torch.utils.data import DataLoader
import torch
from transformers import AutoModel, AutoImageProcessor
from tqdm import tqdm
import pickle
import argparse

from DinoExtractor import DinoWithRegistersExtractor, ImageDataset, CKA

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--image_dir',
        help='Path to the directory containing test images',
        required=True
    )

    parser.add_argument(
        '--batch_size',
        default=32
    )

    return parser.parse_args()

def main(args):
    for model_size in ['small', 'base', 'large', 'giant'][::-1]:
        print('Working on {} model.'.format(model_size))
        processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-" + model_size, use_fast=True)
        model = AutoModel.from_pretrained("facebook/dinov2-with-registers-" + model_size)

        extractor = DinoWithRegistersExtractor(processor, model)

        batch_size = args.batch_size

        dataset = ImageDataset(args.image_dir)
        dataloader = DataLoader(dataset, batch_size=int(batch_size), shuffle=False)

        results = {
            'cls_output_global': [],
            'cls_output_local': [],
            'cls_output_total': [],
            'attn_weight': [],
            'pre_norm_hidden_states': [],
            'post_norm_hidden_states': []
        }
        for i, inputs in enumerate(tqdm(dataloader)):
            # Forward pass
            outputs = extractor.get_features(inputs, include_skip_connection=False)
            # keep only the activations for the highest-norm token
            token_id = outputs['pre_norm_hidden_states'].norm(dim=-1).mean(0).argmax()
            outputs['pre_norm_hidden_states'] = outputs['pre_norm_hidden_states'][:, token_id, :]
            outputs['post_norm_hidden_states'] = outputs['post_norm_hidden_states'][:, token_id, :]
            pass
            for key, val in outputs.items():
                results[key].append(val.squeeze().detach().cpu())

        for key, val in results.items():
            results[key] = torch.cat(val)

        results_to_save = {}

        results_to_save['mean_attention_global'] = results['attn_weight'][:, :, 0, :5].sum(-1).mean(-1)
        results_to_save['mean_attention_local'] = results['attn_weight'][:, :, 0, 5:].sum(-1).mean(-1)

        results_to_save['global_local_output_sim'] = torch.nn.functional.cosine_similarity(results['cls_output_global'].squeeze(),
                                              results['cls_output_local'].squeeze())
        results_to_save['global_total_output_sim'] = torch.nn.functional.cosine_similarity(results['cls_output_global'].squeeze(),
                                              results['cls_output_total'].squeeze())
        results_to_save['local_total_output_sim'] = torch.nn.functional.cosine_similarity(results['cls_output_local'].squeeze(),
                                              results['cls_output_total'].squeeze())

        results_to_save['cka_total_local'] = CKA(results['cls_output_total'], results['cls_output_local'])
        results_to_save['cka_total_global'] = CKA(results['cls_output_total'], results['cls_output_global'])

        # get results for plotting the dimensions of the highest-norm token
        vals, idx = results['pre_norm_hidden_states'].mean(0).abs().sort(descending=True)
        vals, idx = vals[:100], idx[:100]

        results_to_save['pre_norm_hidden_states'] = results['pre_norm_hidden_states'].mean(0).abs()[idx]
        results_to_save['post_norm_hidden_states'] = results['post_norm_hidden_states'].mean(0).abs()[idx]

        # get cosine similarity of highest-norm tokens
        normed_hidden_states = results['pre_norm_hidden_states'] / results['pre_norm_hidden_states'].norm(dim=-1,                                                                                         keepdim=True)
        gram = normed_hidden_states @ normed_hidden_states.T

        # get mean similarity
        mask = torch.triu(torch.ones_like(gram), diagonal=1).bool()
        results_to_save['highest_token_similarity'] = gram[mask].mean()

        with open(join('results', 'register_models', model_size + '_results' + '.pkl'), 'wb') as fp:
            pickle.dump(results_to_save, fp)

if __name__ == '__main__':
    args = parse_args()
    main(args)