import numpy as np
import torch
from torch.nn.functional import normalize
import h5py
from typing import Any, Dict, Tuple

def clip_base_embedding(f: h5py.File, start_idx: int, end_idx: int, device: torch.device, args: Any) -> torch.Tensor:
    """
    Extract standard CLIP image embeddings from the HDF5 file.
    
    Args:
        f: HDF5 file containing the image embeddings
        start_idx: Starting index to extract embeddings from
        end_idx: Ending index to extract embeddings to
        device: Device to load tensors to (not used in this function)
        args: Additional arguments (not used in this function)
        
    Returns:
        torch.Tensor: Extracted image embeddings
    """
    image_embedding = f["image_embeddings"][start_idx:end_idx]
    return torch.from_numpy(image_embedding)

def direct_effect_embedding(f: h5py.File, start_idx: int, end_idx: int, device: torch.device, args: Any) -> torch.Tensor:
    """
    Extract image embeddings using the direct effect method, which combines
    attention and MLP outputs.
    
    Args:
        f: HDF5 file containing the attention and MLP data
        start_idx: Starting index to extract data from
        end_idx: Ending index to extract data to
        device: Device to load tensors to (not used in this function)
        args: Additional arguments (not used in this function)
        
    Returns:
        torch.Tensor: Computed direct effect embeddings
    """
    attentions = f["attentions"][start_idx:end_idx]
    mlps = f["mlps"][start_idx:end_idx]
    M_image = (
            attentions.sum(axis=1)
            + mlps.sum(axis=1)
    )

    return torch.from_numpy(M_image)

def text_based_decomposition_embedding(f0: h5py.File, start_idx: int, end_idx: int, device: torch.device, args: Any) -> torch.Tensor:
    """
    Extract image embeddings using text-based decomposition method, which applies
    model-specific ablations to certain attention heads and layers.
    
    This method customizes the attention and MLP outputs based on pre-determined
    settings for different model architectures, targeting specific heads and layers
    for ablation to improve generalization across domains.
    
    Args:
        f0: HDF5 file containing the attention and MLP data
        start_idx: Starting index to extract data from
        end_idx: Ending index to extract data to
        device: Device to load tensors to (not used in this function)
        args: Command line arguments containing model information
        
    Returns:
        torch.Tensor: Computed text-based decomposition embeddings
        
    Raises:
        ValueError: If the model architecture is not supported
    """
    attentions_batch = f0["attentions"][start_idx:end_idx]
    mlps_batch = f0["mlps"][start_idx:end_idx]

    # Define model-specific ablation settings
    if args.model == "ViT-H-14":
        to_mean_ablate_setting = [(31, 12), (30, 11), (29, 4)]
        to_mean_ablate_geo = [(31, 8), (30, 15), (30, 12), (30, 6), (29, 14), (29, 8)]
    elif args.model == "ViT-L-14":
        to_mean_ablate_geo = [(21, 1), (22, 12), (22, 13), (21, 11), (21, 14), (23, 6)]
        to_mean_ablate_setting = [
            (21, 3),
            (21, 6),
            (21, 8),
            (21, 13),
            (22, 2),
            (22, 12),
            (22, 15),
            (23, 1),
            (23, 3),
            (23, 5),
        ]
    elif args.model == "ViT-B-16":
        to_mean_ablate_setting = [(11, 3), (10, 11), (10, 10), (9, 8), (9, 6)]
        to_mean_ablate_geo = [(11, 6), (11, 0)]
    elif args.model == "ViT-B-32":
        to_mean_ablate_setting = [(11, 3), (10, 11), (10, 10), (9, 8), (9, 6)]
        to_mean_ablate_geo = [(11, 6), (11, 0)]
    else:
        raise ValueError('Model not analyzed or not supported')
        
    # Combine setting and geo ablations
    to_mean_ablate_output = to_mean_ablate_geo + to_mean_ablate_setting

    # Apply ablations to specific layer-head combinations
    for layer, head in to_mean_ablate_output:
        attentions_batch[:, layer, head, :] = np.mean(
            f0["attentions"][:, layer, head, :], axis=0, keepdims=True
        )
        
    # Apply ablations to early layers
    for layer in range(attentions_batch.shape[1] - 4):
        for head in range(attentions_batch.shape[2]):
            attentions_batch[:, layer, head, :] = np.mean(
                f0["attentions"][:, layer, head, :], axis=0, keepdims=True
            )
            
    # Apply ablations to all MLP layers
    for layer in range(mlps_batch.shape[1]):
        mlps_batch[:, layer] = np.mean(f0["mlps"][:, layer], axis=0, keepdims=True)

    # Compute final embeddings by summing ablated attention and MLP outputs
    M_image = (attentions_batch.sum(axis=(1,2)) + mlps_batch.sum(axis=1))

    return torch.from_numpy(M_image)

