import time
import numpy as np
import torch
from PIL import Image
import glob
import sys
import argparse
import datetime
import json
from pathlib import Path

"""
Implementation of PRSLogger and hook management tools

This module contains a `PRSLogger` class and a `hook_prs_logger` method for registering the logging tool.
It is primarily used to record the projected representation of residual streams in vision models, 
and to perform normalization and standardization operations.

Feature overview:
1. **PRSLogger class**
   - Provides the ability to record and standardize attention weights and MLP outputs from different layers of the model.
   - Supports both spatial and non-spatial attention logging.
   - Captures layer outputs through hooks for further analysis of model behavior.

2. **hook_prs_logger method**
   - Registers the `PRSLogger` to the model, dynamically capturing intermediate results from different layers.

Design purpose:
- Support explainability analysis of deep learning models.
- Provide normalized projections of residual streams for diagnosing model performance or further research.
"""


class PRSLogger(object):
    def __init__(self, model, device, spatial: bool = True):
        self.current_layer = 0
        self.device = device
        self.attentions = []
        self.mlps = []
        self.spatial = spatial
        self.post_ln_std = None
        self.post_ln_mean = None
        self.model = model

    @torch.no_grad()
    def compute_attentions_spatial(self, ret):
        assert len(ret.shape) == 5, "Verify that you use method=`head` and not method=`head_no_spatial`"  # [b, l, n*n, h, d]
        bias_term = self.model.visual.transformer.resblocks[self.current_layer].attn.out_proj.bias
        self.current_layer += 1
        return_value = ret[:, 0].detach().cpu()      # This is only for the cls token
        self.attentions.append(
            return_value + bias_term[None, None, None].cpu() / (return_value.shape[1] * return_value.shape[2])
        )
        return ret

    @torch.no_grad()
    def compute_mlps(self, ret):
        self.mlps.append(ret[:, 0].detach().cpu())  # [b,d]
        return ret

    @torch.no_grad()
    def log_post_ln_mean(self, ret):
        self.post_ln_mean = ret.detach().cpu()     #[b, 1]
        return ret

    @torch.no_grad()
    def log_post_ln_std(self, ret):
        self.post_ln_std = ret.detach().cpu()  # [b, 1]
        return ret

    def _normalize_mlps(self):
        len_intermediates = self.attentions.shape[1] + self.mlps.shape[1]
        # For the normalization layer
        mean_centered = self.mlps - self.post_ln_mean[:, :, None].to(self.device) / len_intermediates
        weighted_mean_centered = self.model.visual.ln_post.weight.to(self.device) * mean_centered
        weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[:, :, None].to(self.device)
        bias_term = self.model.visual.ln_post.bias.to(self.device) / len_intermediates
        post_ln = weighted_mean_by_std + bias_term
        return post_ln @ self.model.visual.proj.to(self.device)

    def _normalize_attentions_spatial(self):
        len_intermediates = self.attentions.shape[1] + self.mlps.shape[1]
        normalization_term = self.attentions.shape[2] * self.attentions.shape[3]
        # For the normalization layer
        mean_centered = (self.attentions - self.post_ln_mean[:, :, None, None, None].to(self.device)
                         / (len_intermediates * normalization_term))
        weighted_mean_centered = self.model.visual.ln_post.weight.to(self.device) * mean_centered
        weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[:, :, None, None, None].to(self.device)
        bias_term = self.model.visual.ln_post.bias.to(self.device) / (len_intermediates * normalization_term)
        post_ln = weighted_mean_by_std + bias_term
        return post_ln @ self.model.visual.proj.to(self.device)

    @torch.no_grad()
    def finalize(self, representation):
        self.attentions = torch.stack(self.attentions, dim=1).to(self.device)  # shape: [b, l, n, h, d]
        self.mlps = torch.stack(self.mlps, dim=1).to(self.device)  # shape: [b, l+1, d]
        norm = representation.norm(dim=-1).detach()

        # Normalize and return
        projected_attentions = self._normalize_attentions_spatial() / norm.view(-1, 1, 1, 1, 1)
        projected_mlps = self._normalize_mlps() / norm.view(-1, 1, 1)
        return projected_attentions, projected_mlps             # LN(ViT)

    def reinit(self):
        self.current_layer = 0
        self.attentions = []
        self.mlps = []
        self.post_ln_mean = None
        self.post_ln_std = None
        torch.cuda.empty_cache()

def hook_prs_logger(model, device, spatial: bool = True):
    """Hooks a projected residual stream logger to the model."""
    prs = PRSLogger(model, device, spatial=spatial)
    if spatial:
        model.hook_manager.register(
            "visual.transformer.resblocks.*.attn.out.post", prs.compute_attentions_spatial  # extract MSA(Z)cls
        )
    else:
        assert not spatial, "Non-spatial attention is not yet supported"
    model.hook_manager.register(
        "visual.transformer.resblocks.*.mlp.c_proj.post", prs.compute_mlps     # extract MLP(LN(Z))cls
    )
    model.hook_manager.register("visual.ln_pre_post", prs.compute_mlps)         # extract LP(Z0)cls
    model.hook_manager.register("visual.ln_post.mean", prs.log_post_ln_mean)
    model.hook_manager.register("visual.ln_post.sqrt_var", prs.log_post_ln_std)
    return prs