import torch
import numpy as np
from scipy import linalg
from torch.nn.functional import interpolate
from torchvision.models import inception_v3
from torchvision import datasets, transforms
import os


def load_inception_model(device):
    """Load pretrained Inception V3 model."""
    model = inception_v3(pretrained=True, transform_input=False)
    model.fc = torch.nn.Identity()  # Remove final classification layer
    model.eval()
    return model.to(device)


def get_inception_features(images, model, device, batch_size=32):
    """
    Extract features from Inception model.
    
    Args:
        images: Tensor of shape (N, C, H, W) with values in [-1, 1] or [0, 1]
        model: Inception model
        device: torch device
        batch_size: batch size for processing
    
    Returns:
        features: numpy array of shape (N, 2048)
    """
    # Move images to device
    images = images.to(device)
    
    # Ensure images are in [0, 1] range
    if images.min() < 0:
        images = (images + 1) / 2
    
    # Resize to 299x299 for Inception
    if images.shape[-1] != 299:
        images = interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
    
    # Normalize for Inception (ImageNet normalization) - ensure tensors are on the same device
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    images = (images - mean) / std
    
    all_features = []
    
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            
            # Get features (2048-dim from avgpool)
            features = model(batch)
            all_features.append(features.cpu().numpy())
            
    features = np.concatenate(all_features, axis=0)
    
    return features


def calculate_inception_score(images, model, device, splits=10, batch_size=32):
    """
    Calculate Inception Score.
    
    Args:
        images: Tensor of shape (N, C, H, W)
        model: Inception model with classifier
        device: torch device
        splits: number of splits for computing mean/std
        batch_size: batch size for processing
    
    Returns:
        mean_score: Mean IS across splits
        std_score: Std IS across splits
    """
    # We need the full model with classifier for IS
    full_model = inception_v3(pretrained=True, transform_input=False)
    full_model.eval()
    full_model = full_model.to(device)
    
    # Move images to device
    images = images.to(device)
    
    # Ensure images are in [0, 1] range
    if images.min() < 0:
        images = (images + 1) / 2
    
    # Resize to 299x299
    if images.shape[-1] != 299:
        images = interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
    
    # Normalize - ensure tensors are on the same device
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    images = (images - mean) / std
    
    # Get predictions
    all_preds = []
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            preds = torch.nn.functional.softmax(full_model(batch), dim=1)
            all_preds.append(preds.cpu().numpy())
    
    preds = np.concatenate(all_preds, axis=0)
    
    # Calculate IS
    scores = []
    split_size = len(preds) // splits
    
    for i in range(splits):
        part = preds[i * split_size:(i + 1) * split_size]
        kl_divergence = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), 0)))
        kl_divergence = np.mean(np.sum(kl_divergence, axis=1))
        scores.append(np.exp(kl_divergence))
    
    return np.mean(scores), np.std(scores)


def calculate_fid(real_features, fake_features, eps=1e-6):
    """
    Calculate Fréchet Inception Distance.
    
    Args:
        real_features: numpy array of real image features (N, 2048)
        fake_features: numpy array of fake image features (M, 2048)
        eps: regularization term for numerical stability
    
    Returns:
        fid_score: FID value
    """
    # Calculate mean and covariance
    mu1 = np.mean(real_features, axis=0)
    sigma1 = np.cov(real_features, rowvar=False)
    
    mu2 = np.mean(fake_features, axis=0)
    sigma2 = np.cov(fake_features, rowvar=False)
    
    # Add small epsilon to diagonal for numerical stability
    sigma1 = sigma1 + np.eye(sigma1.shape[0]) * eps
    sigma2 = sigma2 + np.eye(sigma2.shape[0]) * eps
    
    # Calculate squared difference of means
    diff = mu1 - mu2
    ssdiff = np.sum(diff ** 2)
    
    # Calculate sqrt of product of covariances
    covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
    
    # Check for imaginary numbers (indicates numerical issues)
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            print(f"Warning: Imaginary component in covmean with max: {np.max(np.abs(covmean.imag))}")
        covmean = covmean.real
    
    # Calculate trace
    tr_covmean = np.trace(covmean)
    
    # Final FID calculation
    fid = ssdiff + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
    
    return fid


def load_cifar10_real_images(num_images=5000, img_size=32, batch_size=64, device='cuda', train=True, class_label=3):
    """
    Load real CIFAR-10 images for FID calculation.
    
    Args:
        num_images: Number of real images to load
        img_size: Image size (default 32 for CIFAR-10)
        batch_size: Batch size for loading
        device: 'cuda' or 'cpu'
        train: Whether to use training set (True) or test set (False)
        class_label: Which class to load (0-9). Default 3 = cats
                     CIFAR-10 classes: 0=airplane, 1=automobile, 2=bird, 3=cat, 4=deer,
                                       5=dog, 6=frog, 7=horse, 8=ship, 9=truck
                     Set to None to load all classes
    
    Returns:
        Tensor of real images (N, C, H, W) normalized to [-1, 1]
    """
    
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']
    
    if class_label is not None:
        print(f'Loading {num_images} real CIFAR-10 "{class_names[class_label]}" images (label {class_label}) from {"train" if train else "test"} set...')
    else:
        print(f'Loading {num_images} real CIFAR-10 images (all classes) from {"train" if train else "test"} set...')
    
    # Setup CIFAR-10 directory
    cifar_dir = os.environ.get("DATA_ROOT", "./cifar10_data")
    os.makedirs(cifar_dir, exist_ok=True)
    
    # Load dataset
    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
    ]
    
    # Add resize if needed
    if img_size != 32:
        transform_list.insert(0, transforms.Resize([img_size, img_size]))
    
    dataset = datasets.CIFAR10(
        cifar_dir,
        train=train,
        download=True,
        transform=transforms.Compose(transform_list),
    )
    
    # Filter by class if specified
    if class_label is not None:
        # Get indices of images with the specified label
        indices = [i for i, (_, label) in enumerate(dataset) if label == class_label]
        dataset = torch.utils.data.Subset(dataset, indices)
        print(f"Found {len(dataset)} images of class '{class_names[class_label]}' in the dataset")
        
        # Check if we have enough images
        if len(dataset) < num_images:
            print(f"Warning: Only {len(dataset)} images available for class {class_label}, loading all of them")
            num_images = len(dataset)
    
    # Load images
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,  # Set to False for reproducibility
        num_workers=0,
        pin_memory=True
    )
    
    real_images = []
    total_loaded = 0
    
    for batch_idx, (imgs, _) in enumerate(dataloader):
        real_images.append(imgs)
        total_loaded += imgs.shape[0]
        if total_loaded >= num_images:
            break
    
    real_images = torch.cat(real_images, dim=0)[:num_images]
    print(f"Loaded {real_images.shape[0]} real images with shape {real_images.shape}")
    
    return real_images


def compute_inception_scores_CIFAR10(generated_images, real_images=None, device='cuda', batch_size=32, class_label=3):
    """
    Compute Inception Score and FID for CIFAR-10 generated images.
    
    Args:
        generated_images: Tensor, numpy array, or list of generated images
        real_images: Tensor of real images (N, C, H, W) - required for FID
                     Can also be 'auto' to automatically load from CIFAR-10 dataset
        device: 'cuda' or 'cpu'
        batch_size: batch size for processing
        class_label: Which class to load if using 'auto' (0-9, or None for all classes)
                     Default 3 = cats
    
    Returns:
        Dictionary with IS mean, IS std, and FID (if real_images provided)
    """
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    # Convert list to numpy array first
    if isinstance(generated_images, list):
        generated_images = np.array(generated_images)
    
    # Convert to tensor if numpy
    if isinstance(generated_images, np.ndarray):
        generated_images = torch.from_numpy(generated_images).float()
    
    # Ensure 4D tensor
    if len(generated_images.shape) == 5:
        generated_images = generated_images.reshape(-1, *generated_images.shape[2:])
    
    print(f"\nGenerated images - shape: {generated_images.shape}, range: [{generated_images.min():.3f}, {generated_images.max():.3f}]")
    
    # Load model
    print("Loading Inception model...")
    model = load_inception_model(device)
    
    # Calculate Inception Score
    print("Calculating Inception Score...")
    is_mean, is_std = calculate_inception_score(generated_images, model, device, batch_size=batch_size)
    
    results = {
        'inception_score_mean': is_mean,
        'inception_score_std': is_std
    }
    
    print(f"Inception Score: {is_mean:.3f} ± {is_std:.3f}")
    
    # Calculate FID if real images provided
    if real_images is not None:
        print("Calculating FID...")
        
        # Auto-load real images if requested
        if isinstance(real_images, str) and real_images.lower() == 'auto':
            num_samples = min(5000, len(generated_images))  # Use 5000 for single class
            real_images = load_cifar10_real_images(
                num_images=num_samples, 
                img_size=generated_images.shape[-1],
                device=device,
                train=True,  # Use training set by default
                class_label=class_label  # Load only specified class (default: cats)
            )
        
        if isinstance(real_images, np.ndarray):
            real_images = torch.from_numpy(real_images).float()
        
        if len(real_images.shape) == 5:
            real_images = real_images.reshape(-1, *real_images.shape[2:])
        
        print(f"Real images - shape: {real_images.shape}, range: [{real_images.min():.3f}, {real_images.max():.3f}]")
        
        # Get features
        print("Extracting features from generated images...")
        fake_features = get_inception_features(generated_images, model, device, batch_size)
        print(f"Fake features - shape: {fake_features.shape}, mean: {fake_features.mean():.3f}, std: {fake_features.std():.3f}")
        
        print("Extracting features from real images...")
        real_features = get_inception_features(real_images, model, device, batch_size)
        print(f"Real features - shape: {real_features.shape}, mean: {real_features.mean():.3f}, std: {real_features.std():.3f}")
        
        # Calculate FID
        fid_score = calculate_fid(real_features, fake_features)
        results['fid'] = fid_score
        
        print(f"FID: {fid_score:.3f}")
    
    return results


# Example usage:
if __name__ == "__main__":
    # Example: Load real CIFAR-10 cat images (label 3)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # CIFAR-10 class labels:
    # 0=airplane, 1=automobile, 2=bird, 3=cat, 4=deer,
    # 5=dog, 6=frog, 7=horse, 8=ship, 9=truck
    
    # Option 1: Load only cat images (label 3)
    real_cat_images = load_cifar10_real_images(
        num_images=5000,   # CIFAR-10 train set has 5000 cats
        img_size=32,       # CIFAR-10 native resolution
        batch_size=64,
        device=device,
        train=True,
        class_label=3      # 3 = cats
    )
    
    # Generate fake images (example - replace with your GAN output)
    # fake_images should have shape (N, 3, 32, 32) with values in [-1, 1]
    # fake_images = your_generator(z)
    
    # Compute metrics for cat images
    # results = compute_inception_scores_CIFAR10(
    #     fake_images, 
    #     real_images=real_cat_images,  # Or use 'auto' with class_label=3
    #     device=device,
    #     class_label=3  # Only used if real_images='auto'
    # )
    
    # Alternative: Auto-load cat images
    # results = compute_inception_scores_CIFAR10(
    #     fake_images,
    #     real_images='auto',  # Automatically loads cats
    #     device=device,
    #     class_label=3  # 3 = cats
    # )
    
    # Print results
    # print(f"IS: {results['inception_score_mean']:.3f} ± {results['inception_score_std']:.3f}")
    # if 'fid' in results:
    #     print(f"FID: {results['fid']:.3f}")