from prettytable import PrettyTable
import torch
from tqdm import tqdm
from utils.datasets import imagenet_classes, imagenet_templates

def print_attention_summary(vit_model):
    table = PrettyTable()
    table.field_names = ["Layer", "# Heads", "QKV Shape", "Out Shape", "# Params"]

    for i, block in enumerate(vit_model.encoder.layers):
        mha = block.self_attention
        num_heads = mha.num_heads
        embed_dim = mha.out_proj.in_features
        out_in_dim = mha.out_proj.out_features

        qkv_shape = f"{3} x {embed_dim} x {out_in_dim}"
        out_shape = f"{embed_dim} x {out_in_dim}"
        params = sum(p.numel() for p in mha.parameters())
        table.add_row([i, num_heads, qkv_shape, out_shape, f"{params / 1e6:.2f}M"])
    
    print(table)

def print_detailed_vit_summary(vit_model):
    table = PrettyTable()
    table.field_names = ["Layer", "# Heads", "QKV Shape", "Out Shape", "MLP Shapes", "Total Params (M)"]

    for i, block in enumerate(vit_model.encoder.layers):
        mha = block.self_attention
        mlp = block.mlp

        embed_dim = mha.embed_dim
        num_heads = mha.num_heads

        qkv_shape = f"{3 * embed_dim} x {embed_dim}"
        out_shape = f"{embed_dim} x {embed_dim}"
        mlp_shapes = f"{mlp[0].weight.shape[0]} x {mlp[0].weight.shape[1]}, {mlp[3].weight.shape[0]} x {mlp[3].weight.shape[1]}"

        total_params = sum(p.numel() for p in mha.parameters()) + sum(p.numel() for p in mlp.parameters())
        table.add_row([i, num_heads, qkv_shape, out_shape, mlp_shapes, f"{total_params/1e6:.2f}"])

    print("\nViT Layer-wise Summary:")
    print(table)

def get_text_logits(
        prompts: list, 
        model, 
        tokenizer,
        device
        ) -> torch.Tensor:
    with torch.no_grad():
        text_logits = []
        if isinstance(prompts[0], list):
            pbar = tqdm(prompts, desc="Tokenizing prompts", unit="class")
            for class_prompts in pbar:
                texts = tokenizer(class_prompts).to(device) 
                class_embeddings = model.encode_text(texts)
                class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
                class_embedding = class_embeddings.mean(dim=0)
                class_embedding /= class_embedding.norm()
                text_logits.append(class_embedding)
            text_logits = torch.stack(text_logits, dim=1).to(device)
        else:
            text_logits = tokenizer(prompts).to(device)
            text_logits = model.encode_text(text_logits, normalize=True)
            text_logits = text_logits.T
    return text_logits

def get_prompts(
        dataset = "imagenet1k",
        ensemble = False
        ):
    """
    Returns a list of text prompts.
    If `ensemble` is True, returns a list of lists using pre-defined templates.
    """
    datasets_available = ["imagenet1k"]

    if dataset == "imagenet1k":
        class_labels = imagenet_classes
        templates = imagenet_templates
    else:
        raise ValueError(f"Dataset {dataset} not supported. Only {datasets_available} is available.")
    
    if ensemble:
        text_prompts = []
        for c in class_labels:
            texts = [template.format(c) for template in templates]  
            text_prompts.append(texts)
    else:
        text_prompts = [f"a photo of a {c}." for c in class_labels]
    
    return text_prompts
