from prettytable import PrettyTable
import torch
from utils.datamodules import (
    imagenet_classes, cifar100_classes, cifar10_classes, 
    food101_classes, sun397_classes, oxford_pet_classes, 
    dtd_classes, eurosat_classes, caltech101_classes
    )

def print_attention_summary(vit_model):
    table = PrettyTable()
    table.field_names = ["Layer", "# Heads", "QKV Shape", "Out Shape", "# Params"]

    for i, block in enumerate(vit_model.encoder.layers):
        mha = block.self_attention
        num_heads = mha.num_heads
        embed_dim = mha.out_proj.in_features
        out_in_dim = mha.out_proj.out_features

        qkv_shape = f"{3} x {embed_dim} x {out_in_dim}"
        out_shape = f"{embed_dim} x {out_in_dim}"
        params = sum(p.numel() for p in mha.parameters())
        table.add_row([i, num_heads, qkv_shape, out_shape, f"{params / 1e6:.2f}M"])
    
    print(table)

def print_detailed_vit_summary(vit_model):
    table = PrettyTable()
    table.field_names = ["Layer", "# Heads", "QKV Shape", "Out Shape", "MLP Shapes", "Total Params (M)"]

    for i, block in enumerate(vit_model.encoder.layers):
        mha = block.self_attention
        mlp = block.mlp

        embed_dim = mha.embed_dim
        num_heads = mha.num_heads

        qkv_shape = f"{3 * embed_dim} x {embed_dim}"
        out_shape = f"{embed_dim} x {embed_dim}"
        mlp_shapes = f"{mlp[0].weight.shape[0]} x {mlp[0].weight.shape[1]}, {mlp[3].weight.shape[0]} x {mlp[3].weight.shape[1]}"

        total_params = sum(p.numel() for p in mha.parameters()) + sum(p.numel() for p in mlp.parameters())
        table.add_row([i, num_heads, qkv_shape, out_shape, mlp_shapes, f"{total_params/1e6:.2f}"])

    print("\nViT Layer-wise Summary:")
    print(table)

def get_text_logits(
        prompts: list, 
        model, 
        tokenizer,
        device
        ) -> torch.Tensor:
    with torch.no_grad():
        text_logits = tokenizer(prompts).to(device)
        text_logits = model.encode_text(text_logits, normalize=True)
        text_logits = text_logits.T
    return text_logits

def get_prompts(
        dataset = "imagenet-1k",
        ):
    """
    Returns a list of text prompts.
    """
    datasets_available = ["imagenet-1k", "cifar100", "cifar10", "food101", "sun397", "oxford_pet", "eurosat", "caltech101", "dtd"]

    if dataset.lower() == "imagenet-1k":
        class_labels = imagenet_classes
    elif dataset.lower() == "cifar100":
        class_labels = cifar100_classes
    elif dataset.lower() == "cifar10":
        class_labels = cifar10_classes
    elif dataset.lower() == "food101":
        class_labels = food101_classes
    elif dataset.lower() == "sun397":
        class_labels = sun397_classes
    elif dataset.lower() == "oxford_pet":
        class_labels = oxford_pet_classes
    elif dataset.lower() == "dtd":
        class_labels = dtd_classes
    elif dataset.lower() == "eurosat":
        class_labels = eurosat_classes
    elif dataset.lower() == "caltech101":
        class_labels = caltech101_classes
    else:
        raise ValueError(f"Dataset {dataset} not supported. Only {datasets_available} is available.")
    
    text_prompts = [f"a photo of a {c}." for c in class_labels]
    
    return text_prompts
