import torch
import dnnlib
import pickle
import numpy as np
import torch.nn.functional as F
from scipy import linalg

# Preprocessing 

def from_sdf_to_mask(x: torch.Tensor) -> torch.Tensor:
    return (x < 0.).float()

def ensure_nchw_1ch(x: torch.Tensor) -> torch.Tensor:
    """
    Accept:
      - (B,H,W)  -> (B,1,H,W)
      - (B,1,H,W)-> (B,1,H,W)
    """
    if x.ndim == 3:
        return x.unsqueeze(1)
    if x.ndim == 4 and x.shape[1] == 1:
        return x
    raise ValueError(f"Expected (B,H,W) or (B,1,H,W), got {tuple(x.shape)}")

def upsample_to_64(x_1ch: torch.Tensor, upsample_module=None) -> torch.Tensor:
    if upsample_module is None:
        return x_1ch
    return upsample_module(x_1ch)



def preprocess_to_inception_uint8(
    x_1ch_float_01: torch.Tensor,
    interpolation: str = "bilinear",
    antialias: bool = False,
) -> torch.Tensor:
    """
      - resize to 299x299 with F.interpolate(mode=..., align_corners=False, antialias=...)
      - if 1ch: repeat to 3ch
      - (x*255).clamp(0,255).uint8
    Input MUST be float in [0,1] with shape (B,1,H,W) or (B,3,H,W).
    Output: uint8 (B,3,299,299)
    """
    if x_1ch_float_01.dtype not in (torch.float16, torch.float32, torch.float64):
        raise ValueError(f"Expected float tensor, got {x_1ch_float_01.dtype}")

    # optional sanity: repo assumes [0,1] before scaling to 255
    # (do not clamp here unless you want to deviate; I keep it strict)
    if torch.isfinite(x_1ch_float_01).all():
        mn = float(x_1ch_float_01.min().item())
        mx = float(x_1ch_float_01.max().item())
        if mn < -1e-6 or mx > 1.0 + 1e-6:
            raise ValueError(f"Expected x in [0,1], got min={mn}, max={mx}")

    # resize to 299x299 (repo tensor-mode path)
    x = F.interpolate(
        x_1ch_float_01,
        size=(299, 299),
        mode=interpolation,
        align_corners=False,
        antialias=antialias,
    )

    # gray -> rgb
    if x.shape[1] == 1:
        x = x.repeat(1, 3, 1, 1)

    # float [0,1] -> uint8 [0,255]
    x = (x * 255.0).clamp(0, 255).to(torch.uint8)
    return x




# Feature detector 
_feature_detector_cache = dict()

def get_feature_detector(url, cache_dir=None, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
    assert 0 <= rank < num_gpus
    key = (url, device)
    if key not in _feature_detector_cache:
        is_leader = (rank == 0)
        if not is_leader and num_gpus > 1:
            torch.distributed.barrier() # leader goes first
        with dnnlib.util.open_url(url, cache_dir=cache_dir, verbose=(verbose and is_leader)) as f:
            _feature_detector_cache[key] = pickle.load(f).to(device)
        if is_leader and num_gpus > 1:
            torch.distributed.barrier() # others follow
    return _feature_detector_cache[key]




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

detector = get_feature_detector(
    url="https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl",
    cache_dir="./fid_cache",
    device=device,
    num_gpus=1,
    rank=0,
    verbose=True,
)
detector.eval()



# Based on data loader and detector, collect features of real data set

@torch.no_grad()
def collect_real_features(
    detector,
    real_loader,
    device,
    num_real: int,
    upsample_module=None,
    interpolation="bilinear",
    antialias=False,
    print_every=1000,          # print every this many real samples
):
    feats = []
    seen = 0
    last_print = 0

    for sdf_imgs, _labels in real_loader:
        if seen >= num_real:
            break

        sdf_imgs = sdf_imgs.to(device)  # (B,1,32,32)

        if upsample_module is not None:
            sdf_imgs = upsample_module(sdf_imgs)  # (B,1,64,64)

        mask = (sdf_imgs < 0.).float()  # (B,1,H,W)

        x_uint8 = preprocess_to_inception_uint8(
            mask,
            interpolation=interpolation,
            antialias=antialias,
        )  # (B,3,299,299) uint8

        f = detector(x_uint8, return_features=True)  # (B,2048)
        feats.append(f.cpu())
        seen += f.shape[0]

        if seen - last_print >= print_every or seen >= num_real:
            print(f"[real FID] processed {min(seen, num_real)}/{num_real}")
            last_print = seen

    feats = torch.cat(feats, dim=0)[:num_real]
    return feats

# Based on sample generator and detector, collect features of real data set

@torch.no_grad()
def collect_gen_features(
    detector,
    gen_sampler_fn,               # returns (B,64,64)
    device,
    num_gen: int,
    interpolation="bilinear",
    antialias=False,
    print_every=1000,             # print every this many samples
):
    feats = []
    seen = 0
    last_print = 0

    while seen < num_gen:
        gen_imgs = gen_sampler_fn().to(device)       # (B,64,64)

        mask = (gen_imgs.unsqueeze(1) < 0.).float()  # (B,1,64,64)

        x_uint8 = preprocess_to_inception_uint8(
            mask,
            interpolation=interpolation,
            antialias=antialias,
        )                                            # (B,3,299,299) uint8

        f = detector(x_uint8, return_features=True)  # (B,2048)
        feats.append(f.cpu())
        seen += f.shape[0]

        if seen - last_print >= print_every or seen >= num_gen:
            print(f"[gen FID] processed {min(seen, num_gen)}/{num_gen}")
            last_print = seen

    feats = torch.cat(feats, dim=0)[:num_gen]
    return feats



def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)




def fid_from_features(real_features: torch.Tensor, gen_features: torch.Tensor) -> float:
    real_np = real_features.numpy()
    gen_np  = gen_features.numpy()

    m0 = np.mean(real_np, axis=0)
    s0 = np.cov(real_np, rowvar=False)

    m  = np.mean(gen_np, axis=0)
    s  = np.cov(gen_np, rowvar=False)

    return float(calculate_frechet_distance(m0, s0, m, s))




def compute_fid_end2end(
    real_loader,
    gen_sampler_fn,
    device,
    num_real=10000,
    num_gen=5000,
    upsample_module=None,
    detector_cache_dir="./fid_cache",
    interpolation="bilinear",
    antialias=False,
):
    detector_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl"

    detector = get_feature_detector(
        url=detector_url,
        cache_dir=detector_cache_dir,
        device=device,
        num_gpus=1,
        rank=0,
        verbose=True,
    )
    detector.eval()

    real_features = collect_real_features(
        detector,
        real_loader,
        device=device,
        num_real=num_real,
        upsample_module=upsample_module,
        interpolation=interpolation,
        antialias=antialias,
    )

    gen_features = collect_gen_features(
        detector,
        gen_sampler_fn,
        device=device,
        num_gen=num_gen,
        interpolation=interpolation,
        antialias=antialias,
    )

    # sanity
    assert torch.isfinite(real_features).all(), "real features contain NaN/Inf"
    assert torch.isfinite(gen_features).all(),  "gen features contain NaN/Inf"

    fid = fid_from_features(real_features, gen_features)
    return fid, real_features, gen_features
