import os
from PIL import Image
from tqdm.auto import tqdm
import hdf5plugin
import h5py
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
import lovely_tensors as lt

from chip.datasets.tomogram_dataset import TomogramDataset

# Increase the maximum allowed image size
Image.MAX_IMAGE_PIXELS = None  # This removes the limit entirely
# Or set it to a specific limit, e.g., 300 million pixels
# Image.MAX_IMAGE_PIXELS = 300000000

def create_circular_mask(h, w, center=None, radius=None):
    """ Create a circular mask """
    if center is None:  # use the middle of the image
        center = (int(w / 2), int(h / 2))
    if radius is None:  # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w - center[0], h - center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)

    mask = dist_from_center <= radius
    return mask


def random_crop(image, crop_size=(512, 512), mask_radius=256):
    """ Randomly crop an image and apply a circular mask """
    width, height = image.size
    x = random.randint(0, width - crop_size[0])
    y = random.randint(0, height - crop_size[1])
    cropped_image = image.crop((x, y, x + crop_size[0], y + crop_size[1]))

    mask = create_circular_mask(crop_size[1], crop_size[0], radius=mask_radius)
    cropped_data = np.array(cropped_image)
    cropped_data[~mask] = 0  # Apply the mask

    return cropped_data


def process_images(folder_path, num_crops, h5_file):
    with h5py.File(h5_file, 'w') as h5f:
        chunk_size = (1, 512, 512)
        folder = [filename for filename in os.listdir(folder_path) if filename.endswith('.tif')]
        dataset = h5f.create_dataset(
            name='images',
            shape=(len(folder) * num_crops, chunk_size[1], chunk_size[2]),
            chunks=chunk_size,
            **hdf5plugin.Bitshuffle(nelems=0, cname='lz4')
        )
        """ Process each image in the folder """
        iterative = 0
        for i, filename in tqdm(enumerate(folder)):
            image_path = os.path.join(folder_path, filename)
            image = Image.open(image_path)

            for j in range(num_crops):
                crop_data = random_crop(image)
                dataset[iterative] = crop_data.astype(float)
                iterative += 1


# Example usage
folder_path = 'data/DATASET_G7_170um_10nm_rect'  # Replace with your folder path
num_crops = 20000  # Number of random crops per image
h5_file = 'data/tomograms_blueprint.h5'  # Output H5 file
process_images(folder_path, num_crops, h5_file)

lt.monkey_patch()


def display_random_images(h5_file, num_images=5):
    """ Display random images from an H5 file """
    with h5py.File(h5_file, 'r') as h5f:
        random_keys = np.random.randint(0, len(h5f['images']) - 1, (num_images,))
        fig, axes = plt.subplots(1, num_images, figsize=(20, 4))
        for i, key in enumerate(random_keys):
            image = h5f['images'][key]
            print(torch.tensor(image))
            axes[i].imshow(image, cmap='gray')
            axes[i].axis('off')
            axes[i].set_title(f'Image: {key}')

        plt.show()


# Example usage
display_random_images(h5_file, num_images=5)
ds = TomogramDataset(h5_file)
trainSet = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
testSet = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))


# old version:
#
# h5_tomogram = h5py.File(h5_file, 'r')
# ds = h5_tomogram.get('images')
#
# trainSet_raw = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
# testSet_raw = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))
# trainSet = TomogramDataset(trainSet_raw)
# testSet = TomogramDataset(testSet_raw)

print(trainSet[0])
