import math
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from sc_perturb.dataset import to_rgb
from scipy import linalg

# Import torch dataset-related modules
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# Check if torch-fidelity is available
try:
    import torch_fidelity

    TORCH_FIDELITY_AVAILABLE = True
except ImportError:
    TORCH_FIDELITY_AVAILABLE = False
    print("Warning: torch-fidelity not available. Some metrics may not work.")

# Check if OpenPhenom is available
try:
    from openphenom import OpenPhenomEncoder

    OPENPHENOM_AVAILABLE = True
except ImportError:
    OPENPHENOM_AVAILABLE = False
    print(
        "Warning: OpenPhenom not available. OpenPhenom feature extraction won't work."
    )


class OpenPhenomFeatureExtractor:
    """
    Extracts features using the OpenPhenom model which is specialized for cell images.
    This extractor works with 6-channel cell images (size 512x512) and outputs
    feature embeddings for metric calculations.
    """

    def __init__(self, device="cuda", feature_aggregation="mean"):
        """
        Initialize the OpenPhenom feature extractor.

        Args:
            device: Device to use for computation ('cuda' or 'cpu')
            feature_aggregation: How to aggregate features across crops ('mean', 'concat', or 'flatten')
        """
        self.device = device
        self.feature_aggregation = feature_aggregation

        # Check if OpenPhenom is available
        if not OPENPHENOM_AVAILABLE:
            raise ImportError("OpenPhenom is not available. Please install it first.")

        # Initialize the OpenPhenom encoder
        self.model = (
            OpenPhenomEncoder(
                feature_dim=384,
                patch_dim=16,
                channels=6,
                crops=4,
                img_dim=512,
                return_channelwise_embeddings=False,
            )
            .to(device)
            .eval()
        )

        # No specific preprocessing is needed as OpenPhenom handles it internally
        # But we ensure the input is a tensor
        self.preprocess = transforms.Compose([transforms.ToTensor()])

    @torch.no_grad()
    def extract_features(self, images, batch_size=16):
        """
        Extract features from a batch of images using OpenPhenom.

        Args:
            images: tensor of shape [N, 6, 512, 512] in range [0, 1]
            batch_size: batch size for feature extraction

        Returns:
            features: numpy array of shape [N, feature_dim]
                      where feature_dim depends on the feature_aggregation method
        """
        # Check if images meet the input requirements
        if images.shape[1] != 6:
            raise ValueError(
                f"OpenPhenom expects 6-channel images, got {images.shape[1]} channels"
            )

        if images.shape[2] != 512 or images.shape[3] != 512:
            raise ValueError(
                f"OpenPhenom expects 512x512 images, got {images.shape[2]}x{images.shape[3]}"
            )

        # Check if images are in correct range
        eps = 1e-3
        if images.max() > 1.0 + eps or images.min() < 0.0 - eps:
            breakpoint()
            raise ValueError(
                f"Images should be in range [0, 1], got {images.min():.4f}-{images.max():.4f}"
            )

        n_samples = images.size(0)
        dataset = TensorDataset(images, torch.arange(n_samples))
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )
        features = []
        for batch, _ in tqdm(loader, desc="Extracting OpenPhenom features"):
            batch_images = batch.to(self.device)
            # Extract features using OpenPhenom
            if self.feature_aggregation == "mean":
                # Use the forward method which averages features across patches
                batch_features = self.model(batch_images).cpu().numpy()
            else:
                # Use forward_features which extracts richer features
                batch_features = self.model.forward_features(batch_images).cpu().numpy()

            features.append(batch_features)

        return np.vstack(features)


class TorchFidelityFeatureExtractor:
    """
    Extracts features using the Inception V3 model via torch-fidelity.
    This is compatible with the standard FID calculation method.
    """

    def __init__(self, device="cuda"):
        """
        Initialize the Inception V3 feature extractor.

        Args:
            device: Device to use for computation ('cuda' or 'cpu')
        """
        self.device = device

        # Check if torch-fidelity is available
        if not TORCH_FIDELITY_AVAILABLE:
            raise ImportError(
                "torch-fidelity is not available. Please install it first."
            )

        # Initialize the Inception V3 model
        from torch_fidelity.feature_extractor_inceptionv3 import (
            FeatureExtractorInceptionV3,
        )

        self.model = (
            FeatureExtractorInceptionV3(
                name="inception-v3-compat", features_list=["2048"]
            )
            .to(device)
            .eval()
        )

    @torch.no_grad()
    def extract_features(self, images, batch_size=16):
        """
        Extract features from a batch of images using Inception V3.

        Args:
            images: tensor of shape [N, C, H, W] in range [0, 1]
            batch_size: batch size for feature extraction

        Returns:
            features: numpy array of shape [N, 2048]
        """
        # Check if images are RGB (3 channels) or need conversion
        if images.shape[1] != 3:
            # Convert to RGB if we have 6-channel cell images
            if images.shape[1] == 6:
                images = torch.stack(
                    [to_rgb(img.cpu()[None]).squeeze(0) for img in images]
                )
            else:
                raise ValueError(f"Expected 3 or 6 channels, got {images.shape[1]}")

        # Convert to uint8 format required by torch-fidelity
        images = (images * 255).to(torch.uint8)

        n_samples = images.size(0)
        dataset = TensorDataset(images, torch.arange(n_samples))
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )
        features = []

        for batch, _ in tqdm(loader, desc="Extracting Inception features"):
            batch_images = batch.to(self.device)
            # The model returns a tuple with the features as the first element
            batch_features = self.model(batch_images)[0].cpu().numpy()
            features.append(batch_features)

        return np.vstack(features)


def calculate_metrics_from_features(
    features_real: np.ndarray,
    features_fake: np.ndarray,
    kid_subsets: int = 100,
    kid_subset_size: int = 1000,
):
    """
    Compute FID and KID from two collections of feature vectors.

    Parameters
    ----------
    features_real, features_fake : array-like, shape (n_samples, dim)
        Activations produced by an Inception-style network for real and
        generated (fake) images.
    kid_subsets : int, optional
        How many random subsets to draw when estimating KID.
    kid_subset_size : int, optional
        Size of each subset (must be ≤ min(n_real, n_fake)).

    Returns
    -------
    fid              : float
        Fréchet Inception Distance.
    kid_mean, kid_sd : float, float
        Mean and (unbiased) standard deviation of the KID estimate across
        the sampled subsets.
    """
    # --- basic checks -------------------------------------------------------
    features_real = np.asarray(features_real, dtype=np.float64)
    features_fake = np.asarray(features_fake, dtype=np.float64)

    if features_real.ndim != 2 or features_fake.ndim != 2:
        raise ValueError("Input arrays must be 2-D (n_samples, feature_dim).")
    if features_real.shape[1] != features_fake.shape[1]:
        raise ValueError("Real and fake feature dimensionalities must match.")
    if kid_subset_size > min(len(features_real), len(features_fake)):
        raise ValueError("kid_subset_size cannot exceed number of samples.")

    dim = features_real.shape[1]

    # -----------------------------------------------------------------------
    # 1) FID  ───────────────────────────────────────────────────────────────
    # -----------------------------------------------------------------------
    mu_r = features_real.mean(axis=0)
    mu_f = features_fake.mean(axis=0)

    cov_r = np.cov(features_real, rowvar=False)
    cov_f = np.cov(features_fake, rowvar=False)

    diff = mu_r - mu_f
    diff_sq = diff @ diff

    # matrix square-root of product Σ_r Σ_f
    covmean, _ = linalg.sqrtm(cov_r @ cov_f, disp=False)

    # Guard against small imaginary parts due to numerical error
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = float(diff_sq + np.trace(cov_r + cov_f - 2.0 * covmean))

    # -----------------------------------------------------------------------
    # 2) KID  ───────────────────────────────────────────────────────────────
    # -----------------------------------------------------------------------
    # Polynomial kernel k(x, y) = (x·y / dim + 1)^3  (as in the official paper)
    m = kid_subset_size
    kid_values = np.empty(kid_subsets, dtype=np.float64)

    for i in range(kid_subsets):
        idx_r = np.random.choice(len(features_real), m, replace=False)
        idx_f = np.random.choice(len(features_fake), m, replace=False)

        X = features_real[idx_r]  # (m, dim)
        Y = features_fake[idx_f]  # (m, dim)

        gram_xx = (X @ X.T / dim + 1.0) ** 3
        gram_yy = (Y @ Y.T / dim + 1.0) ** 3
        gram_xy = (X @ Y.T / dim + 1.0) ** 3

        # Unbiased MMD² estimate
        sum_xx = (gram_xx.sum() - np.trace(gram_xx)) / (m * (m - 1))
        sum_yy = (gram_yy.sum() - np.trace(gram_yy)) / (m * (m - 1))
        sum_xy = gram_xy.sum() / (m * m)

        kid_values[i] = sum_xx + sum_yy - 2.0 * sum_xy

    kid_mean = kid_values.mean()
    kid_sd = kid_values.std(ddof=1)  # sample SD

    return fid, kid_mean, kid_sd


def calculate_metrics_from_scratch(
    real_images,
    fake_images,
    batch_size=256,
    kid_subsets=100,
    kid_subset_size=500,
    feature_extractor="inception_v3",
):
    """
    Calculate FID and KID metrics from scratch for two sets of images.

    Args:
        real_images: tensor of shape [N, C, H, W] with real images in [0, 1]
        fake_images: tensor of shape [M, C, H, W] with fake images in [0, 1]
        batch_size: batch size for feature extraction
        kid_subsets: number of subsets for KID calculation
        kid_subset_size: size of each subset for KID calculation
        feature_extractor: which feature extractor to use ('inception_v3' or 'openphenom')

    Returns:
        fid: Fréchet Inception Distance
        kid_mean: Mean Kernel Inception Distance
        kid_std: Standard deviation of KID
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Select the appropriate feature extractor
    if feature_extractor.lower() == "inception_v3":
        extractor = TorchFidelityFeatureExtractor(device=device)
    elif feature_extractor.lower() == "openphenom":
        extractor = OpenPhenomFeatureExtractor(device=device)
    else:
        raise ValueError(
            f"Unsupported feature extractor: {feature_extractor}. "
            f"Choose from 'inception_v3' or 'openphenom'."
        )

    # Extract features
    print("Extracting features from real images...")
    real_features = extractor.extract_features(real_images, batch_size=batch_size)
    print("Extracting features from fake images...")
    fake_features = extractor.extract_features(fake_images, batch_size=batch_size)

    # Calculate metrics using pytorch_fid's functions
    return calculate_metrics_from_features(
        real_features,
        fake_features,
        kid_subsets,
        kid_subset_size=min(kid_subset_size, len(fake_features) // 2),
    )
