import torch
import numpy as np
from scipy import linalg
from torch.nn.functional import interpolate
from torchvision.models import inception_v3
from torchvision import datasets, transforms
import os


def load_inception_model(device):
    """Load pretrained Inception V3 model."""
    model = inception_v3(pretrained=True, transform_input=False)
    model.fc = torch.nn.Identity()  # Remove final classification layer
    model.eval()
    return model.to(device)


def get_inception_features(images, model, device, batch_size=32):
    """
    Extract features from Inception model.
    
    Args:
        images: Tensor of shape (N, C, H, W) with values in [-1, 1] or [0, 1]
        model: Inception model
        device: torch device
        batch_size: batch size for processing
    
    Returns:
        features: numpy array of shape (N, 2048)
        predictions: numpy array of shape (N, 1008) - for IS calculation
    """
    # Move images to device
    images = images.to(device)
    
    # Ensure images are in [0, 1] range
    if images.min() < 0:
        images = (images + 1) / 2
    
    # Resize to 299x299 for Inception
    if images.shape[-1] != 299:
        images = interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
    
    # Normalize for Inception (ImageNet normalization) - ensure tensors are on the same device
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    images = (images - mean) / std
    
    all_features = []
    all_predictions = []
    
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            
            # Get features (2048-dim from avgpool)
            features = model(batch)
            all_features.append(features.cpu().numpy())
            
    features = np.concatenate(all_features, axis=0)
    
    return features


def calculate_inception_score(images, model, device, splits=10, batch_size=32):
    """
    Calculate Inception Score.
    
    Args:
        images: Tensor of shape (N, C, H, W)
        model: Inception model with classifier
        device: torch device
        splits: number of splits for computing mean/std
        batch_size: batch size for processing
    
    Returns:
        mean_score: Mean IS across splits
        std_score: Std IS across splits
    """
    # We need the full model with classifier for IS
    full_model = inception_v3(pretrained=True, transform_input=False)
    full_model.eval()
    full_model = full_model.to(device)
    
    # Move images to device
    images = images.to(device)
    
    # Ensure images are in [0, 1] range
    if images.min() < 0:
        images = (images + 1) / 2
    
    # Resize to 299x299
    if images.shape[-1] != 299:
        images = interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
    
    # Normalize - ensure tensors are on the same device
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    images = (images - mean) / std
    
    # Get predictions
    all_preds = []
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            preds = torch.nn.functional.softmax(full_model(batch), dim=1)
            all_preds.append(preds.cpu().numpy())
    
    preds = np.concatenate(all_preds, axis=0)
    
    # Calculate IS
    scores = []
    split_size = len(preds) // splits
    
    for i in range(splits):
        part = preds[i * split_size:(i + 1) * split_size]
        kl_divergence = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), 0)))
        kl_divergence = np.mean(np.sum(kl_divergence, axis=1))
        scores.append(np.exp(kl_divergence))
    
    return np.mean(scores), np.std(scores)


def calculate_fid(real_features, fake_features):
    """
    Calculate Fréchet Inception Distance.
    
    Args:
        real_features: numpy array of real image features (N, 2048)
        fake_features: numpy array of fake image features (M, 2048)
    
    Returns:
        fid_score: FID value
    """
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    
    # Calculate squared difference of means
    ssdiff = np.sum((mu1 - mu2) ** 2)
    
    # Calculate sqrt of product of covariances
    covmean = linalg.sqrtm(sigma1.dot(sigma2))
    
    # Check for imaginary numbers
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid


def load_celeba_real_images(num_images=10000, img_size=64, batch_size=64, device='cuda'):
    """
    Load real CelebA images for FID calculation.
    
    Args:
        num_images: Number of real images to load
        img_size: Image size (default 64)
        batch_size: Batch size for loading
        device: 'cuda' or 'cpu'
    
    Returns:
        Tensor of real images (N, C, H, W) normalized to [-1, 1]
    """

    
    print(f'Loading {num_images} real CelebA images...')
    
    # Setup CelebA directory
    celeba_dir = os.environ.get("DATA_ROOT", "/home/s2670758/celeba_data")
    os.makedirs(celeba_dir, exist_ok=True)
    
    # Check for images
    img_dir = os.path.join(celeba_dir, 'celeba', 'img_align_celeba')
    if not os.path.exists(img_dir):
        print("Extracting CelebA dataset...")
        import zipfile
        zip_path = os.path.join(celeba_dir, 'celeba', 'img_align_celeba.zip')
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.join(celeba_dir, 'celeba'))
    
    # Check for required annotation files
    anno_dir = os.path.join(celeba_dir, 'celeba')
    required_files = ['list_attr_celeba.txt', 'identity_CelebA.txt', 'list_bbox_celeba.txt', 
                      'list_landmarks_align_celeba.txt', 'list_eval_partition.txt']
    missing_files = [f for f in required_files if not os.path.exists(os.path.join(anno_dir, f))]
    
    if missing_files:
        print(f"Missing annotation files: {missing_files}")
        print("Downloading annotation files...")
        from torchvision.datasets.utils import download_url
        base_url = "https://drive.google.com/uc?export=download&id="
        file_ids = {
            'list_attr_celeba.txt': '0B7EVK8r0v71pblRyaVFSWGxPY0U',
            'identity_CelebA.txt': '1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS',
            'list_bbox_celeba.txt': '0B7EVK8r0v71pbThiMVRxWXZ4dU0',
            'list_landmarks_align_celeba.txt': '0B7EVK8r0v71pd0FJY3Blby1HUTQ',
            'list_eval_partition.txt': '0B7EVK8r0v71pY0NSMzRuSXJEVkk'
        }
        for filename in missing_files:
            if filename in file_ids:
                try:
                    url = base_url + file_ids[filename]
                    print(f"Downloading {filename}...")
                    download_url(url, anno_dir, filename=filename)
                except Exception as e:
                    print(f"Failed to download {filename}: {e}")
                    print(f"Please download manually from Google Drive and place in {anno_dir}")
    
    # Load dataset
    dataset = datasets.CelebA(
        celeba_dir,
        split='train',
        download=False,
        transform=transforms.Compose([
            transforms.Resize([img_size, img_size]),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ]),
    )
    
    # Load images
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    real_images = []
    total_loaded = 0
    
    for batch_idx, (imgs, _) in enumerate(dataloader):
        real_images.append(imgs)
        total_loaded += imgs.shape[0]
        if total_loaded >= num_images:
            break
    
    real_images = torch.cat(real_images, dim=0)[:num_images]
    print(f"Loaded {real_images.shape[0]} real images with shape {real_images.shape}")
    
    return real_images

def compute_inception_scores_CelebA(generated_images, real_images=None, device='cuda', batch_size=32):
    """
    Compute Inception Score and FID for CelebA generated images.
    
    Args:
        generated_images: Tensor, numpy array, or list of generated images
        real_images: Tensor of real images (N, C, H, W) - required for FID
                     Can also be 'auto' to automatically load from CelebA dataset
        device: 'cuda' or 'cpu'
        batch_size: batch size for processing
    
    Returns:
        Dictionary with IS mean, IS std, and FID (if real_images provided)
    """
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    # Convert list to numpy array first
    if isinstance(generated_images, list):
        generated_images = np.array(generated_images)
    
    # Convert to tensor if numpy
    if isinstance(generated_images, np.ndarray):
        generated_images = torch.from_numpy(generated_images).float()
    
    # Ensure 4D tensor
    if len(generated_images.shape) == 5:
        generated_images = generated_images.reshape(-1, *generated_images.shape[2:])
    
    # Load model
    print("Loading Inception model...")
    model = load_inception_model(device)
    
    # Calculate Inception Score
    print("Calculating Inception Score...")
    is_mean, is_std = calculate_inception_score(generated_images, model, device, batch_size=batch_size)
    
    results = {
        'inception_score_mean': is_mean,
        'inception_score_std': is_std
    }
    
    print(f"Inception Score: {is_mean:.3f} ± {is_std:.3f}")
    
    # Calculate FID if real images provided
    if real_images is not None:
        print("Calculating FID...")
        
        # Auto-load real images if requested
        if isinstance(real_images, str) and real_images.lower() == 'auto':
            num_samples = min(10000, len(generated_images))
            real_images = load_celeba_real_images(
                num_images=num_samples, 
                img_size=generated_images.shape[-1],
                device=device
            )
        
        if isinstance(real_images, np.ndarray):
            real_images = torch.from_numpy(real_images).float()
        
        if len(real_images.shape) == 5:
            real_images = real_images.reshape(-1, *real_images.shape[2:])
        
        # Get features
        fake_features = get_inception_features(generated_images, model, device, batch_size)
        real_features = get_inception_features(real_images, model, device, batch_size)
        
        # Calculate FID
        fid_score = calculate_fid(real_features, fake_features)
        results['fid'] = fid_score
        
        print(f"FID: {fid_score:.3f}")
    
    return results