import torch
from typing import Sequence, List, Dict
from collections import defaultdict
from pytorch_lightning import seed_everything

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def image_encode(
    samples: Sequence[Dict],
    processor,
    image_name,
    label_name,
):

    images: List[torch.Tensor] = [sample[image_name].convert("RGB") for sample in samples]
    images: List[torch.Tensor] = [processor(images=image, return_tensors="pt")["pixel_values"] for image in images]

    images = torch.cat(images, dim=0)

    return {"images": images, "labels": torch.tensor([sample[label_name] for sample in samples])}


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def convert_parameters(num_parameters):
    if num_parameters >= 1_000_000_000:
        return f"{num_parameters / 1_000_000_000:.2f}B"  # Billions
    elif num_parameters >= 1_000_000:
        return f"{num_parameters / 1_000_000:.2f}M"  # Millions
    elif num_parameters >= 1_000:
        return f"{num_parameters / 1_000:.2f}K"  # Thousands
    else:
        return str(num_parameters)


def extract_specific_layers(model, max_samples, dataloader, layers_to_skip, seed=0):
    seed_everything(seed)

    stored_samples = 0
    layer_embeddings = defaultdict(list)

    model.to(device)
    model.eval()

    with torch.no_grad():
        for batch in dataloader:
            images = batch["images"].to(device)

            outputs = model(images)
            hidden_states = outputs.hidden_states[1:]

            num_to_add = min(max_samples - stored_samples, images.size(0))

            if num_to_add > 0:
                for layer_idx, layer_output in enumerate(hidden_states):
                    if layer_idx in [l for pair in layers_to_skip for l in pair]:
                        layer_embeddings[layer_idx].append(layer_output[:num_to_add].cpu())

                stored_samples += num_to_add

            if stored_samples >= max_samples:
                break

    for layer_idx in layer_embeddings:
        layer_embeddings[layer_idx] = torch.cat(layer_embeddings[layer_idx], dim=0)

    flattened_layers = torch.flatten(torch.tensor(layers_to_skip))
    unique_elements = torch.tensor(list(set(flattened_layers.tolist())))

    assert len(layer_embeddings) == len(unique_elements)

    return layer_embeddings


def extract_all_layers(model, max_samples, dataloader, only_cls):
    stored_samples = 0

    layer_embeddings = defaultdict(list)

    model.to(device)
    model.eval()

    with torch.no_grad():
        for batch in dataloader:
            images = batch["images"].to(device)

            outputs = model(images)
            hidden_states = outputs.hidden_states[1:]

            num_to_add = min(max_samples - stored_samples, images.size(0))

            if num_to_add > 0:
                for layer_idx, layer_output in enumerate(hidden_states):
                    if only_cls:
                        layer_embeddings[layer_idx].append(layer_output[:num_to_add, 0, :].cpu())
                    else:
                        layer_embeddings[layer_idx].append(layer_output[:num_to_add].cpu())
                stored_samples += num_to_add

            if stored_samples >= max_samples:
                break

    for layer_idx in layer_embeddings:
        layer_embeddings[layer_idx] = torch.cat(layer_embeddings[layer_idx], dim=0)

    assert len(layer_embeddings) == 12

    return layer_embeddings
