from torchvision import transforms

from torch.utils.data import Dataset
import torch
import os
from PIL import Image
import numpy as np
from einops import rearrange


class DinoWithRegistersExtractor:
    def __init__(self, processor, model, device='cuda'):

        self.processor = processor
        self.model = model.to(device)
        self.model.eval()
        self.device = device

    def get_features(self, x, include_skip_connection=True, probe_registers=True, invert_attn=False):
        assert x.max() <= 1, 'Images should be in [0,1] when entering TransformerBackbone'
        with torch.no_grad():
            inputs = self.processor(images=x, return_tensors="pt", do_center_crop=False, do_resize=False, do_rescale=False)

            # move inputs to model device. Need not be done elsewhere
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # forward pass of dino model

        hidden_states = self.model.embeddings(inputs['pixel_values'], bool_masked_pos=None)
        with torch.no_grad():
            # Compute all layers but the last
            for i, layer_module in enumerate(self.model.encoder.layer[:-1]):
                layer_outputs = layer_module(hidden_states, None, False)
                hidden_states = layer_outputs[0]
            pass

            last_layer = self.model.encoder.layer[-1]
            pre_norm_hidden_states = hidden_states.clone()
            post_norm_hidden_states = last_layer.norm1(hidden_states).clone()

            last_layer_attention = last_layer.attention.attention

            # forward pass of Dinov2WithRegistersSdpaSelfAttention, of which last_layer.attention.attention is an instance
            mixed_query_layer = last_layer_attention.query(last_layer.norm1(hidden_states))

            key = last_layer_attention.transpose_for_scores(last_layer_attention.key(last_layer.norm1(hidden_states)))
            value = last_layer_attention.transpose_for_scores(last_layer_attention.value(last_layer.norm1(hidden_states)))
            query = last_layer_attention.transpose_for_scores(mixed_query_layer)

            L, S = query.size(-2), key.size(-2)
            scale_factor = 1 / np.sqrt(query.size(-1))
            attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)

            attn_weight = query @ key.transpose(-2, -1) * scale_factor
            attn_weight += attn_bias
            attn_weight = torch.softmax(attn_weight, dim=-1)

            context_layer = attn_weight @ value

            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (last_layer_attention.all_head_size,)
            context_layer = context_layer.view(new_context_layer_shape)


            if probe_registers:   # Compute outputs for register and patch tokens separately
                cls_val_total = rearrange(attn_weight[:, :, 0, :].unsqueeze(-2) @ value, 'b h 1 f -> b 1 (h f)')
                cls_val_global = rearrange(attn_weight[:, :, 0, :5].unsqueeze(-2) @ value[:, :, :5, :], 'b h 1 f -> b 1 (h f)')
                cls_val_local = rearrange(attn_weight[:, :, 0, 5:].unsqueeze(-2) @ value[:, :, 5:, :], 'b h 1 f -> b 1 (h f)')

                # Still need to apply the dense layer of the last attention layer
                cls_val_total = last_layer.attention.output(cls_val_total, last_layer.norm1(hidden_states))
                cls_val_global = last_layer.attention.output(cls_val_global, last_layer.norm1(hidden_states))
                cls_val_local = last_layer.attention.output(cls_val_local, last_layer.norm1(hidden_states))

                if not include_skip_connection:
                    hidden_states = torch.zeros_like(hidden_states)
                cls_output_global = self.post_attention_processing(cls_val_global, hidden_states[:, 0, :].unsqueeze(1))[
                    0]
                cls_output_local = self.post_attention_processing(cls_val_local, hidden_states[:, 0, :].unsqueeze(1))[0]
                cls_output_total = self.post_attention_processing(cls_val_total, hidden_states[:, 0, :].unsqueeze(1))[0]

                return {
                    'cls_output_global': cls_output_global,
                    'cls_output_local': cls_output_local,
                    'cls_output_total': cls_output_total,
                    'attn_weight': attn_weight,
                    'pre_norm_hidden_states': pre_norm_hidden_states,
                    'post_norm_hidden_states': post_norm_hidden_states
                }

            else:
                if not invert_attn:
                    ### Just extract residual and patch outputs separately, used for most results on residual infulence
                    # Compute the output for residual and patch tokens separately
                    # omit the cls token
                    cls_val_total = rearrange(attn_weight[:, :, 0, :].unsqueeze(-2) @ value[:, :, :, :], 'b h 1 f -> b 1 (h f)')
                    cls_val_total = last_layer.attention.output(cls_val_total, hidden_states)

                    cls_output_residual = self.post_attention_processing(torch.zeros_like(cls_val_total), hidden_states[:, 0, :].unsqueeze(1))[0]
                    cls_output_patch = self.post_attention_processing(cls_val_total, torch.zeros_like(hidden_states[:, 0, :].unsqueeze(1)))[0]
                    cls_output_total = self.post_attention_processing(cls_val_total, hidden_states[:, 0, :].unsqueeze(1))[0]

                    residual_norm = hidden_states[:, 0, :].norm(dim=-1)
                    patch_norm = self.model.encoder.layer[-1].layer_scale1(cls_val_total).norm(dim=-1)

                    return {
                        'cls_output_residual': cls_output_residual,
                        'cls_output_patch': cls_output_patch,
                        'cls_output_total': cls_output_total,
                        'residual_norm': residual_norm,
                        'patch_norm': patch_norm
                    }
                else:
                    # Compute standard output, as well as output with inverted attention. Used to study how big the
                    # influence of patches/attention is on classification
                    cls_val_standard = rearrange(attn_weight[:, :, 0, :].unsqueeze(-2) @ value[:, :, :, :],
                                              'b h 1 f -> b 1 (h f)')
                    cls_val_standard = last_layer.attention.output(cls_val_standard, hidden_states)
                    cls_output_standard = self.post_attention_processing(cls_val_standard, hidden_states[:, 0, :].unsqueeze(1))[0]

                    attn_weight_inverted = query @ key.transpose(-2, -1) * scale_factor
                    attn_weight_inverted += attn_bias
                    attn_weight_inverted = torch.softmax(- attn_weight_inverted, dim=-1)


                    cls_val_inverted = rearrange(attn_weight_inverted[:, :, 0, :].unsqueeze(-2) @ value[:, :, :, :],
                                                 'b h 1 f -> b 1 (h f)')
                    cls_val_inverted = last_layer.attention.output(cls_val_inverted, hidden_states)
                    cls_output_inverted = self.post_attention_processing(cls_val_inverted, hidden_states[:, 0, :].unsqueeze(1))[0]

                    return {
                        'cls_output_standard': cls_output_standard,
                        'cls_output_inverted': cls_output_inverted,
                    }

    def post_attention_processing(self, attention_output, hidden_states):
        '''
        :param x: [B, patches, f]
        Computes the processing after the output of the attention, i.e. layer norm and the MLP to get the final features
        '''
        last_layer = self.model.encoder.layer[-1]
        attention_output = last_layer.layer_scale1(attention_output)

        # first residual connection
        hidden_states = last_layer.drop_path(attention_output) + hidden_states

        # in Dinov2WithRegisters, layernorm is also applied after self-attention
        layer_output = last_layer.norm2(hidden_states)
        layer_output = last_layer.mlp(layer_output)
        layer_output = last_layer.layer_scale2(layer_output)

        # second residual connection
        layer_output = last_layer.drop_path(layer_output) + hidden_states

        outputs = (layer_output,)
        return outputs



class ImageDataset(Dataset):
    def __init__(self, directory):
        self.directory = directory
        self.transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
        # List all files with common image extensions
        self.image_files = [f for f in os.listdir(directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.directory, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        return image

import torch


def HSIC(K, L):
    n = K.shape[0]
    # Center the matrices
    K = K - K.mean(dim=0, keepdim=True)
    L = L - L.mean(dim=0, keepdim=True)

    return (1 / (n - 1) ** 2) * torch.trace(K @ K.T @ L @ L.T)

def CKA(K, L):
    return HSIC(K, L) / (torch.sqrt(HSIC(K, K) * HSIC(L, L)))
