import pickle
import argparse
from os.path import join
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from DinoExtractor import DinoWithRegistersExtractor
from Imagenet import get_imagenet_loaders
from tqdm import tqdm
from transformers import AutoModel, AutoImageProcessor


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--imagenet_dir',
        help='Path to the directory containing test images',
        required=True
    )

    parser.add_argument(
        '--batch_size',
        default=16
    )

    return parser.parse_args()

def main(args):

    results = {}
    for model_size in ['small', 'base', 'large', 'giant'][::-1]:
        print('Working on {} model...'.format(model_size))
        processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-" + model_size, use_fast=True)
        model = AutoModel.from_pretrained("facebook/dinov2-with-registers-" + model_size)

        extractor = DinoWithRegistersExtractor(processor, model)

        batch_size = args.batch_size
        imagenet_path = args.imagenet_dir

        trainloader, testloader = get_imagenet_loaders(imagenet_path, batch_size=batch_size)

        train_reps = {
            'cls_output_global': [],
            'cls_output_local': [],
            'cls_output_total': [],
            'attn_weight': [],
            'cls_id': []
        }

        test_reps = {
            'cls_output_global': [],
            'cls_output_local': [],
            'cls_output_total': [],
            'attn_weight': [],
            'cls_id': []
        }

        # Compute all representations
        print('Getting train representations...')
        for i, inputs in enumerate(tqdm(trainloader)):
            # Forward pass
            x, y = inputs
            outputs = extractor.get_features(x, include_skip_connection=False)
            outputs['cls_id'] = y
            for key in train_reps.keys():
                train_reps[key].append(outputs[key].squeeze().detach().cpu())

        for key, val in train_reps.items():
            train_reps[key] = torch.cat(val)

        print('Getting test representations...')
        for i, inputs in enumerate(tqdm(testloader)):
            # Forward pass
            x, y = inputs
            outputs = extractor.get_features(x, include_skip_connection=False)
            outputs['cls_id'] = y
            for key in test_reps.keys():
                test_reps[key].append(outputs[key].squeeze().detach().cpu())

        for key, val in test_reps.items():
            test_reps[key] = torch.cat(val)

        def predict(test_mat, train_mat, cls_id):
            test_mat = test_mat / test_mat.norm(dim=-1, keepdim=True)
            train_mat = train_mat / train_mat.norm(dim=-1, keepdim=True)
            sim_mat = test_mat @ train_mat.T
            most_sim = torch.argsort(sim_mat, dim=-1, descending=True)[:, :5]
            correct = (cls_id[most_sim] == cls_id.unsqueeze(1)).any(dim=1)
            correct_percentage = correct.sum() / len(correct)
            return correct_percentage

        total_accuracy = predict(test_reps['cls_output_total'], train_reps['cls_output_total'], test_reps['cls_id'])
        register_accuracy = predict(test_reps['cls_output_global'], train_reps['cls_output_total'], test_reps['cls_id'])
        patch_accuracy = predict(test_reps['cls_output_local'], train_reps['cls_output_total'], test_reps['cls_id'])

        results[model_size] = {'global': total_accuracy, 'register': register_accuracy, 'patch': patch_accuracy}

    with open(join('results', 'register_models', 'classification_results.pkl'), 'wb') as fp:
        pickle.dump(results, fp)

    with open(join('results', 'register_models', 'classification_results.pkl'), 'rb') as fp:
        results = pickle.load(fp)

    # Set default font size to 9pt for all text
    matplotlib.rcParams['font.size'] = 8
    plt.rcParams['axes.titlesize'] = 8

    # Extract model names and metric names
    models = list(results.keys())[::-1]
    metrics = list(next(iter(results.values())).keys())[1:]  # plot only register and patch tokens

    # Build a 2D array of shape (n_models, n_metrics)
    data = np.array([[results[model][metric] for metric in metrics]
                     for model in models])

    n_models, n_metrics = data.shape
    indices = np.arange(n_models)
    bar_width = 0.8 / n_metrics  # total group width of 0.8

    # Positions of groups on the x-axis
    indices = np.arange(n_models)
    bar_width = 0.8 / n_metrics  # total width allocated per group

    labels  = ['Register tokens', 'Patch tokens']

    fig_width = 7 / 2.54   # ≈ 2.76 inches
    fig_height = 2         # example height in inches; adjust according to your layout

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Insert line for chance level accuracy
    chance_level = 5 / 1000
    ax.axhline(y=chance_level, linestyle='--', linewidth=1, c='black', label='Chance level')

    # Plot each set of results
    for i in range(n_metrics):
        ax.bar(
            indices + i * bar_width,
            data[:, i],
            width=bar_width,
            label=labels[i]
        )

    # Labeling and aesthetics
    ax.set_xlabel('Model')
    ax.set_ylabel('Accuracy')
    ax.set_title('One-shot classification accuracy')
    ax.set_xticks(indices + bar_width * (n_metrics - 1) / 2)
    ax.set_xticklabels(models)
    ax.legend()
    fig.tight_layout()
    plt.tight_layout
    plt.savefig(join('results', 'register_models', 'patch_register_classification.pdf'))

if __name__ == '__main__':
    args = parse_args()
    main(args)