import numpy as np
import torch
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize

CLIP_DEFAULT_MEAN = (0.481455, 0.457827, 0.408211)
CLIP_DEFAULT_STD = (0.26863, 0.261303, 0.275777)


def preprocess_before_vae(x, resolution):
    """Preprocess images before encoding with VAE."""
    x = torch.nn.functional.interpolate(x, resolution, mode="bicubic")
    return x


def preprocess_raw_image(x, enc_type):
    """Preprocess raw images for a specific encoder type."""
    resolution = x.shape[-1]
    if "clip" in enc_type:
        x = (x + 1) / 2
        x = torch.nn.functional.interpolate(
            x, 224 * (resolution // 256), mode="bicubic"
        )
        x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
    elif "mocov3" in enc_type or "mae" in enc_type:
        x = (x + 1) / 2
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
    elif "dinov2" in enc_type:
        x = (x + 1) / 2
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
        x = torch.nn.functional.interpolate(
            x, 224 * (resolution // 256), mode="bicubic"
        )
    elif "dinov1" in enc_type:
        x = (x + 1) / 2
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
    elif "jepa" in enc_type:
        x = (x + 1) / 2
        x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
        x = torch.nn.functional.interpolate(
            x, 224 * (resolution // 256), mode="bicubic"
        )
    elif "openphenom" in enc_type:
        x = (x + 1) / 2
    return x


def prepare_ground_truth_samples(
    train_dataloader,
    class_ids,
    cell_type_ids,
    device,
    resolution,
    latent_size,
    vae,
    logger,
    accelerator=None,
):
    """
    Prepares ground truth samples for visualization during training.
    """
    latents_scale = (
        torch.tensor([0.18215, 0.18215, 0.18215, 0.18215]).view(1, 4, 1, 1).to(device)
    )
    latents_bias = torch.tensor([0, 0, 0, 0]).view(1, 4, 1, 1).to(device)
    sample_idx = 0
    sample_cell_type = cell_type_ids[0].item()
    class_to_sample = class_ids[0].item()
    found = False

    # Search for a batch containing the desired class and cell type
    for i, batch in enumerate(train_dataloader):
        x, y = batch
        for j in range(len(y)):
            if (
                y[j]["class_id"].item() == class_to_sample
                and y[j]["cell_type_id"].item() == sample_cell_type
            ):
                sample_idx = j
                found = True
                break
        if found:
            break

    # If found, process and encode the sample
    if found:
        if logger:
            logger.info(
                f"Found sample with class ID {class_to_sample} and cell type ID {sample_cell_type}"
            )
        x_sample = x[sample_idx : sample_idx + 1]
        if accelerator:
            x_sample = accelerator.gather(x_sample)
        x_sample = preprocess_before_vae(x_sample, resolution)
        x_sample = x_sample.to(device)

        # Encode with VAE
        with torch.no_grad():
            gt_xs = vae.encode(x_sample).latent_dist.sample()
            gt_xs = gt_xs * latents_scale + latents_bias

        return gt_xs, sample_idx
    else:
        if logger:
            logger.warning(
                f"Could not find sample with class ID {class_to_sample} and cell type ID {sample_cell_type}"
            )
        return None, None


def process_perturbation_samples(
    data_module,
    perturbation_id,
    num_samples=100,
    device=None,
    accelerator=None,
    cell_type=None,
):
    """
    Filter the dataset for a specific perturbation and ensure exactly num_samples are returned.
    """
    filtered_dataset = data_module.filter_samples(
        perturbation_id=perturbation_id, cell_type_id=cell_type
    )
    # Filter samples with the specified perturbation ID
    perturbation_samples = []
    metadata = []

    for i in range(len(filtered_dataset)):
        sample, y, ct = filtered_dataset[i]
        perturbation_samples.append(sample)
        metadata.append({"perturbation_id": y, "cell_type_id": ct})

    if len(perturbation_samples) == 0:
        return None, None

    # Handle case with fewer samples than requested
    if len(perturbation_samples) < num_samples:
        # Randomly repeat samples to reach num_samples
        indices = np.random.choice(
            len(perturbation_samples),
            num_samples - len(perturbation_samples),
            replace=True,
        )
        additional_samples = [perturbation_samples[i] for i in indices]
        additional_metadata = [metadata[i] for i in indices]
        perturbation_samples.extend(additional_samples)
        metadata.extend(additional_metadata)

    # Handle case with more samples than requested
    elif len(perturbation_samples) > num_samples:
        # Randomly select num_samples
        indices = np.random.choice(
            len(perturbation_samples), num_samples, replace=False
        )
        perturbation_samples = [perturbation_samples[i] for i in indices]
        metadata = [metadata[i] for i in indices]

    # Convert to tensor
    samples_tensor = torch.stack(perturbation_samples).to(device)

    return samples_tensor, metadata
