import argparse
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import wandb
from sklearn.manifold import TSNE
from torch.utils.data import TensorDataset
from tqdm import tqdm

import models_simmim
import models_vit
import models_capi
import models_more
import util.misc as misc
from util.datasets import build_dataset_v2
from util.misc import AMP_PRECISIONS
from util.pos_embed import interpolate_pos_embed
from models_vit import CLS_FT_CHOICES

from poolings.abmilp import ABMILPHead
from poolings.simpool import SimPool, SimPool_nolinears
from poolings.clip.attention_pool import AttentionPoolLatent
from poolings.clip.attention_pool2d import AttentionPool2d
from poolings.jepa.attentive_pooler import AttentivePooler
from poolings.aim import AttentionPoolingClassifier
from poolings.cbam import CbamPooling
from poolings.coca_pytorch import CrossAttention as CocaPooling
from poolings.other_pool import CAPooling, DinoViTBlockPooling
from poolings.dolg.dolg import SpatialAttention2d
from poolings.cae_att import CAEAttentiveBlock
from poolings.ep import EfficientProbing
from poolings.mhca import MHCA

from models_simmim import VisionTransformerSimMIM
from models_vit import VisionTransformer

import open_clip
from util.DiT.download import find_model as find_model_dit
from util.SiT.download import find_model as find_model_sit
from util.DiT.models import DiT_models
from util.SiT.models import SiT_models
from aim.v2.utils import load_pretrained
from diffusers.models import AutoencoderKL

from models_capi import enable_attention_capture
from util.patches.dino_attn_capture import enable_dino_attention_capture, get_all_block_attentions
from util.patches.clip_attn_capture import enable_clip_attention_capture, get_all_block_attentions_clip
from util.patches.timm_attn_capture import enable_timm_vit_attention_capture, get_all_block_attentions_timm

# os.environ["WANDB_ENTITY"] = 
# os.environ["WANDB_PROJECT"] = 
# os.environ["WANDB_API_KEY"] = 

def get_args_parser():
    parser = argparse.ArgumentParser('MAE attention statistics', add_help=False)
    parser.add_argument('--batch_size', default=512, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')

    # Model parameters
    parser.add_argument('--model', default='vit_base_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--input_size', default=224, type=int,
                        help='images input size')
    parser.add_argument("--simmim", action="store_true", default=False)
    parser.add_argument("--openclip", action="store_true", default=False)
    parser.add_argument('--openclip_pretrain', default='openai', type=str, metavar='PRETRAIN',
                        help='Name of pretrain framework for openclip')
    parser.add_argument('--comp', default='mean', help='how to compute complementarity score', choices=['mean', 'max', 'mean_all', 'max_all'])
    
    # DINOv3 argument
    parser.add_argument('--dinov3_weights', type=str, metavar='DINOV3_WEIGHTS',
                        help='url or path to DINOv3 weights')
    
    # Franca argument
    parser.add_argument("--use_rasa_head", action="store_true", default=False, 
                        help="Use debiased patch tokens from RASA heasd (only for Franca models)")

    # DiT/SiT arguments
    parser.add_argument("--dit_image_size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--dit_ckpt", type=str, default=None,
                        help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
    parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse")
    parser.add_argument("--cls_features",
                        choices=CLS_FT_CHOICES,
                        default="cls", help="cls token / positional tokens for classification")

    # * Finetuning params
    parser.add_argument('--finetune', default='', help='finetune from checkpoint')
    parser.add_argument("--checkpoint_key", default="model", type=str)
    parser.add_argument("--cca_bias", default="none")

    # Pre-trained pooling head
    parser.add_argument('--pretrained_head', default=None, help='load attentive pooling head from checkpoint')

    parser.set_defaults(global_pool=False)

    # EP
    parser.add_argument("--ep_queries", type=int, default=32, help="number of EfficientProbing queries")
    parser.add_argument("--d_out", type=int, default=1, help="Denominator of classifier dimensionality")
    # Other poolings
    parser.add_argument("--num_heads", type=int, default=16, help="number of other pooling methods heads")

    # Dataset parameters
    parser.add_argument('--dataset_name', default='imagenet1k', type=str,
                        help='dataset name')
    parser.add_argument('--data_path', default='/pfs/lustrep2/scratch/project_465001765/datasets/imagenet', type=Path,
                        help='dataset path')
    parser.add_argument('--nb_classes', default=1000, type=int,
                        help='number of the classification types')

    parser.add_argument('--output_dir', default=None,
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)

    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    parser.add_argument("--draw_2d_embeddings", action="store_true", default=False)
    parser.add_argument("--amp", default="float16", choices=list(AMP_PRECISIONS.keys()), type=str)


    return parser


def main(args):
    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()

    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    args.dino_aug = False # hack
    _, dataset_val = build_dataset_v2(args, is_pretrain=False)

    print(dataset_val)

    args.distributed = False
    args.gpu = 0
    global_rank = 0

    sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    
    if args.output_dir is not None:
        misc.maybe_setup_wandb(args.output_dir, args=args, job_type="attn_stats")


    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )

    size_patch_kwargs = dict()
    if args.input_size != 224:
        assert args.input_size % 16 == 0, args.input_size
        size_patch_kwargs=dict(
            img_size=args.input_size,
            patch_size=args.input_size // 16
        )

    model_to_kwargs = {
        "vit_tiny_patch16": dict(patch_size=16, embed_dim=192, depth=12, num_heads=12),
        "vit_small_patch16": dict(patch_size=16, embed_dim=384, depth=12, num_heads=12),
        "vit_base_patch16": dict(patch_size=16, embed_dim=768, depth=12, num_heads=12),
        "vit_large_patch16": dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16),
        "vit_huge_patch14": dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16),
    }
    if args.model.startswith("capi"):
        capi_backbone = torch.hub.load('facebookresearch/capi:main', args.model)
        
        if args.amp == "bfloat16":
            capi_backbone.to(dtype=torch.bfloat16)
        # patch the model to capture attention
        enable_attention_capture(capi_backbone)

        model = models_capi.CapiWrapper(
            capi_model=capi_backbone,
            num_classes=args.nb_classes,
            features=args.cls_features
        )
    elif args.model.startswith("dinov2"):
        dinov2_backbone = torch.hub.load('facebookresearch/dinov2', args.model)

        enable_dino_attention_capture(dinov2_backbone)

        model = models_more.DinoWrapper(
            dino_model=dinov2_backbone, 
            num_classes=args.nb_classes,
            features=args.cls_features
        )
    elif args.model.startswith("dinov3"):
        dinov3_backbone = torch.hub.load('facebookresearch/dinov3', args.model, weights=args.dinov3_weights)
        
        enable_dino_attention_capture(dinov3_backbone)

        model = models_more.DinoWrapper(
            dino_model=dinov3_backbone, 
            num_classes=args.nb_classes,
            features=args.cls_features
        )
    elif args.model.startswith("aimv2"):
        aimv2_backbone = load_pretrained(args.model, backend="torch")
        model = models_more.AIMv2Wrapper(
            aimv2_model=aimv2_backbone, 
            num_classes=args.nb_classes,
            features=args.cls_features
        )
    elif args.model.startswith("franca"):
        franca_backbone = torch.hub.load('valeoai/Franca', args.model, use_rasa_head=True)

        enable_dino_attention_capture(franca_backbone)

        model = models_more.FrancaWrapper(
            franca_model=franca_backbone, 
            num_classes=args.nb_classes,
            features=args.cls_features,
            use_rasa_head=args.use_rasa_head
        )
    elif args.model.startswith("DiT"):
        if args.dit_ckpt is None:
            assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download."
            assert args.dit_image_size in [256, 512]
            assert args.nb_classes == 1000
        
        # Load model
        latent_size = args.dit_image_size // 8
        dit_backbone = DiT_models[args.model](
            input_size=latent_size,
            num_classes=args.nb_classes
        ).to(device)
        dit_ckpt_path = args.dit_ckpt or f"DiT-XL-2-{args.dit_image_size}x{args.dit_image_size}.pt"
        state_dict = find_model_dit(dit_ckpt_path)
        dit_backbone.load_state_dict(state_dict)
        
        vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device).eval()

        model = models_more.DiTWrapper(
            dit_model=dit_backbone, 
            vae_model=vae,
            num_classes=args.nb_classes,
            features=args.cls_features,
            finetuning=args.finetuning
        )
    elif args.model.startswith("SiT"):
        if args.dit_ckpt is None:
            assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
            assert args.dit_image_size == 256, "512x512 models are not yet available for auto-download."
            assert args.nb_classes == 1000
        
        # Load model
        latent_size = args.dit_image_size // 8
        sit_backbone = SiT_models[args.model](
            input_size=latent_size,
            num_classes=args.nb_classes
        ).to(device)
        dit_ckpt_path = args.dit_ckpt or f"SiT-XL-2-{args.dit_image_size}x{args.dit_image_size}.pt"
        state_dict = find_model_sit(dit_ckpt_path)
        sit_backbone.load_state_dict(state_dict)
        
        vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device).eval()

        model = models_more.DiTWrapper(
            dit_model=sit_backbone, 
            vae_model=vae,
            num_classes=args.nb_classes,
            features=args.cls_features
        )
    elif args.openclip:
        backbone, _, _ = open_clip.create_model_and_transforms(args.model, pretrained=args.openclip_pretrain)
        vision_encoder = backbone.visual

        if 'timm' in vision_encoder.__class__.__module__:
            enable_timm_vit_attention_capture(vision_encoder.trunk)
        else:
            enable_clip_attention_capture(vision_encoder)

        model = models_more.CLIPWrapper(
            clip_model=vision_encoder,
            num_classes=args.nb_classes,
            features=args.cls_features
        )
    elif args.simmim:
        model = models_simmim.__dict__[args.model](
            checkpoint_path=args.finetune
        )
    else:
        model: models_vit.VisionTransformer = models_vit.__dict__[args.model](
            num_classes=args.nb_classes,
            **size_patch_kwargs
        )

    if args.finetune and not args.simmim and not args.model.startswith(("capi", "dinov2", "dinov3", "aimv2", "franca", "DiT", "SiT")):
        if Path(args.finetune).exists():
            print("Interpreting", args.finetune, "as path")
            checkpoint_model = torch.load(args.finetune, map_location='cpu')[args.checkpoint_key]

        elif args.finetune.startswith("hub"):
            state_dict = torch.hub.load_state_dict_from_url(
                url=models_vit.HUB_KEY_TO_URL[args.finetune],
            )
            state_dict = state_dict['model']
            for k in list(state_dict.keys()):
                if k.startswith('decoder') or k.startswith('mask_token'):
                    del state_dict[k]
            checkpoint_model = state_dict
        else:
            print("Interpreting", args.finetune, "as timm model")
            from timm.models.vision_transformer import _create_vision_transformer

            model_kwargs = model_to_kwargs[args.model]
            checkpoint_model = _create_vision_transformer(args.finetune, pretrained=True, **model_kwargs).state_dict()

        print("Load pre-trained checkpoint from: %s" % args.finetune)
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        interpolate_pos_embed(model, checkpoint_model)

        # load pre-trained model
        msg = model.load_state_dict(checkpoint_model, strict=False)
        print(msg)

        assert not any([k.startswith("blocks") for k in msg.missing_keys])

    if args.cls_features == "ep" or args.cls_features == "ep_all":
        ep = EfficientProbing(dim=model.head.in_features, num_queries=args.ep_queries, d_out=args.d_out)
        new_classifier = torch.nn.Linear(model.head.in_features // args.d_out, args.nb_classes, bias=True)
        model.head = torch.nn.Sequential(
            ep,
            torch.nn.BatchNorm1d(model.head.in_features // args.d_out, affine=False, eps=1e-6),
            new_classifier
        )
    elif args.cls_features in ["cls", "pos"]:
        model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)
    else:
        raise NotImplementedError()

    if args.pretrained_head is not None:
        ckpt_head = torch.load(args.pretrained_head, map_location='cpu')
        state = ckpt_head.get("model", None) or ckpt_head.get("state_dict", None) or ckpt_head
        msg = model.head.load_state_dict(state, strict=True)
        print(f"\nLoaded attentive pooling head from {args.pretrained_head}:\n{msg}")

    model.to(device)

    if wandb.run is not None:
        
        if args.pretrained_head is None:
            with torch.cuda.amp.autocast(
                    enabled=args.amp != "none",
                    dtype=AMP_PRECISIONS[args.amp]
            ):
                L_test, Y_test, A_test, complementarity = collect_features(
                    model, data_loader_val, device,
                    complementarity_reduce=args.comp, 
                    tqdm_desc="attention stats",
                )

            mean_attn_stats = A_test.mean(dim=(0, 2))
            # mean_magn_stats = M_test.mean(dim=0)


            cc_attns = mean_attn_stats[:, 0]
            pos_self_attns = mean_attn_stats[:, 1]
            cc_attns_adj = mean_attn_stats[:, 2]
            pos_self_attns_adj = mean_attn_stats[:, 3]
            cls_pos_attns = mean_attn_stats[:, 4] # should complement the cls cls attention
            pos_cls_attns = mean_attn_stats[:, 5]
            cls_pos_entropy = mean_attn_stats[:, 6]
            pos_pos_entropy = mean_attn_stats[:, 7]
            # cls_magnitude = mean_magn_stats[:, 0]
            # pos_magnitude = mean_magn_stats[:, 1]

            stats_pf = "test_attn"

            for b in range(len(cc_attns)):
                wandb.log({
                    f"{stats_pf}/cls_cls_attention": cc_attns[b], # Average attention weight from [CLS] token → [CLS] token; How much the CLS token attends to itself; High value: The CLS token is “self-focused,” not using much context.
                    f"{stats_pf}/pos_self_attention": pos_self_attns[b], # Average attention weight from patch tokens → themselves; Measures “diagonal” attention: how much each patch only attends to itself; High value: Patches are isolated, little cross-token mixing; Low value: Patches spread attention more broadly across the image.
                    f"{stats_pf}/cls_cls_attention_adj_for_cls": cc_attns_adj[b], # same as the above, minor differences
                    f"{stats_pf}/pos_self_attention_adj_for_cls": pos_self_attns_adj[b], # same as the above, minor differences
                    f"{stats_pf}/cls_pos_attention": cls_pos_attns[b], # Average attention from CLS → patch tokens; How much the CLS token is reading information from the image patches; High value: CLS is actively aggregating patch information (good for representation learning).
                    f"{stats_pf}/pos_cls_attention": pos_cls_attns[b], # Average attention from patches → CLS token; How much patches “write” their information to the CLS token; CLS should act as an aggregator, so this is expected to be reasonably high in later layers.
                    f"{stats_pf}/cls_pos_entropy": cls_pos_entropy[b], # Entropy of CLS’s attention distribution over patches; High entropy: CLS distributes attention broadly across many patches; Low entropy: CLS focuses sharply on only a few patches; Use: Detects whether CLS aggregates globally or sparsely.
                    f"{stats_pf}/pos_pos_entropy": pos_pos_entropy[b], # Entropy of patch-to-patch attention distributions; High entropy: Patches attend widely across the image (global mixing); Low entropy: Patches attend narrowly to specific neighbors.
                    # f"{stats_pf}/cls_magnitude": cls_magnitude[b], # Magnitude of the CLS representation after attention vs. residual connection; How strongly CLS updates itself at each layer; High value: CLS is being significantly modified by attention; Low value: CLS mainly carries residual information.
                    # f"{stats_pf}/pos_magnitude": pos_magnitude[b], # Magnitude of the patch tokens representation after attention vs. residual connection; Meaning: Tracks how much patches are modified by attention vs. residual flow.
                    f"{stats_pf}/vit_block": b,
                })
            
            stats_pf = "test_complementarity"
            if "all" in args.comp:
                complementarity = complementarity.mean(dim=0)
                for i in range(len(complementarity)):
                    wandb.log({f"{stats_pf}/attn_complementarity": complementarity[i].item()}, step=i)
            else:
                complementarity = complementarity.mean().item()
                wandb.log({f"{stats_pf}/attn_complementarity": complementarity})

            # tsne = TSNE()
            # latent_2d = tsne.fit_transform(L_test.numpy())
            # Y_test = Y_test.numpy()
            # fig, ax = plt.subplots()

            # for label in range(10):
                # l_subset = latent_2d[Y_test == label][:25]
                # ax.scatter(l_subset[:, 0], l_subset[:, 1], label=label)

            # ax.legend()
            # static
            # wandb.log({"monitoring/tsne": wandb.Image(fig)})
            # dynamic
            #wandb.log({"monitoring/tsne": fig})
        else:
            assert args.pretrained_head and os.path.isfile(args.pretrained_head), f"Checkpoint file not found: {args.pretrained_head}"
            with torch.cuda.amp.autocast(
                    enabled=args.amp != "none",
                    dtype=AMP_PRECISIONS[args.amp]
            ):
                L_test, Y_test, v_entropy, backbone_entropy, complementarity = collect_features_pooled(
                    model, data_loader_val, device,
                    complementarity_reduce=args.comp,
                    tqdm_desc="attention stats",
                )

            stats_pf = "test_attn_pooled"
            cls_v_entropy = v_entropy.mean().item()
            cls_z_entropy = backbone_entropy.mean().item()
            wandb.log({f"{stats_pf}/cls_v_entropy": cls_v_entropy})
            wandb.log({f"{stats_pf}/cls_z_entropy": cls_z_entropy})

            stats_pf = "test_complementarity"
            complementarity = complementarity.mean().item()
            wandb.log({f"{stats_pf}/attn_complementarity": complementarity})

def calculate_attn_stuff(attn):
    attentions = []
    B, H, T, T = attn.shape
    attn_range = torch.arange(T)
    attn_diag = attn[:, :, attn_range, attn_range] # attention of tokens w.r.t. themselves
    cls_all_attn = attn[:, :, 0, ]  # attention of cls token to all tokens
    all_cls_attn = attn[:, :, :, 0] # attention of all tokens to cls token
    attn_wo_cls = attn[:, :, :, 1:]
    attn_wo_cls_denom = attn_wo_cls.sum(dim=3, keepdim=True)
    attn_wo_cls = attn_wo_cls / (attn_wo_cls_denom + 1e-6)
    all_pos_attn_entropy = -(attn_wo_cls * (attn_wo_cls + 1e-6).log()).sum(dim=3)
    attn_adj_for_cls = attn / (attn_wo_cls_denom + 1e-6)
    attn_diag_adj_for_cls = attn_adj_for_cls[:, :, attn_range, attn_range]
    attn_stats = torch.stack([attn_diag, attn_diag_adj_for_cls, cls_all_attn, all_cls_attn, all_pos_attn_entropy])
    attn_stats = attn_stats.unsqueeze(2)
    attentions.append(attn_stats.detach())
    return torch.cat(attentions, dim=2)

def collect_features(
        model: models_vit.VisionTransformer, loader: torch.utils.data.DataLoader,
        device,
        complementarity_reduce: str = 'mean',
    tqdm_desc: str = None
):
    model.eval()
    with torch.no_grad():
        features = []
        labels = []
        attns_list = []
        complementarity_scores = []
        # magn_list = []


        for i, (data, target) in enumerate(tqdm(loader, desc=tqdm_desc)):
            with torch.cuda.amp.autocast(
                    enabled=args.amp != "none",
                    dtype=AMP_PRECISIONS[args.amp]
            ):
                # Extract features and attentions from the backbone
                if args.model.startswith(("capi", "dinov2", "dinov3", "franca")):
                    _, z = model(data.to(device), return_backbone_features=True)
                    if args.model.startswith("capi"):
                        attns = []
                        for b in range(len(model.capi_model.encoder.blocks)):
                            attn = model.capi_model.encoder.blocks[b].residual1.fn._last_attn
                            attns.append(calculate_attn_stuff(attn))
                        attns = torch.cat(attns, dim=2)
                    elif args.model.startswith(("dinov2", "dinov3", "franca")):
                        if args.model.startswith("franca"):
                            attn = get_all_block_attentions(model.franca_model)
                        else:
                            attn = get_all_block_attentions(model.dino_model)
                        attns = []
                        for b in range(attn.shape[0]):
                            attns.append(calculate_attn_stuff(attn[b]))
                        attns = torch.cat(attns, dim=2)
                elif args.openclip:
                    _, z = model(data.to(device), return_backbone_features=True)
                    if model.is_timm:
                        attn = get_all_block_attentions_timm(model.clip_model.trunk)
                    else:
                        attn = get_all_block_attentions_clip(model.clip_model)
                    attns = []
                    for b in range(attn.shape[0]):
                        attns.append(calculate_attn_stuff(attn[b]))
                    attns = torch.cat(attns, dim=2)
                else:
                    z, attns, magnitudes = model.forward_features(data.to(device), return_features=args.cls_features)

            if args.model.startswith(("dinov2", "dinov3", "franca")) or args.openclip:
                complementarity_score = cls_head_complementarity(attn[-1, :, :, 0, 1:], reduce=complementarity_reduce)
            elif args.model.startswith("capi"):
                attn = model.capi_model.encoder.blocks[-1].residual1.fn._last_attn
                complementarity_score = cls_head_complementarity(attn[:, :, 0, 1:], reduce=complementarity_reduce)
            else:
                if "all" in complementarity_reduce:
                    complementarity_score = []
                    for b in range(attns.shape[2]):
                        complementarity_score_b = cls_head_complementarity(attns[2, :, b, :, 1:], reduce=complementarity_reduce)
                        complementarity_score.append(complementarity_score_b)
                    complementarity_score = torch.stack(complementarity_score, dim=1)
                else:
                    complementarity_score = cls_head_complementarity(attns[2, :, -1, :, 1:], reduce=complementarity_reduce)
            complementarity_scores.append(complementarity_score.detach().cpu())

            cls_cls_attns = attns[0, :, :, :, :1]
            pos_self_attns = attns[0, :, :, :, 1:].mean(dim=3, keepdim=True)

            cls_cls_attns_adj = attns[1, :, :, :, :1]
            pos_self_attns_adj = attns[1, :, :, :, 1:].mean(dim=3, keepdim=True)


            cls_pos_attns = attns[2, :, :, :, 1:].mean(dim=3, keepdim=True)
            pos_cls_attns = attns[3, :, :, :, 1:].mean(dim=3, keepdim=True)

            cls_pos_entropy = attns[4, :, :, :, :1]
            pos_pos_entropy = attns[4, :, :, :, 1:].mean(dim=3, keepdim=True)

            attn_stats = torch.cat([cls_cls_attns, pos_self_attns, cls_cls_attns_adj, pos_self_attns_adj, cls_pos_attns, pos_cls_attns, cls_pos_entropy, pos_pos_entropy], dim=3)

            # magn_residual = magnitudes[0]
            # magn_attended = magnitudes[1]
            # magn_stats = magn_attended / (magn_residual + 1e-6)
            # cls_magn_stats = magn_stats[:, :, :1]
            # pos_magn_stats = magn_stats[:, :, 1:].mean(dim=2, keepdim=True)

            # magn_stats = torch.cat([cls_magn_stats, pos_magn_stats], dim=2)


            features.append(z.detach().cpu())
            labels.append(target.detach().short().cpu())

            BSS, L, H, _ = attn_stats.shape
            attn_stats = attn_stats.reshape(BSS, L, H, 8)
            # magn_stats = magn_stats.reshape(BSS, L, 2)

            attns_list.append(attn_stats.detach().cpu())
            # magn_list.append(magn_stats.detach().cpu())


    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0).long()
    complementarity_scores = torch.cat(complementarity_scores, dim=0)

    attns_list = torch.cat(attns_list, dim=0)
    # magns_list = torch.cat(magn_list, dim=0)

    return features, labels, attns_list, complementarity_scores #, magns_list

def calc_entropy(cls_v: torch.Tensor) -> torch.Tensor:
    cls_v = torch.softmax(cls_v, dim=-1)  # [B, N]
    entropy = -(cls_v * (cls_v + 1e-6).log()).sum(dim=-1)  # [B]
    return entropy

def cls_head_complementarity(
    cls_attn: torch.Tensor,
    eps: float = 1e-8,
    reduce: str = "mean",
):
    """
    Compute a diversity/complementarity score across heads from the last-block attention.
    - Input: [cls] -> patch attention: cls_attn = attn[:, :, 0, 1:] # [B,H,N]
    - Normalizes per head to sum=1 over tokens.
    - Returns per-sample scalar: higher = more diverse heads.

    Returns:
        scores: torch.Tensor  # [B], diversity per sample
    """
    assert cls_attn.ndim == 3, f"expected [B,H,N], got {cls_attn.shape}"
    B, H, N = cls_attn.shape
    # normalize to distributions over tokens
    P = cls_attn / (cls_attn.sum(dim=-1, keepdim=True) + eps)
    # L2-normalize per head, then average (1 - cosine) over off-diagonal pairs
    Pn = F.normalize(P, dim=-1)
    # pairwise cosine: [B,H,H]
    sim = torch.matmul(Pn, Pn.transpose(1, 2))
    # remove diagonal, average off-diagonals
    mask = ~torch.eye(H, dtype=torch.bool, device=sim.device)
    if reduce in ["mean", "mean_all"]:
        mean_sim = (sim[:, mask].view(B, -1)).mean(dim=1)  # [B]
        scores = 1.0 - mean_sim  
    elif reduce in ["max", "max_all"]:
        max_sim = (sim[:, mask].view(B, -1)).max(dim=1).values  # [B]
        scores = 1.0 - max_sim                            # higher => more diverse
    else:
        raise ValueError(f"reduce={reduce} not implemented")
    return scores

def collect_features_pooled(
        model: models_vit.VisionTransformer, loader: torch.utils.data.DataLoader,
        device,
        complementarity_reduce: str = 'mean',
    tqdm_desc: str = None
):
    model.eval()
    with torch.no_grad():
        features = []
        labels = []
        v_entropy = []
        backbone_entropy = []
        complementarity_scores = []


        for i, (data, target) in enumerate(tqdm(loader, desc=tqdm_desc)):
            with torch.cuda.amp.autocast(
                    enabled=args.amp != "none",
                    dtype=AMP_PRECISIONS[args.amp]
            ):
                if isinstance(model, (VisionTransformer, VisionTransformerSimMIM)):
                    z, z_backbone = model.forward(data.to(device), return_features=args.cls_features, return_backbone_features=True)
                else:
                    z, z_backbone = model.forward(data.to(device), return_backbone_features=True)
                
                ### Calculate complementarity score ###
                attn = model.head[0].get_attention
                if attn.shape[-1] != z_backbone.shape[-2]:
                    assert attn.shape[-1] == z_backbone.shape[-2] + 1, f"attention: {attn.shape}, features: {z_backbone.shape}"
                    attn = attn[:, :, :, 1:]
                complementarity_score = cls_head_complementarity(attn.squeeze(), reduce=complementarity_reduce)
                
                ### Calculate entropy of cls_v and cls_z ###
                ep_cls = model.head[0].cls
                # ep_cls @ patch token features from backbone model
                cls_zbackbone = (ep_cls.unsqueeze(1) @ z_backbone.transpose(-2, -1)).squeeze()
                # ep_cls @ features from EP pooling (v)
                cls_v = model.head[0].cls_v
                cls_v_entropy = calc_entropy(cls_v)
                cls_z_entropy = calc_entropy(cls_zbackbone)

            complementarity_scores.append(complementarity_score.detach().cpu())
            v_entropy.append(cls_v_entropy.detach().cpu())
            backbone_entropy.append(cls_z_entropy.detach().cpu())
            features.append(z.detach().cpu())
            labels.append(target.detach().short().cpu())

    complementarity_scores = torch.cat(complementarity_scores, dim=0)
    v_entropy = torch.cat(v_entropy, dim=0)
    backbone_entropy = torch.cat(backbone_entropy, dim=0)
    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0).long()

    return features, labels, v_entropy, backbone_entropy, complementarity_scores

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
