import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
import skimage.io
import rasterio

class RandomGenerator(object):
    def __init__(self, output_size):
        """
        Randomly transforms images and labels.

        Args:
            output_size (tuple): Desired output size of the transformed images.
        """
        self.output_size = output_size

    def __call__(self, sample):
        """
        Apply random transformations to the input sample.

        Args:
            sample (dict): A dictionary containing 'image', 'label', 'geotransform', and 'idx'.

        Returns:
            dict: Transformed sample containing 'image', 'label', 'geotransform', and 'idx'.
        """
        image, label, geotransform, idx = sample['image'], sample['label'], sample['geotransform'], sample['idx']

        if random.random() > 0.5:
            image, label = self.random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = self.random_rotate(image, label)
        image = self.rgb_shift(image)
        image = self.random_brightness_contrast(image)
        
        x, y, _ = image.shape
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
            
        image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
        label = torch.from_numpy(label.astype(np.float32))
        sample = {'image': image, 'label': label.long(), 'geotransform': geotransform, 'idx': idx}
        return sample

    def random_rot_flip(self, image, label):
        """
        Apply random rotation and flipping to the input image and label.

        Args:
            image (numpy.ndarray): Input image.
            label (numpy.ndarray): Input label.

        Returns:
            tuple: Transformed image and label.
        """
        k = np.random.randint(0, 4)
        image = np.rot90(image, k)
        label = np.rot90(label, k)
        axis = np.random.randint(0, 2)
        image = np.flip(image, axis=axis).copy()
        label = np.flip(label, axis=axis).copy()
        return image, label

    def random_rotate(self, image, label):
        """
        Apply random rotation to the input image and label.

        Args:
            image (numpy.ndarray): Input image.
            label (numpy.ndarray): Input label.

        Returns:
            tuple: Transformed image and label.
        """
        angle = np.random.randint(-20, 20)
        image = ndimage.rotate(image, angle, order=0, reshape=False)
        label = ndimage.rotate(label, angle, order=0, reshape=False)
        return image, label

    def rgb_shift(self, image, r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5):
        """
        Apply random RGB channel shifts to the input image.

        Args:
            image (numpy.ndarray): Input image.
            r_shift_limit (int): Maximum shift limit for the red channel.
            g_shift_limit (int): Maximum shift limit for the green channel.
            b_shift_limit (int): Maximum shift limit for the blue channel.
            p (float): Probability of applying the transformation.

        Returns:
            numpy.ndarray: Transformed image.
        """
        if np.random.rand() < p:
            r_shift = np.random.randint(-r_shift_limit, r_shift_limit + 1)
            g_shift = np.random.randint(-g_shift_limit, g_shift_limit + 1)
            b_shift = np.random.randint(-b_shift_limit, b_shift_limit + 1)
            image = np.clip(image + [r_shift, g_shift, b_shift], 0, 255).astype(np.uint8)
        return image

    def random_brightness_contrast(self, image, brightness_limit=0.3, contrast_limit=0.3, p=0.5):
        """
        Apply random brightness and contrast adjustments to the input image.

        Args:
            image (numpy.ndarray): Input image.
            brightness_limit (float): Maximum limit for brightness adjustment.
            contrast_limit (float): Maximum limit for contrast adjustment.
            p (float): Probability of applying the transformation.

        Returns:
            numpy.ndarray: Transformed image.
        """
        if np.random.rand() < p:
            brightness_factor = 1.0 + np.random.uniform(-brightness_limit, brightness_limit)
            contrast_factor = 1.0 + np.random.uniform(-contrast_limit, contrast_limit)
            image = np.clip((image - 128) * contrast_factor + 128 + brightness_factor, 0, 255).astype(np.uint8)
        return image

class CoralDataset(Dataset):
    def __init__(self, image_dir, image_names, mode="train", mask_dir=None, mask_names=None, transform=None):
        """
        Geospatial dataset of overwater reef imagery.

        Args:
            image_dir (str): Directory with all the images.
            image_names (list): List of image names.
            mode (str, optional): Dataset mode ("train", "val", "test").
            mask_dir (str, optional): Directory with mask images.
            mask_names (list, optional): List of mask names.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_dir = image_dir
        self.image_names = image_names
        self.mode = mode
        
        if mode != "test":
            self.mask_dir = mask_dir
            self.mask_names = mask_names
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def load_geospatial_info(self, img_path):
        """
        Load geospatial information from the image file.

        Args:
            img_path (str): Path to the image file.

        Returns:
            list: Geospatial transformation information.
        """
        with rasterio.open(img_path) as src:
            transform = src.transform
            return [transform]
            
    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_names[idx])
        image = skimage.io.imread(img_name).astype(np.float32)[:, :, :3]
        geotransform = self.load_geospatial_info(img_name)
        image_idx = os.path.splitext(self.image_names[idx])[0]
        
        if self.mode != "test":
            mask_name = os.path.join(self.mask_dir, self.mask_names[idx])
            mask = skimage.io.imread(mask_name).astype(int)
            sample = {'image': image, 'label': mask, 'geotransform': geotransform, 'idx': image_idx}
            
            if self.transform:
                sample = self.transform(sample)
            
            return sample
        else:
            sample = {'image': image, 'geotransform': geotransform, 'idx': image_idx}
            return sample