#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import os
import math
import random
from pathlib import Path
from typing import List, Tuple, Optional

import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import inception_v3, Inception_V3_Weights, InceptionOutputs

from tqdm import tqdm
from scipy import linalg


# ----------------------------
# Data utilities
# ----------------------------

IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}


class ImageFolderFlat(Dataset):
    def __init__(self, root: str, image_size: int = 299):
        self.paths = self._gather_images(root)
        if len(self.paths) == 0:
            raise RuntimeError(f"No images found in: {root}")
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),  # [0,1] float32
            transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalization
                                 std=[0.229, 0.224, 0.225]),
        ])

    @staticmethod
    def _gather_images(root: str) -> List[str]:
        root = Path(root)
        if not root.exists():
            raise FileNotFoundError(f"Path does not exist: {root}")
        out = []
        for p in sorted(root.rglob("*")):
            if p.is_file() and p.suffix.lower() in IMG_EXTS:
                out.append(str(p))
        return out

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        with Image.open(p) as img:
            img = img.convert("RGB")
        tensor = self.transform(img)
        return tensor


# ----------------------------
# Inception feature/logit extractor
# ----------------------------

class InceptionExtractor(nn.Module):
    """
    Wraps torchvision Inception v3 to provide:
      - logits (for IS)
      - pool3 features (2048-d, for FID/KID)
    """
    def __init__(self, device: torch.device):
        super().__init__()
        weights = Inception_V3_Weights.IMAGENET1K_V1
        # IMPORTANT: aux_logits must be True when using weights in recent torchvision
        self.model = inception_v3(weights=weights, aux_logits=True, transform_input=False)
        self.model.eval()
        self.model.to(device)
        # Hook to capture features after avgpool (N, 2048, 1, 1)
        self._features = None

        def hook_fn(module, inp, out):
            self._features = out

        self.model.avgpool.register_forward_hook(hook_fn)

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
          logits: (N, 1000)
          feats: (N, 2048)
        """
        out = self.model(x)
        if isinstance(out, InceptionOutputs):
            logits = out.logits  # ignore aux_logits
        else:
            logits = out  # some versions may return tensor directly
        feats = self._features
        if feats is None:
            raise RuntimeError("Inception avgpool hook did not run.")
        feats = torch.flatten(feats, 1)  # (N, 2048)
        return logits, feats


# ----------------------------
# Metric computations
# ----------------------------

def compute_inception_score(logits: np.ndarray, splits: int = 10) -> Tuple[float, float]:
    """
    logits: (N, 1000) raw logits from Inception v3 for generated images
    Returns: (IS_mean, IS_std)
    """
    probs = softmax_np(logits)  # (N, 1000)
    N = probs.shape[0]
    if splits <= 0 or splits > N:
        splits = max(1, min(10, N))
    split_size = N // splits
    scores = []
    for k in range(splits):
        part = probs[k * split_size:(k + 1) * split_size, :]
        if part.size == 0:
            continue
        p_y = np.mean(part, axis=0, keepdims=True)  # (1, 1000)
        kl = part * (np.log(part + 1e-10) - np.log(p_y + 1e-10))  # (m, 1000)
        kl = np.sum(kl, axis=1)  # (m,)
        score = np.exp(np.mean(kl))
        scores.append(score)
    return float(np.mean(scores)), float(np.std(scores))


def softmax_np(x: np.ndarray) -> np.ndarray:
    x = x - np.max(x, axis=1, keepdims=True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis=1, keepdims=True)


def compute_fid(feats_real: np.ndarray, feats_fake: np.ndarray, eps: float = 1e-6) -> float:
    """
    FID between two sets of Inception pool3 features (N, 2048)
    """
    mu1 = np.mean(feats_real, axis=0)
    mu2 = np.mean(feats_fake, axis=0)
    sigma1 = np.cov(feats_real, rowvar=False)
    sigma2 = np.cov(feats_fake, rowvar=False)

    diff = mu1 - mu2
    # Product might be nearly singular; add eps*I for stability
    offset = np.eye(sigma1.shape[0]) * eps
    covmean, _ = linalg.sqrtm((sigma1 + offset).dot((sigma2 + offset)), disp=False)
    if not np.isfinite(covmean).all():
        # Fallback: increase eps
        offset = np.eye(sigma1.shape[0]) * (eps * 10)
        covmean, _ = linalg.sqrtm((sigma1 + offset).dot((sigma2 + offset)), disp=False)

    # Real part in case of tiny imaginary residuals
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return float(fid)


def polynomial_mmd2_unbiased(X: np.ndarray, Y: np.ndarray, degree: int = 3,
                             gamma: Optional[float] = None, coef0: float = 1.0) -> float:
    """
    Unbiased MMD^2 estimator with polynomial kernel: k(x,y) = (gamma * x^T y + coef0)^degree
    X, Y: (m, d), (n, d)
    """
    X = X.astype(np.float64, copy=False)
    Y = Y.astype(np.float64, copy=False)
    m, d = X.shape
    n, _ = Y.shape
    if gamma is None:
        gamma = 1.0 / d

    Kxx = ((gamma * (X @ X.T)) + coef0) ** degree
    Kyy = ((gamma * (Y @ Y.T)) + coef0) ** degree
    Kxy = ((gamma * (X @ Y.T)) + coef0) ** degree

    # Remove diagonal for unbiased estimate
    sum_Kxx = (np.sum(Kxx) - np.sum(np.diag(Kxx))) / (m * (m - 1)) if m > 1 else 0.0
    sum_Kyy = (np.sum(Kyy) - np.sum(np.diag(Kyy))) / (n * (n - 1)) if n > 1 else 0.0
    sum_Kxy = np.mean(Kxy)

    mmd2 = sum_Kxx + sum_Kyy - 2.0 * sum_Kxy
    return float(mmd2)


def compute_kid(feats_real: np.ndarray, feats_fake: np.ndarray,
                subsets: int = 50, subset_size: int = 1000,
                degree: int = 3, gamma: Optional[float] = None, coef0: float = 1.0,
                seed: int = 123) -> Tuple[float, float]:
    """
    KID = mean and std of MMD^2 over multiple random subsets (as in TF-GAN)
    """
    rng = np.random.default_rng(seed)
    n_r, n_f = feats_real.shape[0], feats_fake.shape[0]
    m = min(subset_size, n_r, n_f)
    if m < 2:
        raise ValueError("Not enough samples for KID. Need at least 2 in both real and fake.")

    vals = []
    for _ in range(subsets):
        idx_r = rng.choice(n_r, size=m, replace=False)
        idx_f = rng.choice(n_f, size=m, replace=False)
        X = feats_real[idx_r]
        Y = feats_fake[idx_f]
        mmd2 = polynomial_mmd2_unbiased(X, Y, degree=degree, gamma=gamma, coef0=coef0)
        vals.append(mmd2)
    vals = np.array(vals, dtype=np.float64)
    return float(np.mean(vals)), float(np.std(vals, ddof=1) if len(vals) > 1 else 0.0)


# ----------------------------
# Feature/logit extraction
# ----------------------------

@torch.no_grad()
def extract_logits_and_feats(dataloader: DataLoader, extractor: InceptionExtractor,
                             device: torch.device, want_logits: bool, want_feats: bool
                             ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    logits_list = []
    feats_list = []
    for batch in tqdm(dataloader, desc="Inception pass", leave=False):
        batch = batch.to(device, non_blocking=True)
        logits, feats = extractor(batch)
        if want_logits:
            logits_list.append(logits.cpu().numpy())
        if want_feats:
            feats_list.append(feats.cpu().numpy())

    logits_arr = np.concatenate(logits_list, axis=0) if want_logits else None
    feats_arr = np.concatenate(feats_list, axis=0) if want_feats else None
    return logits_arr, feats_arr


def make_loader(path: Optional[str], batch_size: int, num_workers: int, image_size: int) -> Optional[DataLoader]:
    if path is None:
        return None
    ds = ImageFolderFlat(path, image_size=image_size)
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False, 
    )
    return loader


# ----------------------------
# CLI and main
# ----------------------------

def main():
    parser = argparse.ArgumentParser(description="Compute IS, FID, and KID for generated images.")
    parser.add_argument("--real-dir", type=str, default=None, help="Directory with real images.")
    parser.add_argument("--fake-dir", type=str, required=True, help="Directory with generated images.")
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--image-size", type=int, default=299, help="Resize images to this square size for Inception.")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--is-splits", type=int, default=10)
    parser.add_argument("--kid-subsets", type=int, default=50)
    parser.add_argument("--kid-subset-size", type=int, default=1000)
    parser.add_argument("--skip-fid-kid", action="store_true", help="Only compute IS on fake-dir.")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    # Reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    device = torch.device(args.device)
    extractor = InceptionExtractor(device)

    # Loaders
    fake_loader = make_loader(args.fake_dir, args.batch_size, args.num_workers, args.image_size)
    if fake_loader is None:
        raise ValueError("fake-dir is required.")
    real_loader = None if args.skip_fid_kid or args.real_dir is None else make_loader(args.real_dir, args.batch_size, args.num_workers, args.image_size)

    # Extract logits for IS (fake only), features for FID/KID
    want_logits = True
    want_feats = not args.skip_fid_kid and real_loader is not None

    print("Collecting logits/features from fake images...")
    fake_logits, fake_feats = extract_logits_and_feats(fake_loader, extractor, device, want_logits=True, want_feats=want_feats)

    # Inception Score (fake only)
    is_mean, is_std = compute_inception_score(fake_logits, splits=args.is_splits)
    print(f"Inception Score (IS): {is_mean:.6f} ± {is_std:.6f} (splits={args.is_splits}, N={fake_logits.shape[0]})")

    if not args.skip_fid_kid:
        if real_loader is None:
            raise ValueError("To compute FID/KID, provide --real-dir or use --skip-fid-kid.")
        print("Collecting features from real images...")
        _, real_feats = extract_logits_and_feats(real_loader, extractor, device, want_logits=False, want_feats=True)

        # FID
        fid = compute_fid(real_feats, fake_feats)
        print(f"FID: {fid:.6f} (N_real={real_feats.shape[0]}, N_fake={fake_feats.shape[0]})")

        # KID (mean ± std) using default polynomial kernel settings
        kid_mean, kid_std = compute_kid(
            real_feats, fake_feats,
            subsets=args.kid_subsets, subset_size=args.kid_subset_size,
            degree=3, gamma=None, coef0=1.0, seed=args.seed
        )
        print(f"KID (MMD^2, poly kernel): {kid_mean:.8f} ± {kid_std:.8f} "
              f"(subsets={args.kid_subsets}, subset_size={args.kid_subset_size})")


if __name__ == "__main__":
    main()