import torch
from sc_perturb.samplers import euler_maruyama_sampler


def process_latents_through_vae(
    latents, vae, latent_size, resolution, latents_bias, latents_scale, C
):
    """
    Process latent representations through VAE decoder.
    """
    B = latents.shape[0]
    latents_reshaped = latents.view(B * C, 4, latent_size, latent_size)
    samples = vae.decode((latents_reshaped - latents_bias) / latents_scale).sample
    samples = samples.mean(dim=1, keepdim=True)
    samples = samples.view(B, C, resolution, resolution)
    samples = (samples + 1) / 2
    samples = samples.clamp(0, 1)
    return samples


def generate_and_process_samples(
    model,
    noise,
    class_ids,
    cell_type_ids,
    vae,
    latent_size,
    resolution,
    latents_bias,
    latents_scale,
    path_type,
    C,
    device,
    heun=False,
):
    """
    Generates samples using the diffusion model and processes them through the VAE decoder.
    """
    samples = euler_maruyama_sampler(
        model,
        noise,
        class_ids,
        chs=None,
        cts=cell_type_ids,
        num_steps=50,
        cfg_scale=0.0,
        guidance_low=0,
        guidance_high=1,
        path_type=path_type,
        heun=heun,
    ).to(torch.float32)
    B, Ch, H, W = samples.shape
    samples = samples.view(B * C, 4, latent_size, latent_size)
    samples = vae.decode((samples - latents_bias) / latents_scale).sample
    samples = samples.mean(dim=1, keepdim=True)
    samples = samples.view(B, C, resolution, resolution)
    samples = (samples + 1) / 2
    samples = samples.clamp(0, 1)
    return samples


def generate_and_process_samples_multi_celltype(
    model,
    noise,
    class_ids,
    cell_type_ids_list,
    vae,
    latent_size,
    resolution,
    latents_bias,
    latents_scale,
    path_type,
    C,
    device,
    heun=False,
):
    """
    Generates samples for a single class across multiple cell types using the diffusion model.
    """
    num_cell_types = len(cell_type_ids_list)
    B = noise.shape[0]
    all_samples = []

    for cell_type_ids in cell_type_ids_list:
        # Repeat the same noise and class IDs for this cell type
        cell_type_tensor = torch.tensor([cell_type_ids] * B, device=device)

        # Generate samples for this cell type
        samples = euler_maruyama_sampler(
            model,
            noise,
            class_ids,
            chs=None,
            cts=cell_type_tensor,
            num_steps=50,
            cfg_scale=0.0,
            guidance_low=0,
            guidance_high=1,
            path_type=path_type,
            heun=heun,
        ).to(torch.float32)

        # Process through VAE
        samples_processed = process_latents_through_vae(
            samples, vae, latent_size, resolution, latents_bias, latents_scale, C
        )

        all_samples.append(samples_processed)

    # Stack along a new dimension to get [num_cell_types, B, C, H, W]
    return torch.stack(all_samples)


def generate_perturbation_matched_samples(
    model,
    perturbation_id,
    metadata,
    vae,
    latent_size,
    resolution,
    latents_bias,
    latents_scale,
    path_type,
    device,
    num_steps=50,
    cfg_scale=4.0,
    guidance_low=0.0,
    guidance_high=1.0,
    batch_size= 8,
):
    """
    Generate samples for a specific perturbation ID that match the cell type distribution in metadata.
    """
    if not metadata or len(metadata) == 0:
        print(f"No metadata available for perturbation {perturbation_id}")
        return None, None

    num_samples = len(metadata)
    cell_type_counts = {}

    # Count occurrences of each cell type
    for meta in metadata:
        cell_type = meta["cell_type_id"]
        if cell_type not in cell_type_counts:
            cell_type_counts[cell_type] = 0
        cell_type_counts[cell_type] += 1

    # Generate samples for each cell type according to counts
    all_samples = []
    all_meta = []
    for cell_type, total_needed in cell_type_counts.items():
        remaining = total_needed
        while remaining > 0:
            cur_bs = min(batch_size, remaining)   # this chunk size

            # --- 2a. Random noise --------------------------------------
            noise = torch.randn(
                cur_bs, 24, latent_size, latent_size,
                device=device, dtype=torch.float32
            )

            # --- 2b. Conditioning tensors -----------------------------
            class_ids     = torch.full((cur_bs,), perturbation_id, dtype=torch.long, device=device)
            cell_type_ids = torch.full((cur_bs,), cell_type,       dtype=torch.long, device=device)

            # --- 2c. Diffusion step -----------------------------------
            latents = euler_maruyama_sampler(
                model,
                noise,
                class_ids,
                cts=cell_type_ids,
                num_steps=num_steps,
                cfg_scale=cfg_scale,
                guidance_low=guidance_low,
                guidance_high=guidance_high,
                path_type=path_type,
            )

            # --- 2d. Decode through VAE -------------------------------
            imgs = process_latents_through_vae(
                latents,
                vae,
                latent_size,
                resolution,
                latents_bias,
                latents_scale,
                C=6,       # 6‑channel Cell‑Painting stack
            )
            all_samples.append(imgs.cpu())         # keep host memory free

            # --- 2e. Build matching metadata --------------------------
            all_meta.extend(
                {
                    "perturbation_id": perturbation_id,
                    "cell_type_id":    cell_type,
                    "is_generated":    True,
                }
                for _ in range(cur_bs)
            )

            remaining -= cur_bs

    if not all_samples:
        return None, None

    return torch.cat(all_samples, dim=0), all_meta


def metadata_match(perturbation_metadata, generation_metadata):
    """
    Compare cell type distribution between original and generated metadata
    """
    cell_type_counts_pert = {}
    cell_type_counts_gen = {}

    # Count cell types in perturbation metadata
    for meta in perturbation_metadata:
        cell_type = meta["cell_type_id"]
        if cell_type not in cell_type_counts_pert:
            cell_type_counts_pert[cell_type] = 0
        cell_type_counts_pert[cell_type] += 1

    # Count cell types in generation metadata
    for meta in generation_metadata:
        cell_type = meta["cell_type_id"]
        if cell_type not in cell_type_counts_gen:
            cell_type_counts_gen[cell_type] = 0
        cell_type_counts_gen[cell_type] += 1

    # Compare distributions
    all_cell_types = set(
        list(cell_type_counts_pert.keys()) + list(cell_type_counts_gen.keys())
    )
    total_pert = len(perturbation_metadata)
    total_gen = len(generation_metadata)

    mismatch = 0
    for cell_type in all_cell_types:
        pert_count = cell_type_counts_pert.get(cell_type, 0)
        gen_count = cell_type_counts_gen.get(cell_type, 0)

        pert_frac = pert_count / total_pert if total_pert > 0 else 0
        gen_frac = gen_count / total_gen if total_gen > 0 else 0

        mismatch += abs(pert_frac - gen_frac)

    match_score = 1.0 - (mismatch / 2.0)  # Normalize to [0, 1]
    return match_score, cell_type_counts_pert, cell_type_counts_gen


def calculate_metrics(
    perturbation_samples,
    perturbation_metadata,
    generated_samples,
    generation_metadata,
    feature_extractor,
):
    """
    Calculate comparison metrics between real and generated samples
    """
    # Extract features for real and generated samples
    with torch.no_grad():
        features_real = feature_extractor(perturbation_samples)
        features_gen = feature_extractor(generated_samples)

    # Calculate FID score
    mean_real = features_real.mean(dim=0)
    cov_real = torch.cov(features_real.T)

    mean_gen = features_gen.mean(dim=0)
    cov_gen = torch.cov(features_gen.T)

    # Calculate squared Wasserstein distance (simplified FID)
    diff_means = torch.sum((mean_real - mean_gen) ** 2)

    # Matrix square root calculation for covariance matrices (approximate)
    sqrt_cov_prod = torch.matrix_power(torch.matmul(cov_real, cov_gen), 0.5)
    trace_term = torch.trace(cov_real + cov_gen - 2 * sqrt_cov_prod)

    fid = diff_means + trace_term

    # Calculate metadata match score
    match_score, cell_counts_real, cell_counts_gen = metadata_match(
        perturbation_metadata, generation_metadata
    )

    return {
        "fid": fid.item(),
        "match_score": match_score,
        "cell_counts_real": cell_counts_real,
        "cell_counts_gen": cell_counts_gen,
    }
