import os
import numpy as np
import torch
import tifffile
import glob
from PIL import Image
from pathlib import Path

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn import functional as F

class CollateFn:
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, batch):
        processed_batch = [random_crop(item, self.target_size) for item in batch]
        return torch.stack(processed_batch, dim=0)

class Pretrain_Dataset(Dataset):
    def __init__(self, root_dir, img_size):
        self.root_dir = root_dir
        self.img_size = img_size  # (H, W)
        self.samples = self._create_sample_list()

    def _create_sample_list(self):
        return [f for f in os.listdir(self.root_dir) if (f.endswith('.tiff') or f.endswith('.npy'))]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path = os.path.join(self.root_dir, self.samples[idx])
        
        # Read the preprocessed images
        if sample_path.endswith('.tiff'):
            raw_image = tifffile.imread(sample_path)
        elif sample_path.endswith('.npy'):
            raw_image = np.load(sample_path)
        else:
            raise ValueError(f"Unsupported images: {sample_path}")
        
        if raw_image.ndim == 3:
            raw_image = raw_image[raw_image.shape[0]//2] 
        elif raw_image.ndim > 3:
            raise ValueError(f"Unexpected image dimensions: {raw_image.shape}")
        
        image_tensor = torch.from_numpy(raw_image).float().unsqueeze(0)
        image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())
        image_tensor = self.upsample(image_tensor)
        
        return image_tensor

    def upsample(self, image_tensor):
        h, w = image_tensor.shape[1:] 
        th, tw = self.img_size

        h_factor = max(1, th / h)
        w_factor = max(1, tw / w)

        if h_factor > 1 or w_factor > 1:
            image_tensor = F.interpolate(image_tensor.unsqueeze(0), 
                                         size=(int(h*h_factor), int(w*w_factor)),
                                         mode='bilinear', 
                                         align_corners=False).squeeze(0)

        return image_tensor

def random_crop(image_tensor, target_size):
    _, h, w = image_tensor.shape
    th, tw = target_size

    if h > th:
        start_h = torch.randint(0, h - th + 1, (1,)).item()
        image_tensor = image_tensor[:, start_h:start_h+th]
    if w > tw:
        start_w = torch.randint(0, w - tw + 1, (1,)).item()
        image_tensor = image_tensor[:, :, start_w:start_w+tw]

    image_tensor = F.interpolate(image_tensor.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0)

    return image_tensor

def get_dataloader(config, is_train=True):
    dataset = Pretrain_Dataset(config.data_dir, img_size=config.img_size)
    sampler = DistributedSampler(dataset) if is_train else None
    
    collate = CollateFn(config.img_size)
    
    return DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=(sampler is None and is_train),
        sampler=sampler,
        num_workers=config.num_workers,
        pin_memory=True,
        collate_fn=collate,
        persistent_workers=True,
    )

class ISOLiverDataset(Dataset):
    def __init__(self, config, is_train=True):
        self.data_dir = config.data_dir
        self.is_train = is_train
        self.img_size = config.img_size
        
        info = np.load(os.path.join(self.data_dir, 'dataset_info.npy'), allow_pickle=True).item()
        if is_train:
            self.original_shape = info['original_train_shape']
        else:
            self.original_shape = info['original_test_shape']
        
        if is_train:
            self.file_list = sorted(glob.glob(os.path.join(self.data_dir, 'train', 'slice_*.npz')))
        else:
            self.file_list = sorted(glob.glob(os.path.join(self.data_dir, 'test', 'xz_slice_*.npz')))
            
        print(f"{'Train' if is_train else 'Test'} dataset initialized with {len(self.file_list)} samples")
        
        sample_data = np.load(self.file_list[0])
        source_shape = sample_data['source'].shape
        print(f"Original data shape: {source_shape}")
        print(f"Target resize shape: {self.img_size}")

    def resize_2d(self, img):
        # img shape: (C, H, W)
        c, h, w = img.shape
        img_resized = F.interpolate(
            img.unsqueeze(0), 
            size=self.img_size,
            mode='bilinear',
            align_corners=False
        ).squeeze(0)
        return img_resized

    def __getitem__(self, idx):
        data = np.load(self.file_list[idx])
        source = torch.from_numpy(data['source']).float()  # (C, H, W)
        target = torch.from_numpy(data['target']).float()  # (C, H, W)
        
        if self.is_train:
            source = self.resize_2d(source)
            target = self.resize_2d(target)
            return source, target
        else:
            position = data['position']
            return source, target, position

    def __len__(self):
        return len(self.file_list)
    
class ProjFlywingDataset(Dataset):
    def __init__(self, config, is_train=True, condition=0):
        self.config = config
        self.is_train = is_train
        self.condition = condition
        self.iso = ['Projection_Flywing']

        if self.is_train:
            self._load_train_data()
        else:
            self._load_test_data()

        print(f"{'Train' if is_train else 'Test'} dataset initialized with {self.lenth} samples")

    def _load_train_data(self):
        datapath = f"{self.config.data_dir}/train_data/my_training_data.npz"
        datapath2 = f"{self.config.data_dir}/train_data/data_label.npz"

        X1, Y1 = self._load_npz_data(datapath)

        self.nm_lr = X1
        self.nm_hr = Y1
        self.lenth = len(self.nm_lr)

    def _load_test_data(self):
        dir_lr = f"{self.config.data_dir}/test_data/"
        self.nm_lr = sorted(glob.glob(f"{dir_lr}Input/C{self.condition}/*.tif"))
        self.nm_hr = sorted(glob.glob(f"{dir_lr}GT/C{self.condition}/*.tif"))
        self.lenth = len(self.nm_lr)

    def _load_npz_data(self, path):
        """Load .npz data (784, 1, 50, 128, 128), SCZYX"""
        data = np.load(path)
        return data['X'], data['Y']

    def _split_patches(self, X, Y, patch_size=64):
        X_patches, Y_patches = [], []
        for n in range(len(X)):
            for i in range(0, X.shape[3], patch_size):
                for j in range(0, X.shape[4], patch_size):
                    X_patches.append(X[n][:, j:j + patch_size, i:i + patch_size, :])
                    Y_patches.append(Y[n][:, j:j + patch_size, i:i + patch_size, :])
        return np.array(X_patches), np.array(Y_patches)

    def _load_training_data(self, file, axes):
        data = np.load(file)
        X, Y = data['X'], data['Y']

        return X, Y

    def __getitem__(self, idx):
        idx = idx % self.lenth
        if self.is_train:
            lr, hr = self.nm_lr[idx], self.nm_hr[idx]
        else:
            lr = np.float32(imread(self.nm_lr[idx]))
            hr = np.expand_dims(np.float32(imread(self.nm_hr[idx])), 0)

        lr = torch.from_numpy(np.ascontiguousarray(lr * self.config.rgb_range)).float()
        hr = torch.from_numpy(np.ascontiguousarray(hr * self.config.rgb_range)).float()

        return lr, hr

    def __len__(self):
        return self.lenth

class BSAFusionDataset2D(Dataset):
    def __init__(self, config, is_train = True):
        data_path = Path(config.data_dir)
        self.modularities = config.modularities # Can be "CT-MRI", "PET-MRI", and "SPECT-MRI"
        modularity1 = self.modularities.split('-')[0]
        modularity2 = self.modularities.split('-')[1]
        if is_train:
            image_dir = data_path / self.modularities / "train"
            mod1_images = glob.glob(str(image_dir / modularity1 / "*.png"))
            mod2_images = glob.glob(str(image_dir / modularity2 / "*.png"))

        else:
            image_dir = data_path / self.modularities / "test"
            mod1_images = glob.glob(str(image_dir / modularity1 / "*.png"))
            mod2_images = glob.glob(str(image_dir / modularity2 / "*.png"))

        # Check if the name of image matches
        mod1_images = sorted(mod1_images)
        mod2_images = sorted(mod2_images)
        assert len(mod1_images) == len(mod2_images), "The number of images in two modalities should be the same."
        for i in range(len(mod1_images)):
            assert Path(mod1_images[i]).name == Path(mod2_images[i]).name, "The name of images in two modalities should be the same."
                    
        self.mod1_images = mod1_images
        self.mod2_images = mod2_images

        self.length = len(mod1_images)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Read png images to np array
        source1 = np.array(Image.open(self.mod1_images[idx]).convert('L')) # HW
        source2 = np.array(Image.open(self.mod2_images[idx])) # HW
        
        source1 = (source1 - source1.min()) / (source1.max() - source1.min()) if source1.max() != source1.min() else np.zeros_like(source1)
        source2 = (source2 - source2.min()) / (source2.max() - source2.min()) if source2.max() != source2.min() else np.zeros_like(source2)

        source1 = torch.tensor(source1, dtype=torch.float32).unsqueeze(0)  # (1, D, H, W)
        source2 = torch.tensor(source2, dtype=torch.float32).unsqueeze(0)  # (1, D, H, W)

        target = source2 # Fake target, not used
        return source1, source2, target
    

class UnifmirSRDataset2D(Dataset):
    def __init__(self, config, is_train = True):
        self.data_dir = Path(config.data_dir)
        self.data_types = config.data_types # A list, contains ['CCPs', 'F-actin', 'ER', 'Microtubules']

        self.is_train = is_train
        if self.is_train:
            self.data_length = {'CCPs': 19440, 'F-actin': 19584, 'ER': 19584, 'Microtubules': 19800}
        else:
            self.data_length = {'CCPs': 100, 'F-actin': 100, 'ER': 100, 'Microtubules': 100}

        self.img_size = config.img_size

    def resize_2d(self, img):
        # img shape: (H, W)
        h, w = img.shape
        img_resized = F.interpolate(
            img.unsqueeze(0).unsqueeze(0), 
            size=self.img_size,
            mode='bilinear',
            align_corners=False
        )  
        return img_resized.squeeze()
    
    def __len__(self):
        length = 0
        for data_type in self.data_types:
            length += self.data_length[data_type]
        self.length = length
        return self.length

    def __getitem__(self, idx):
        # Compute data_path accrording to idx
        data_length = self.data_length
        data_type = None
        for key in self.data_types:
            if idx < data_length[key]:
                data_type = key
                break
            else:
                idx -= data_length[key]
        if self.is_train:
            data_path = self.data_dir / 'train' / data_type
            source = np.load(data_path / f"preprocessed/X_{idx}.npy").squeeze() # (128, 128)
            target = np.load(data_path / f"preprocessed/Y_{idx}.npy").squeeze() # (256, 256)
        else:
            data_path = self.data_dir / 'test' / data_type
            source = tifffile.imread(data_path / 'LR' / f"im{idx+1}_LR.tif").squeeze() # (128, 128)
            target = tifffile.imread(data_path / 'GT' / f"im{idx+1}_GT.tif").squeeze() # (256, 256)

        source = (source - source.min()) / (source.max() - source.min()) if source.max() != source.min() else np.zeros_like(source)
        target = (target - target.min()) / (target.max() - target.min()) if target.max() != target.min() else np.zeros_like(target)
        source = torch.tensor(source, dtype=torch.float32) 
        target = torch.tensor(target, dtype=torch.float32)

        source = self.resize_2d(source).unsqueeze(0)
        target = self.resize_2d(target).unsqueeze(0)

        return source, target