import os
from data.base_dataset3d import BaseDataset, get_params, get_transform
from data.image_folder import make_dataset
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import torchio as tio
import torch
import random
from torch.utils.data import Dataset
from torchvision.transforms import InterpolationMode
from scipy.ndimage import rotate as scipy_rotate

def rotate(images, angle):
    """Rotate 3D images along a random axis."""
    axis = random.randint(1, 3)  # Choose axis 1, 2, or 3
    return scipy_rotate(images, angle, axes=(0, axis), reshape=False, order=1)

def adjust_brightness(images, factor):
    """Adjust brightness of 3D images."""
    return images * factor

def add_noise(images, scale):
    """Add random noise to 3D images."""
    noise = np.random.normal(0, scale, images.shape).astype(images.dtype)
    return images + noise

def adaptive_discriminator_augmentation(real_images, fake_images, current_loss, target_loss):
    """Apply adaptive augmentation based on the current discriminator loss."""
    aug_strength = min(0.9, current_loss / target_loss)
    augmented_real = apply_augmentation(real_images, strength=aug_strength)
    augmented_fake = apply_augmentation(fake_images, strength=aug_strength)
    return augmented_real, augmented_fake

def apply_augmentation(images, strength):
    """Apply a random subset of augmentations with given strength."""
    # Convert to numpy for augmentation if it's a torch tensor
    is_tensor = isinstance(images, torch.Tensor)
    if is_tensor:
        images = images.detach().cpu().numpy()
    
    ops = [
        lambda x: rotate(x, angle=random.uniform(-30, 30)),
        lambda x: adjust_brightness(x, factor=random.uniform(0.7, 1.3)),
        lambda x: add_noise(x, scale=random.uniform(0, 0.1))
    ]
    
    # Apply a random subset of operations based on strength
    num_ops = max(1, int(len(ops) * strength))
    for op in random.sample(ops, k=num_ops):
        images = op(images)
    
    # Convert back to tensor if it was one initially
    if is_tensor:
        images = torch.from_numpy(images)
    
    return images


class AlignedOCT2OCTA3DDataset(BaseDataset):
    """A dataset class for paired image dataset.

    OCT to OCTA 3D

    It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
    During test time, you need to prepare a directory '/path/to/data/test'.
    """

    def __init__(self, opt, phase):
        BaseDataset.__init__(self, opt)
        self.dir_A = os.path.join(opt.dataroot, phase, 'A')  # get the image directory
        self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))  # get image paths

        self.dir_B = os.path.join(opt.dataroot, phase, 'B')  # get the image directory
        self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))  # get image paths
        print("AB paths", self.dir_A, self.dir_B)
        print("AB paths length", len(self.A_paths), len(self.B_paths))
        assert(len(self.A_paths)==len(self.B_paths))
        assert(self.opt.load_size >= self.opt.crop_size)   # crop_size should be smaller than the size of loaded image
        self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
        self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc

        # Store current loss and target loss for adaptive augmentation
        self.current_loss = 1.0  # Initial value
        self.target_loss = 0.5   # Target discriminator loss

    def __getitem__(self, index):
        """Return a data point and its metadata information."""
        A_path = self.A_paths[index]
        B_path = self.B_paths[index]
        A = np.load(A_path).astype(np.float32)  # Convert to float32 immediately after loading
        B = np.load(B_path).astype(np.float32)

        A = np.expand_dims(A, axis=0)
        B = np.expand_dims(B, axis=0)
        # print("转换前的尺寸:",A.shape)

        transform_params = get_params(self.opt, A.shape)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)
        # print("转换后的尺寸:",A.shape)
        
        # Normalize and scale within [0, 1] then [-1, 1], ensuring float32
        A = ((A - A.min()) / (A.max() - A.min())).astype(np.float32)
        B = ((B - B.min()) / (B.max() - B.min())).astype(np.float32)
        A = (2 * A - 1).astype(np.float32)
        B = (2 * B - 1).astype(np.float32)

        A = torch.from_numpy(A)
        B = torch.from_numpy(B)

        return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}

    def update_discriminator_loss(self, current_loss):
        """Update the current discriminator loss for adaptive augmentation."""
        self.current_loss = current_loss
    
    def get_augmented_batch(self, real_batch, fake_batch):
        """Apply adaptive discriminator augmentation to a batch of images."""
        return adaptive_discriminator_augmentation(
            real_batch, 
            fake_batch, 
            self.current_loss, 
            self.target_loss
        )

    def __len__(self):
        """Return the total number of images in the dataset."""
        assert (len(self.A_paths)==len(self.B_paths))
        return len(self.A_paths)