import argparse
import os

import numpy as np
import torch
import torchvision
from dataset import CellDataModule, to_rgb
from metrics_utils import (
    OpenPhenomFeatureExtractor,
    TorchFidelityFeatureExtractor,
    calculate_metrics_from_features,
    calculate_metrics_from_scratch,
)
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm

try:
    import torch_fidelity

    TORCH_FIDELITY_AVAILABLE = True
except ImportError:
    TORCH_FIDELITY_AVAILABLE = False
    print("Warning: torch-fidelity not available. Some comparison tests won't run.")


def test_metric_calculations(real_images, fake_images, batch_size=8):
    """
    Test metrics calculations comparing manual implementation with library implementations.

    Args:
        real_images: Tensor of real images [B, C, H, W]
        fake_images: Tensor of generated images [B, C, H, W]
        batch_size: Batch size for processing

    Returns:
        Dictionary with comparison results
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Move images to device
    real_images = real_images.to(device)
    fake_images = fake_images.to(device)

    print(f"Real images shape: {real_images.shape}")
    print(f"Fake images shape: {fake_images.shape}")

    # Save a few samples for visual inspection
    os.makedirs("test_samples", exist_ok=True)

    if real_images.shape[1] == 6:
        # Save RGB converted versions
        save_image(
            to_rgb(real_images[:4].cpu()),
            "test_samples/real_rgb.png",
            nrow=2,
            normalize=True,
        )
        save_image(
            to_rgb(fake_images[:4].cpu()),
            "test_samples/fake_rgb.png",
            nrow=2,
            normalize=True,
        )

        # Save individual channels from first image
        for i in range(6):
            channel = real_images[0, i : i + 1].repeat(1, 3, 1, 1).cpu()
            save_image(channel, f"test_samples/real_channel_{i}.png", normalize=True)

            channel = fake_images[0, i : i + 1].repeat(1, 3, 1, 1).cpu()
            save_image(channel, f"test_samples/fake_channel_{i}.png", normalize=True)
    else:
        # Save images directly
        save_image(real_images[:4], "test_samples/real.png", nrow=2, normalize=True)
        save_image(fake_images[:4], "test_samples/fake.png", nrow=2, normalize=True)

    results = {}

    # Test with OpenPhenom feature extractor if available
    print("\nTesting with OpenPhenom feature extractor:")
    # Compute metrics with OpenPhenom
    print("Computing metrics with OpenPhenom...")
    metrics_openphenom = calculate_metrics_from_scratch(
        real_images,
        fake_images,
        batch_size=batch_size,
        feature_extractor="openphenom",
    )

    # Unpack results
    fid, kid_mean, kid_std = metrics_openphenom

    results["openphenom"] = {"fid": fid, "kid_mean": kid_mean, "kid_std": kid_std}

    print(f"OpenPhenom FID: {fid:.4f}")
    print(f"OpenPhenom KID: {kid_mean:.6f} ± {kid_std:.6f}")

    # Compare channel-specific metrics for 6-channel images
    if real_images.shape[1] == 6:
        print("\nTesting individual channel metrics:")
        channel_results = {}

        for i in range(6):
            print(f"\nChannel {i}:")

            # Extract single channel and repeat 3 times
            real_channel = real_images[:, i : i + 1].repeat(1, 3, 1, 1)
            fake_channel = fake_images[:, i : i + 1].repeat(1, 3, 1, 1)

            # Compute metrics for this channel
            channel_metrics = calculate_metrics_from_scratch(
                real_channel,
                fake_channel,
                batch_size=batch_size,
                feature_extractor="inception_v3",
            )

            fid, kid_mean, kid_std = channel_metrics

            channel_results[f"channel_{i}"] = {
                "fid": fid,
                "kid_mean": kid_mean,
                "kid_std": kid_std,
            }

            print(f"Channel {i} FID: {fid:.4f}")
            print(f"Channel {i} KID: {kid_mean:.6f} ± {kid_std:.6f}")

        results["channels"] = channel_results

    return results


def load_cifar_data(num_samples=1000, data_dir="./data"):
    """
    Load CIFAR-10 images for testing metrics

    Args:
        num_samples: Number of images to load
        data_dir: Directory to store CIFAR-10 dataset

    Returns:
        Two sets of images (real and with noise added)
    """
    # Load CIFAR-10
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    cifar_dataset = datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform
    )

    # Randomly select samples
    indices = np.random.choice(len(cifar_dataset), num_samples, replace=False)

    # Get real samples
    real_samples = []
    for idx in indices:
        img, _ = cifar_dataset[idx]
        real_samples.append(img)

    real_samples = torch.stack(real_samples)

    # Create fake samples by adding noise
    noise = torch.randn_like(real_samples) * 0.1
    fake_samples = real_samples + noise
    fake_samples = torch.clamp(fake_samples, 0, 1)

    return real_samples, fake_samples


def test_metrics_with_cifar(real_images, fake_images, batch_size=50):
    """
    Test metrics calculations using CIFAR-10 images, comparing manual and torch-fidelity implementations

    Args:
        real_images: Tensor of real CIFAR images [B, C, H, W]
        fake_images: Tensor of fake CIFAR images [B, C, H, W]
        batch_size: Batch size for processing

    Returns:
        Dictionary with comparison results
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Move images to device
    real_images = real_images.to(device)
    fake_images = fake_images.to(device)

    print(f"Real images shape: {real_images.shape}")
    print(f"Fake images shape: {fake_images.shape}")

    results = {}

    # Test with manual calculation
    print("\nTesting with manual metric calculation:")
    try:
        # Initialize the feature extractor and log its configuration
        print("Initializing Inception v3 feature extractor for manual calculation...")
        from torch_fidelity.feature_extractor_inceptionv3 import (
            FeatureExtractorInceptionV3,
        )

        feature_extractor = (
            FeatureExtractorInceptionV3(
                name="inception-v3-compat", features_list=["2048"]
            )
            .to(device)
            .eval()
        )
        print(f"Feature extractor model: inception-v3-compat")
        print(f"Feature dimension: 2048")

        # Log all settings to identify sources of differences
        from torch_fidelity.defaults import DEFAULTS

        print(f"Torch-fidelity defaults:")
        for key, value in DEFAULTS.items():
            print(f"  {key}: {value}")

        # Compute metrics with Inception V3
        print("Computing metrics with manual implementation...")
        metrics_manual = calculate_metrics_from_scratch(
            real_images,
            fake_images,
            batch_size=batch_size,
            feature_extractor="inception_v3",
        )

        # Unpack results
        fid_manual, kid_mean_manual, kid_std_manual = metrics_manual

        results["manual"] = {
            "fid": fid_manual,
            "kid_mean": kid_mean_manual,
            "kid_std": kid_std_manual,
        }

        print(f"Manual FID: {fid_manual:.4f}")
        print(f"Manual KID: {kid_mean_manual:.6f} ± {kid_std_manual:.6f}")

    except Exception as e:
        print(f"Error testing manual metrics: {e}")
        results["manual"] = {"error": str(e)}

    # Test with torch-fidelity
    if TORCH_FIDELITY_AVAILABLE:
        print("\nTesting with torch-fidelity:")
        try:
            # Log torch-fidelity configuration
            print("Checking torch-fidelity configuration...")

            # Get torch-fidelity version
            import pkg_resources

            try:
                tf_version = pkg_resources.get_distribution("torch-fidelity").version
                print(f"torch-fidelity version: {tf_version}")
            except:
                print("Could not determine torch-fidelity version")

            # Check default configuration
            print(f"Default feature extractor: {DEFAULTS['feature_extractor']}")
            # print(f"Default feature layer: {DEFAULTS['feature_layer_inception_v3']}")

            # Save images to directories for torch-fidelity
            real_dir = "test_samples/fid_real"
            fake_dir = "test_samples/fid_fake"
            os.makedirs(real_dir, exist_ok=True)
            os.makedirs(fake_dir, exist_ok=True)

            # Create exact same feature cache for both implementations for exact matching
            real_cache = np.zeros((len(real_images), 2048), dtype=np.float64)
            fake_cache = np.zeros((len(fake_images), 2048), dtype=np.float64)

            # Extract features with our manual method
            print("Extracting features with our feature extractor...")

            # Convert images to the format expected by our feature extractor
            real_uint8 = (real_images * 255).to(torch.uint8)
            fake_uint8 = (fake_images * 255).to(torch.uint8)

            manual_extractor = TorchFidelityFeatureExtractor(device=device)
            with torch.no_grad():
                real_features = manual_extractor.extract_features(
                    real_images, batch_size=batch_size
                )
                fake_features = manual_extractor.extract_features(
                    fake_images, batch_size=batch_size
                )

                # Store in cache
                real_cache = real_features
                fake_cache = fake_features

            # Save the feature caches to disk
            real_cache_path = os.path.join("test_samples", "real_features.npy")
            fake_cache_path = os.path.join("test_samples", "fake_features.npy")
            np.save(real_cache_path, real_cache)
            np.save(fake_cache_path, fake_cache)

            # Convert to numpy arrays and save as PNG files for torch-fidelity
            print("Saving images for torch-fidelity...")
            for i in range(len(real_images)):
                real_img = real_images[i].cpu()
                fake_img = fake_images[i].cpu()

                # Save as PNG
                save_image(real_img, f"{real_dir}/{i:05d}.png", normalize=True)
                save_image(fake_img, f"{fake_dir}/{i:05d}.png", normalize=True)

            # Calculate torch-fidelity metrics using the directories
            print("Computing metrics with torch-fidelity...")
            metrics_tf = torch_fidelity.calculate_metrics(
                input1=real_dir,
                input2=fake_dir,
                cuda=True,
                fid=True,
                kid=True,
                verbose=False,
            )

            results["torch_fidelity"] = {
                "fid": metrics_tf["frechet_inception_distance"],
                "kid_mean": metrics_tf["kernel_inception_distance_mean"],
                "kid_std": metrics_tf["kernel_inception_distance_std"],
            }

            # Now compute metrics with the exact same feature cache
            print(
                "\nComputing metrics with the same feature cache for both implementations..."
            )
            fid_exact, kid_mean_exact, kid_std_exact = calculate_metrics_from_features(
                real_cache, fake_cache, kid_subsets=100, kid_subset_size=1000
            )

            results["exact_match"] = {
                "fid": fid_exact,
                "kid_mean": kid_mean_exact,
                "kid_std": kid_std_exact,
            }

            print(f"torch-fidelity FID: {metrics_tf['frechet_inception_distance']:.4f}")
            print(
                f"torch-fidelity KID: {metrics_tf['kernel_inception_distance_mean']:.6f} ± {metrics_tf['kernel_inception_distance_std']:.6f}"
            )
            print(f"Exact match FID: {fid_exact:.4f}")
            print(f"Exact match KID: {kid_mean_exact:.6f} ± {kid_std_exact:.6f}")

            # Calculate difference
            fid_diff = abs(fid_manual - metrics_tf["frechet_inception_distance"])
            kid_diff = abs(
                kid_mean_manual - metrics_tf["kernel_inception_distance_mean"]
            )

            # Calculate difference with exact match
            fid_diff_exact = abs(fid_exact - metrics_tf["frechet_inception_distance"])
            kid_diff_exact = abs(
                kid_mean_exact - metrics_tf["kernel_inception_distance_mean"]
            )

            print(f"\nDifference between implementations:")
            print(f"FID difference: {fid_diff:.4f}")
            print(f"KID difference: {kid_diff:.6f}")
            print(f"FID difference (exact match): {fid_diff_exact:.10f}")
            print(f"KID difference (exact match): {kid_diff_exact:.10f}")

            # Evaluate if the difference is acceptable
            fid_match = fid_diff < 1.0
            kid_match = kid_diff < 0.01

            # Evaluate if exact match is really exact
            fid_exact_match = fid_diff_exact < 1e-10
            kid_exact_match = kid_diff_exact < 1e-10

            results["comparison"] = {
                "fid_diff": fid_diff,
                "kid_diff": kid_diff,
                "fid_match": fid_match,
                "kid_match": kid_match,
                "fid_diff_exact": fid_diff_exact,
                "kid_diff_exact": kid_diff_exact,
                "fid_exact_match": fid_exact_match,
                "kid_exact_match": kid_exact_match,
            }

            if fid_match:
                print("✅ FID calculations match within acceptable tolerance")
            else:
                print("❌ FID calculations differ significantly")

            if kid_match:
                print("✅ KID calculations match within acceptable tolerance")
            else:
                print("❌ KID calculations differ significantly")

            if fid_exact_match:
                print("✅ FID calculations match exactly when using same features")
            else:
                print("❌ FID calculations differ even with same features")

            if kid_exact_match:
                print("✅ KID calculations match exactly when using same features")
            else:
                print("❌ KID calculations differ even with same features")

        except Exception as e:
            print(f"Error testing torch-fidelity metrics: {e}")
            results["torch_fidelity"] = {"error": str(e)}
    else:
        print("Skipping torch-fidelity comparison as the library is not available")

    return results


def main():
    parser = argparse.ArgumentParser(description="Test metrics calculations")
    parser.add_argument(
        "--num_samples", type=int, default=500, help="Number of samples to use"
    )
    parser.add_argument(
        "--batch_size", type=int, default=8, help="Batch size for processing"
    )
    parser.add_argument(
        "--perturbation_id",
        type=int,
        default=None,
        help="Optional perturbation ID to filter",
    )
    parser.add_argument(
        "--use_cifar", action="store_true", help="Use CIFAR-10 dataset for testing"
    )
    parser.add_argument(
        "--cifar_samples", type=int, default=1000, help="Number of CIFAR samples to use"
    )
    parser.add_argument(
        "--data_dir", type=str, default="./data", help="Directory to store datasets"
    )

    args = parser.parse_args()

    # Choose between CIFAR-10 test or cell image test
    if args.use_cifar:
        print(f"Loading CIFAR-10 dataset with {args.cifar_samples} samples...")
        real_samples, fake_samples = load_cifar_data(
            num_samples=args.cifar_samples, data_dir=args.data_dir
        )

        # Run the CIFAR test that compares with torch-fidelity
        print(
            "Running CIFAR-10 test comparing manual calculations with torch-fidelity..."
        )
        results = test_metrics_with_cifar(
            real_samples, fake_samples, batch_size=args.batch_size
        )
    else:
        # Original cell image test
        print("Loading cell image data module...")
        from omegaconf import OmegaConf

        config = OmegaConf.load("./diffusion_sit.yaml")
        data_module = CellDataModule(config)

        # Get real samples
        if args.perturbation_id is not None:
            print(f"Filtering samples for perturbation ID {args.perturbation_id}...")
            real_samples, _ = data_module.filter_samples(
                perturbation_id=args.perturbation_id, num_samples=args.num_samples
            )
        else:
            print("Getting random samples from training set...")
            dataset = data_module.get_train_dataset()
            indices = np.random.choice(len(dataset), args.num_samples, replace=False)
            real_samples = []
            for i in indices:
                sample, _, _ = dataset[i]
                real_samples.append(sample)
            real_samples = torch.stack(real_samples)

        # Create fake samples by adding noise to real samples
        print("Generating fake samples by adding noise to real samples...")
        noise = torch.randn_like(real_samples) * 0.1
        fake_samples = real_samples + noise
        fake_samples = torch.clamp(fake_samples, 0, 1)

        # Run the test on cell images
        results = test_metric_calculations(
            real_samples, fake_samples, batch_size=args.batch_size
        )

    # Print summary
    print("\n=== Test Summary ===")
    for extractor, metrics in results.items():
        print(f"\n{extractor.upper()} Results:")
        for name, value in metrics.items():
            if isinstance(value, dict):
                print(f"  {name}:")
                for subname, subvalue in value.items():
                    print(f"    {subname}: {subvalue}")
            else:
                print(f"  {name}: {value}")

    # Print overall verdict if we ran the CIFAR test with both implementations
    if args.use_cifar and "comparison" in results:
        print("\n=== OVERALL VERDICT ===")
        all_match = (
            results["comparison"]["fid_match"] and results["comparison"]["kid_match"]
        )
        if all_match:
            print(
                "✅ SUCCESS: Manual calculations match torch-fidelity within acceptable tolerance"
            )
        else:
            print(
                "❌ FAIL: Manual calculations differ significantly from torch-fidelity"
            )
            if not results["comparison"]["fid_match"]:
                print(f"  - FID difference: {results['comparison']['fid_diff']:.4f}")
            if not results["comparison"]["kid_match"]:
                print(f"  - KID difference: {results['comparison']['kid_diff']:.6f}")


if __name__ == "__main__":
    main()
