"""
Diagnostic script to test each distance calculation algorithm individually.
"""
import numpy as np
import ot
from utils.pre_processing import read_dotmark_image, downscale_grayscale_images, calculate_costs, noise_and_split_image

# Load a test image
category = 'MicroscopyImages'
resolution_og = 32
resolution = 32
images = [read_dotmark_image(category, resolution_og, 1)]
images_curr = downscale_grayscale_images(images, resolution)
image = images_curr[0]

# Create cost matrix
cost_matrix = calculate_costs((resolution, resolution), metric='euclidean', cyclic=True)

# Test with a moderate noise level
noise_std = 0.01
print(f"Testing with noise_std = {noise_std}")
print(f"Image shape: {image.shape}, sum: {image.sum():.6f}")

# Generate noisy version
noisy_image, pos, neg = noise_and_split_image(image, noise_std)
print(f"\nAfter noise_and_split_image:")
print(f"  pos sum: {pos.sum():.6f}, min: {pos.min():.6f}, max: {pos.max():.6f}")
print(f"  neg sum: {neg.sum():.6f}, min: {neg.min():.6f}, max: {neg.max():.6f}")

# Prepare distributions for comparison
dist1 = (image + neg).flatten()
dist2 = pos.flatten()

print(f"\nDistribution 1 (image + neg):")
print(f"  sum: {dist1.sum():.6f}, min: {dist1.min():.6f}, max: {dist1.max():.6f}")
print(f"  any negative: {(dist1 < 0).any()}")

print(f"\nDistribution 2 (pos):")
print(f"  sum: {dist2.sum():.6f}, min: {dist2.min():.6f}, max: {dist2.max():.6f}")
print(f"  any negative: {(dist2 < 0).any()}")
print(f"  number of zeros in dist1: {(dist1 == 0).sum()}")
print(f"  number of zeros in dist2: {(dist2 == 0).sum()}")

# Test each algorithm
print("\n" + "="*60)
print("TESTING ALGORITHMS")
print("="*60)

# 1. Standard EMD (W1)
print("\n1. Standard EMD (W1):")
try:
    cost_p = ot.emd2(dist1, dist2, cost_matrix)
    result = cost_p ** (1 / 1)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W1 = {result:.10f}")
    print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

# 2. Standard EMD (W2)
print("\n2. Standard EMD (W2):")
try:
    cost_p = ot.emd2(dist1, dist2, cost_matrix ** 2)
    result = cost_p ** (1 / 2)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W2 = {result:.10f}")
    print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

# 3. Partial Wasserstein (POT) - W1
print("\n3. Partial Wasserstein (POT) - W1:")
try:
    # Test if distributions need to be normalized for POT
    print(f"   dist1 sum: {dist1.sum():.6f}, dist2 sum: {dist2.sum():.6f}")
    print(f"   Testing with m=0.95...")
    cost_p = ot.partial.partial_wasserstein2(dist1, dist2, cost_matrix, m=0.95)
    result = cost_p ** (1 / 1)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W1_pot = {result:.10f}")
    
    # Try with normalized distributions
    dist1_norm = dist1 / dist1.sum()
    dist2_norm = dist2 / dist2.sum()
    cost_p_norm = ot.partial.partial_wasserstein2(dist1_norm, dist2_norm, cost_matrix, m=0.95)
    print(f"   With normalized: cost_p_norm = {cost_p_norm:.10f}")
    
    if result == 0.0:
        print(f"   ⚠ WARNING: Result is exactly zero!")
    else:
        print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

# 4. Partial Wasserstein (POT) - W2
print("\n4. Partial Wasserstein (POT) - W2:")
try:
    cost_p = ot.partial.partial_wasserstein2(dist1, dist2, cost_matrix ** 2, m=0.95)
    result = cost_p ** (1 / 2)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W2_pot = {result:.10f}")
    if result == 0.0:
        print(f"   ⚠ WARNING: Result is exactly zero!")
    else:
        print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

# 5. Sinkhorn - W1
print("\n5. Sinkhorn - W1:")
try:
    cost_p = ot.sinkhorn2(dist1, dist2, cost_matrix, reg=1e-2)
    result = cost_p ** (1 / 1)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W1_sinkhorn = {result:.10f}")
    if result == 0.0:
        print(f"   ⚠ WARNING: Result is exactly zero!")
    else:
        print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

# 6. Sinkhorn - W2
print("\n6. Sinkhorn - W2:")
try:
    cost_p = ot.sinkhorn2(dist1, dist2, cost_matrix ** 2, reg=1e-2)
    result = cost_p ** (1 / 2)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W2_sinkhorn = {result:.10f}")
    if result == 0.0:
        print(f"   ⚠ WARNING: Result is exactly zero!")
    else:
        print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

# 7. Unbalanced OT (UOT) - W1
print("\n7. Unbalanced OT (UOT) - W1:")
try:
    cost_p = ot.unbalanced.sinkhorn_unbalanced2(dist1, dist2, cost_matrix, reg=1e-2, reg_m=1e-1)
    result = cost_p ** (1 / 1)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W1_uot = {result:.10f}")
    if result == 0.0:
        print(f"   ⚠ WARNING: Result is exactly zero!")
    else:
        print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

# 8. Unbalanced OT (UOT) - W2
print("\n8. Unbalanced OT (UOT) - W2:")
try:
    cost_p = ot.unbalanced.sinkhorn_unbalanced2(dist1, dist2, cost_matrix ** 2, reg=1e-2, reg_m=1e-1)
    result = cost_p ** (1 / 2)
    print(f"   cost_p = {cost_p:.10f}")
    print(f"   W2_uot = {result:.10f}")
    if result == 0.0:
        print(f"   ⚠ WARNING: Result is exactly zero!")
    else:
        print(f"   ✓ SUCCESS")
except Exception as e:
    print(f"   ✗ ERROR: {e}")

print("\n" + "="*60)
print("DIAGNOSIS COMPLETE")
print("="*60)
