import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib import gridspec

from utils.pre_processing import calculate_costs, downscale_grayscale_images, noise_image
from utils.calculations import measure_noise_effects_on_image_pair, run_image_pair_experiment

# Increase global font sizes for better readability
plt.rcParams.update({'font.size': 14})

images_path = "results/rotation experiment/extracted_images_normalized"
results_dir = 'results/rotation experiment' # Standardize results directory
resolution = 32
num_exp = 50
n_parallel = 20
force_eval = False # Default to False to save time if data exists
full_path = os.path.join(os.getcwd(), images_path)

# Ensure results directory exists
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

images = os.listdir(full_path)
images = [img for img in images if img.endswith('.png')]
images.sort()
images = [np.array(Image.open(os.path.join(full_path, img)).convert('L')) for img in images]
images_curr = downscale_grayscale_images(images, resolution)
images_curr = images_curr[::2]

for i, image in enumerate(images_curr):
    image = np.pad(image, ((16, 16), (16, 16)), mode='constant', constant_values=0)
    image_cropped = image[i + 6:i + 38, i + 6:i + 38]
    images_curr[i] = image_cropped / image_cropped.sum()

# Show all the images in a 4x5 grid
fig, axs = plt.subplots(4, 5, figsize=(10, 8))
for i, ax in enumerate(axs.flat):
    if i < len(images_curr):
        ax.imshow(images_curr[i], cmap='gray')
        ax.set_title(f'Image {i}', fontsize=16)
    ax.axis('off')
plt.tight_layout()

metrics = ['L2', 'W1', 'W2']
cost_matrix = calculate_costs((resolution, resolution), metric='euclidean', cyclic=True)

noisy_images = []
noise_param = 0.01

# Paths for results
results_path = os.path.join(results_dir, 'distances_between_images_turning_and_moving_noise_01.npy')
results_noisy_path = os.path.join(results_dir, 'distances_between_noisy_images_turning_and_moving_noise_01.npy')

if force_eval or not os.path.exists(results_path) or not os.path.exists(results_noisy_path):
    print("Computing distances (force_eval=True or files missing)...")
    results = np.zeros((len(images_curr), len(images_curr), len(metrics)))
    results_noisy = np.zeros((len(images_curr), len(images_curr), len(metrics)))

    for i, image in enumerate(images_curr):
        noised_image = noise_image(image, noise_param=noise_param)
        noisy_images.append(noised_image)
        print(f"Processing image {i}/{len(images_curr)}...")
        for j, image2 in enumerate(images_curr):
            if i < j:
                continue
            results_df = run_image_pair_experiment(image, image2,
                                                   cost_matrix=cost_matrix,
                                                   noise_std_values=[noise_param],
                                                   num_exp=num_exp,
                                                   n_parallel=n_parallel)

            if np.array_equal(image, image2) and i != j:
                print("Same image, though different instances:", i, j)
            
            results_noisy[i, j, 0] = results_df['noisy_vs_noisy_L2'].values[0]
            results_noisy[j, i, 0] = results_noisy[i, j, 0]
            results_noisy[i, j, 1] = results_df['noisy_vs_noisy_W1'].values[0]
            results_noisy[j, i, 1] = results_noisy[i, j, 1]
            results_noisy[i, j, 2] = results_df['noisy_vs_noisy_W2'].values[0]
            results_noisy[j, i, 2] = results_noisy[i, j, 2]

            results[i, j, 0] = results_df['original_L2'].values[0]
            results[j, i, 0] = results[i, j, 0]
            results[i, j, 1] = results_df['original_W1'].values[0]
            results[j, i, 1] = results[i, j, 1]
            results[i, j, 2] = results_df['original_W2'].values[0]
            results[j, i, 2] = results[i, j, 2]

    # Save the results
    np.save(results_path, results)
    np.save(results_noisy_path, results_noisy)
else:
    print(f"Loading existing results from {results_dir}...")
    results = np.load(results_path)
    results_noisy = np.load(results_noisy_path)


# Generate distance heatmaps: 2 rows (Noiseless, Noisy) x 3 columns (L2, W1, W2)
fig, axs = plt.subplots(2, 3, figsize=(15, 10))

titles = ['L2', 'W1', 'W2']

# Row titles with rotated text
axs[0, 0].text(-0.25, 0.5, "Clean", rotation=90, va='center', ha='center',
                transform=axs[0, 0].transAxes, fontsize=28)
axs[1, 0].text(-0.25, 0.5, "Noisy", rotation=90, va='center', ha='center',
                transform=axs[1, 0].transAxes, fontsize=28)

for k, t in enumerate(titles):
    # Top row: Noiseless
    axs[0, k].imshow(results[:, :, k], cmap='RdBu_r', interpolation='nearest')
    axs[0, k].set_title(t, fontsize=28, pad=15)
    axs[0, k].axis('off')
    
    # Bottom row: Noisy
    axs[1, k].imshow(results_noisy[:, :, k], cmap='RdBu_r', interpolation='nearest')
    axs[1, k].axis('off')

plt.tight_layout()
plt.subplots_adjust(wspace=0.2, hspace=0.1)

file_name = os.path.join(results_dir, 'original_distances_vs_noisy_distances_heatmap.pdf')
plt.savefig(file_name, format='pdf', dpi=1200, bbox_inches='tight')
plt.show()


# Generate side-by-side visualization: 2 rows x 7 cols
noisy_images = []
for i, image in enumerate(images_curr):
    noised_image = noise_image(image, noise_param=noise_param)
    noisy_images.append(noised_image)

num_display = 7
fig, axs = plt.subplots(2, num_display, figsize=(14, 5))

# Row titles with rotated text
axs[0, 0].text(-0.35, 0.5, "Clean", rotation=90, va='center', ha='center',
               transform=axs[0, 0].transAxes, fontsize=14)
axs[1, 0].text(-0.35, 0.5, "Noisy", rotation=90, va='center', ha='center',
               transform=axs[1, 0].transAxes, fontsize=14)

for j in range(num_display):
    # Top row: Clean
    axs[0, j].imshow(images_curr[j], cmap='gray')
    axs[0, j].axis('off')
    
    # Bottom row: Noisy
    axs[1, j].imshow(noisy_images[j], cmap='gray')
    axs[1, j].axis('off')

# Shrink horizontal spacing to around 2x vertical spacing
# default hspace is 0.2, wspace is 0.2. Let's adjust.
plt.subplots_adjust(wspace=0.1, hspace=0.05)

file_name = os.path.join(results_dir, 'images_clean_vs_noisy_7x2.pdf')
plt.savefig(file_name, format='pdf', dpi=1200, bbox_inches='tight')
plt.show()
