import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoImageProcessor
import argparse
from os.path import join

from DinoExtractor import ImageDataset


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--image_dir',
        help='Path to the directory containing test images',
        required=True
    )

    parser.add_argument(
        '--batch_size',
        default=32
    )

    return parser.parse_args()

def main(args):

    results = {}
    for model_size in ['small', 'base', 'large', 'giant'][::-1]:
        ### Use the standard DINOv2 model ###
        processor = AutoImageProcessor.from_pretrained("facebook/dinov2-" + model_size, use_fast=True)
        model = AutoModel.from_pretrained("facebook/dinov2-" + model_size)

        device = 'cuda'
        model = model.to(device)
        model.eval()

        img_dir = args.image_dir

        batch_size = args.batch_size

        dataset = ImageDataset(img_dir)
        dataloader = DataLoader(dataset, batch_size=int(batch_size), shuffle=False)
        model_sims = []
        for i, inputs in enumerate(tqdm(dataloader)):
            with torch.no_grad():
                inputs = processor(images=inputs, return_tensors="pt", do_center_crop=False, do_resize=False,
                                        do_rescale=False)

                # move inputs to model device. Need not be done elsewhere
                inputs = {k: v.to(device) for k, v in inputs.items()}

                outputs = model(inputs['pixel_values'], output_attentions=True, output_hidden_states=True)

            # take the last hidden state before final layer norm
            last_hidden_state = outputs['hidden_states'][-1][:, 0, :]
            hidden_states = torch.stack(outputs['hidden_states'])[:, :, 0, :]

            # normalize
            last_hidden_state = last_hidden_state / torch.norm(last_hidden_state, dim=-1, keepdim=True)
            hidden_states = hidden_states / torch.norm(hidden_states, dim=-1, keepdim=True)

            cos_sim = torch.einsum('lbf,bf->lb', hidden_states, last_hidden_state)
            model_sims.append(cos_sim)

        results[model_size] = torch.cat(model_sims, dim=-1).mean(-1)



    with open(join('results', 'layerwise_sim_results.pkl'), 'wb') as fp:
        pickle.dump(results, fp)

    ### Make the plot

    with open(join('results', 'layerwise_sim_results.pkl'), 'rb') as fp:
        data = pickle.load(fp)

    for key, vector in data.items():
        # Create x values normalized from 0 to 1, regardless of vector length
        vector = vector.detach().cpu().numpy()
        x = np.linspace(0, 1, len(vector))
        plt.plot(x, vector, label=key)

    plt.title('Similarity of cls token to last layer cls token')
    plt.xlabel("Normalized Depth")
    plt.ylabel("Cosine similarity")
    plt.legend()
    plt.savefig(join('results', 'layerwise_sim_results.pdf'))


if __name__ == '__main__':
    args = parse_args()
    main(args)