"""
ViT-B/16 loader with activation hooks for KFAC.

Loads pretrained ViT from torchvision, freezes weights, and registers
forward hooks to capture intermediate activations per block.
"""

import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
from typing import Dict, List, Tuple, Optional, Callable
from collections import OrderedDict
from .logging_utils import get_logger, log_tensor_stats, log_dict

logger = get_logger("ifc_vit.vit")


class ViTWithHooks(nn.Module):
    """
    Wrapper around ViT-B/16 that captures activations for KFAC computation.
    
    Registers hooks on Linear layers to capture:
    - Input activations (a_l) for computing A_l = E[a_l a_l^T]
    - Output gradients (g_l) for computing G_l = E[g_l g_l^T]
    """
    
    def __init__(self, pretrained: bool = True, device: str = "cuda", use_tiny: bool = True):
        super().__init__()

        self.use_tiny = use_tiny
        self.device = device

        if use_tiny:
            import timm
            model_name = "deit_tiny_patch16_224"
            logger.info(f"Loading {model_name} (pretrained={pretrained}, device={device})")
            self.model = timm.create_model(
                model_name,
                pretrained=pretrained,
                num_classes=1000,
            )
        else:
            logger.info(f"Loading ViT-B/16 (pretrained={pretrained}, device={device})")
            weights = ViT_B_16_Weights.IMAGENET1K_V1 if pretrained else None
            self.model = vit_b_16(weights=weights)

        self.model.to(device)
        self.model.eval()

        # Freeze all parameters
        for p in self.model.parameters():
            p.requires_grad = False

        # Storage
        self.activations = {}
        self.gradients = {}
        self.hooks = []
        self.hooked_layers = OrderedDict()

        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        logger.info(f"Model loaded: {total_params:,} total params, {trainable_params:,} trainable")

    def register_kfac_hooks(self, layers_to_hook: Optional[List[str]] = None):
        """
        Register forward and backward hooks on Linear layers.
        
        Args:
            layers_to_hook: Specific layer names to hook. If None, hooks all Linear layers.
        """
        self.clear_hooks()
        
        linear_layer_count = 0
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                linear_layer_count += 1
                if layers_to_hook is None or name in layers_to_hook:
                    self.hooked_layers[name] = module
                    
                    # Forward hook: capture input activations
                    def make_forward_hook(layer_name):
                        def hook(module, input, output):
                            # input is a tuple, take first element
                            self.activations[layer_name] = input[0].detach()
                        return hook
                    
                    # Backward hook: capture output gradients
                    def make_backward_hook(layer_name):
                        def hook(module, grad_input, grad_output):
                            # grad_output is a tuple, take first element
                            self.gradients[layer_name] = grad_output[0].detach()
                        return hook
                    
                    fwd_hook = module.register_forward_hook(make_forward_hook(name))
                    bwd_hook = module.register_full_backward_hook(make_backward_hook(name))
                    
                    self.hooks.extend([fwd_hook, bwd_hook])
        
        logger.info(f"Registered KFAC hooks on {len(self.hooked_layers)}/{linear_layer_count} Linear layers")
        logger.debug(f"Hooked layers: {list(self.hooked_layers.keys())}")
        
    def clear_hooks(self):
        """Remove all registered hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        self.hooked_layers.clear()
        
    def clear_activations(self):
        """Clear stored activations and gradients."""
        self.activations.clear()
        self.gradients.clear()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the model."""
        return self.model(x)
    
    def get_penultimate_features(self, x: torch.Tensor) -> torch.Tensor:
        """
        Penultimate embedding (CLS token, pre-head).
        - timm DeiT/ViT: model.forward_features(x)
        - torchvision ViT: manual forward to encoder output CLS then final LN
        """
        self.model.eval()
        with torch.no_grad():
            if_toggle = hasattr(self.model, "forward_features")
            if self.use_tiny:
                # timm path (DeiT-Tiny)
                if not if_toggle:
                    raise RuntimeError("Expected timm model to have forward_features().")
                feats = self.model.forward_features(x)
                # For timm ViT/DeiT, forward_features usually returns CLS embedding (B, D).
                # Some variants return (B, N, D); handle that defensively:
                if feats.dim() == 3:
                    feats = feats[:, 0]
                return feats

            # torchvision path (ViT-B/16)
            x = self.model._process_input(x)
            n = x.shape[0]
            cls = self.model.class_token.expand(n, -1, -1)
            x = torch.cat([cls, x], dim=1)
            x = self.model.encoder(x)
            x = x[:, 0]
            x = self.model.encoder.ln(x)
            return x
    
    def get_all_embeddings(self, dataloader, max_samples: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Extract penultimate embeddings for all samples in dataloader.
        
        Args:
            dataloader: DataLoader with images
            max_samples: Maximum number of samples to process
            
        Returns:
            embeddings: (N, D) tensor of embeddings
            labels: (N,) tensor of labels
        """
        embeddings_list = []
        labels_list = []
        total = 0
        
        self.model.eval()
        with torch.no_grad():
            for images, labels in dataloader:
                images = images.to(self.device)
                
                emb = self.get_penultimate_features(images)
                embeddings_list.append(emb.cpu())
                labels_list.append(labels)
                
                total += len(labels)
                if max_samples and total >= max_samples:
                    break
        
        embeddings = torch.cat(embeddings_list, dim=0)
        labels = torch.cat(labels_list, dim=0)
        
        if max_samples:
            embeddings = embeddings[:max_samples]
            labels = labels[:max_samples]
            
        return embeddings, labels
    
    def get_param_vector(self) -> torch.Tensor:
        """Flatten all parameters into a single vector."""
        return torch.cat([p.data.view(-1) for p in self.model.parameters()])
    
    def get_grad_vector(self) -> torch.Tensor:
        if not hasattr(self, "_flat_grad_buf") or self._flat_grad_buf.numel() != self.num_params:
            self._flat_grad_buf = torch.empty(self.num_params, device=next(self.model.parameters()).device)

        offset = 0
        for p in self.model.parameters():
            n = p.numel()
            if p.grad is None:
                self._flat_grad_buf[offset:offset+n].zero_()
            else:
                self._flat_grad_buf[offset:offset+n].copy_(p.grad.view(-1))
            offset += n
        return self._flat_grad_buf
    
    def set_param_grad_from_vector(self, v: torch.Tensor):
        """Set parameter gradients from a flattened vector."""
        offset = 0
        for p in self.model.parameters():
            numel = p.numel()
            if p.grad is None:
                p.grad = torch.zeros_like(p.data)
            p.grad.copy_(v[offset:offset + numel].view_as(p.data))
            offset += numel
            
    def enable_full_grads(self, enabled: bool = True):
        for p in self.model.parameters():
            p.requires_grad_(enabled)

    @property
    def num_params(self) -> int:
        """Total number of parameters."""
        return sum(p.numel() for p in self.model.parameters())
    
    def get_layer_sizes(self) -> Dict[str, int]:
        """Get parameter count per layer."""
        sizes = OrderedDict()
        for name, param in self.model.named_parameters():
            sizes[name] = param.numel()
        return sizes


def load_vit(pretrained: bool = True, device: str = "cuda", use_tiny: bool = True) -> ViTWithHooks:
    """
    Load ViT-B/16 with hooks for KFAC computation.
    
    Args:
        pretrained: Whether to load pretrained weights
        device: Device to load model on
        
    Returns:
        ViTWithHooks model instance
    """
    model = ViTWithHooks(pretrained=pretrained, device=device, use_tiny=use_tiny)
    logger.info(f"Loaded {'DeiT-Tiny' if use_tiny else 'ViT-B/16'} with {model.num_params:,} parameters")
    return model


if __name__ == "__main__":
    # Test loading
    model = load_vit()
    model.register_kfac_hooks()
    
    # Test forward pass
    x = torch.randn(4, 3, 224, 224, device="cuda")
    out = model(x)
    print(f"Output shape: {out.shape}")
    
    # Test penultimate features
    emb = model.get_penultimate_features(x)
    print(f"Embedding shape: {emb.shape}")
    
    print(f"Captured activations for {len(model.activations)} layers")
