import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, Subset
from PIL import Image
import os
import pickle
import numpy as np


def unpickle(file):
    """載入 ImageNet-32 pickle 檔案，兼容 Python 2/3"""
    try:
        with open(file, 'rb') as fo:
            dict = pickle.load(fo)
    except UnicodeDecodeError:
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='latin1')
    return dict


class ImageNet32Dataset(Dataset):
    """ImageNet-32 資料集載入器"""
    def __init__(self, data_folder, split="train", transform=None):
        self.transform = transform
        self.data = []
        self.labels = []
        
        if split == "train":
            # 載入 train_data_batch_1 到 train_data_batch_10
            print("Loading ImageNet-32 training data...")
            for i in range(1, 11):
                batch_file = os.path.join(data_folder, f'train_data_batch_{i}')
                if os.path.exists(batch_file):
                    print(f"  Loading batch {i}...")
                    d = unpickle(batch_file)
                    x = d['data'].astype(np.float32) / 255.0  # 正規化到 [0,1]
                    y = np.array([i-1 for i in d['labels']], dtype=np.int64)  # 轉為 0-based
                    
                    # 重塑為 (N, 3, 32, 32) 格式
                    img_size = 32
                    img_size2 = img_size * img_size
                    
                    x_reshaped = []
                    for j in range(x.shape[0]):
                        single_img = x[j]
                        # 分離 RGB 通道並重塑
                        r = single_img[:img_size2].reshape(img_size, img_size)
                        g = single_img[img_size2:2*img_size2].reshape(img_size, img_size)
                        b = single_img[2*img_size2:].reshape(img_size, img_size)
                        # 組合成 (3, 32, 32) - PyTorch format
                        rgb_img = np.stack([r, g, b], axis=0)
                        x_reshaped.append(rgb_img)
                    
                    x_reshaped = np.array(x_reshaped)
                    self.data.append(x_reshaped)
                    self.labels.append(y)
                    
            # 合併所有 batch
            self.data = np.concatenate(self.data, axis=0)
            self.labels = np.concatenate(self.labels, axis=0)
            
        elif split in ["val", "eval", "test"]:
            # 載入 val_data
            print("Loading ImageNet-32 validation data...")
            val_file = os.path.join(data_folder, 'val_data')
            if os.path.exists(val_file):
                d = unpickle(val_file)
                x = d['data'].astype(np.float32) / 255.0
                y = np.array([i-1 for i in d['labels']], dtype=np.int64)
                
                # 重塑為 (N, 3, 32, 32) 格式
                img_size = 32
                img_size2 = img_size * img_size
                
                x_reshaped = []
                for j in range(x.shape[0]):
                    single_img = x[j]
                    r = single_img[:img_size2].reshape(img_size, img_size)
                    g = single_img[img_size2:2*img_size2].reshape(img_size, img_size)
                    b = single_img[2*img_size2:].reshape(img_size, img_size)
                    rgb_img = np.stack([r, g, b], axis=0)
                    x_reshaped.append(rgb_img)
                
                self.data = np.array(x_reshaped)
                self.labels = y
            else:
                raise FileNotFoundError(f"Validation file not found: {val_file}")
        else:
            raise ValueError(f"Unsupported split: {split}. Must be one of ['train', 'val', 'eval', 'test']")
        
        print(f"ImageNet-32 {split} loaded: {len(self.data)} images, shape: {self.data.shape}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = torch.FloatTensor(self.data[idx])  # (3, 32, 32)
        label = int(self.labels[idx])
        
        if self.transform:
            # 注意：torchvision transforms 通常期望 (H, W, C) 格式
            image_hwc = image.permute(1, 2, 0)  # (32, 32, 3)
            image_hwc = self.transform(image_hwc)
            if len(image_hwc.shape) == 3:
                image = image_hwc.permute(2, 0, 1)  # 轉回 (3, 32, 32)
            else:
                image = image_hwc
        
        return image, label


class BSDS500Crops(Dataset):
    def __init__(self, root, split="train", transform=None):
        # root = ./data/BSDS500/crop/images
        from pathlib import Path
        self.split_dir = Path(root) / split
        self.transform = transform
        self.files = sorted(self.split_dir.glob("*.png"))
        if len(self.files) == 0:
            raise RuntimeError(f"No files found in {self.split_dir}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        img = Image.open(path).convert("L")   # 灰階
        if self.transform:
            img = self.transform(img)
        label = 0  # dummy label，確保格式一致
        return img, label

class PairDataset(Dataset):
    def __init__(self, dataset_name='mnist', root='./data', 
                 split='train', mode=None,
                 noise_std=0.0, down_scale=2.0, 
                 use_augmentation=False, aug_num=None,
                 random_seed=42,
                 # >>> 新增：成對強度增強（僅 super 可用） <<<
                 use_scaling_aug=False,
                 scaling_alpha_range=(0.8, 1.2),
                 scaling_bias_range=(-0.05, 0.05),
                 scaling_prob=0.80,
                 scaling_per_channel=False):
        self.mode = mode
        self.noise_std = noise_std
        self.down_scale = down_scale
        self.dataset_name = dataset_name.lower()
        self.root = root
        self.split = split
        self.random_seed = random_seed
        self.use_augmentation = use_augmentation
        self.aug_num = aug_num
        self.use_scaling_aug = use_scaling_aug
        self.scaling_alpha_range = scaling_alpha_range
        self.scaling_bias_range = scaling_bias_range
        self.scaling_prob = scaling_prob
        self.scaling_per_channel = scaling_per_channel
        
        if split not in ['train', 'eval', 'test']:
            raise ValueError(f"split must be one of ['train', 'eval', 'test'], got {split}")
        
        if self.dataset_name not in ['mnist', 'cifar100', 'bsds500', 'imagenet32']:
            raise NotImplementedError(f"Dataset currectly only support for MNIST, CIFAR-100, BSDS500, and ImageNet-32, got {dataset_name}")
        
        if self.mode == 'denoising':
            assert self.noise_std != 0.0, "noise_std must be non-zero for denoising mode"
            if self.use_scaling_aug:
                raise NotImplementedError("use_scaling_aug is only supported in 'super' mode (disabled for denoising).")
        elif self.mode == 'super':
            assert self.down_scale > 1.0, "down_scale must be > 1.0 for super_resolution mode"
            assert isinstance(self.down_scale, (int, float)), "down_scale must be a number"
        else:
            if self.use_scaling_aug:
                raise NotImplementedError("use_scaling_aug is only supported in 'super' mode.")
        
        transform = T.Compose([T.ToTensor()])
        
        if self.dataset_name == 'mnist':
            if split == 'test':
                self.dataset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform)
            else:  # train or eval
                full_train_dataset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform)
                self.dataset = self._create_train_eval_split(full_train_dataset)               
        elif self.dataset_name == 'cifar100':
            if split == 'test':
                self.dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform)
            else:  # train or eval
                full_train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform)
                self.dataset = self._create_train_eval_split(full_train_dataset)
        elif self.dataset_name == 'bsds500':
            bsds_root = os.path.join(root, "BSDS500", "crop", "images")
            self.dataset = BSDS500Crops(root=bsds_root, split=split, transform=transform)
        elif self.dataset_name == 'imagenet32':
            # ImageNet-32 支援
            imagenet32_root = os.path.join(root, "ImageNet-32-source")
            
            if split == 'train':
                train_folder = os.path.join(imagenet32_root, "train")
                self.dataset = ImageNet32Dataset(train_folder, split="train", transform=None)  # 已經是 tensor
            elif split in ['eval', 'test']:
                val_folder = os.path.join(imagenet32_root, "val") 
                full_val_dataset = ImageNet32Dataset(val_folder, split="val", transform=None)
                self.dataset = self._create_imagenet32_eval_test_split(full_val_dataset)
        else:
            raise ValueError(f"Unsupported dataset: {self.dataset_name}")

        sample_img, _ = self.dataset[0]
        self.C, self.H, self.W = sample_img.shape
        self.input_dim = self.C * self.H * self.W
        
        ## Caching the data for super-resolution
        if self.mode == 'super':
            self.low_res_H = int(self.H // self.down_scale)
            self.low_res_W = int(self.W // self.down_scale)
            self.low_res_input_dim = self.C * self.low_res_H * self.low_res_W
            print(f"Super resolution: {self.H}x{self.W} -> {self.low_res_H}x{self.low_res_W} (scale: {self.down_scale})")

            # Use online downsampling for ImageNet-32 (too large for caching)
            if self.dataset_name == 'imagenet32':
                print("Using online downsampling for ImageNet-32 (no caching)")
                self.use_online_downsampling = True
            else:
                # Use caching for smaller datasets (MNIST, CIFAR-100, BSDS500)
                self.use_online_downsampling = False
                
                aug_suffix = f"_with_aug_{self.aug_num}" if self.use_augmentation else ""
                cache_dir = os.path.join(self.root, "downsampled_cache",
                                       f"{self.dataset_name}_scale{self.down_scale}_{self.split}{aug_suffix}")
                os.makedirs(cache_dir, exist_ok=True)

                self.cache_file = os.path.join(cache_dir, "downsampled_data.pkl")
                self.index_file = os.path.join(cache_dir, "index_mapping.pkl")

                if os.path.exists(self.cache_file) and os.path.exists(self.index_file):
                    print("Loading cached downsampled images...")
                    with open(self.cache_file, 'rb') as f:
                        self.downsampled_images = pickle.load(f)
                    with open(self.index_file, 'rb') as f:
                        self.index_mapping = pickle.load(f)
                    print(f"Loaded {len(self.downsampled_images)} cached downsampled images")
                else:
                    print("Creating and caching downsampled images...")
                    self.downsampled_images = []
                    self.index_mapping = []

                    for i in range(len(self.dataset)):
                        img, label = self.dataset[i]

                        # original image (augmentation type 0)
                        downsampled_orig = self._downsample_image(img)
                        self.downsampled_images.append((img.clone(), downsampled_orig))
                        self.index_mapping.append((i, label, 0))  # 0 = original
                        
                        if self.use_augmentation:
                            if self.aug_num == 4:
                                # 1. 水平翻轉
                                img_h_flip = self._apply_augmentation(img, aug_type=1)
                                downsampled_h = self._downsample_image(img_h_flip)
                                self.downsampled_images.append((img_h_flip.clone(), downsampled_h))
                                self.index_mapping.append((i, label, 1))
                                # 2. 垂直翻轉
                                img_v_flip = self._apply_augmentation(img, aug_type=2)
                                downsampled_v = self._downsample_image(img_v_flip)
                                self.downsampled_images.append((img_v_flip.clone(), downsampled_v))
                                self.index_mapping.append((i, label, 2))
                                # 3. 先水平再垂直
                                img_hv_flip = self._apply_augmentation(img, aug_type=3)
                                downsampled_hv = self._downsample_image(img_hv_flip)
                                self.downsampled_images.append((img_hv_flip.clone(), downsampled_hv))
                                self.index_mapping.append((i, label, 3))
                            elif self.aug_num == 2:
                                # 只用 hv flip 一種
                                img_hv_flip = self._apply_augmentation(img, aug_type=3)
                                downsampled_hv = self._downsample_image(img_hv_flip)
                                self.downsampled_images.append((img_hv_flip.clone(), downsampled_hv))
                                self.index_mapping.append((i, label, 3))
                            else:
                                raise NotImplementedError(f"Unsupported aug_num: {self.aug_num}, only support 2 or 4")
                    with open(self.cache_file, 'wb') as f:
                        pickle.dump(self.downsampled_images, f)
                    with open(self.index_file, 'wb') as f:
                        pickle.dump(self.index_mapping, f)
                    print(f"Cached {len(self.downsampled_images)} downsampled images to {cache_dir}")

        ## Caching the data for denoising (僅 eval/test 才快取)
        elif self.mode == 'denoising' and (self.split == 'eval' or self.split == 'test'):
            print(f"Denoising mode with noise_std={self.noise_std}")
            
            cache_dir = os.path.join(self.root, "denoising_cache",
                                   f"{self.dataset_name}_noise{self.noise_std}_{self.split}")
            os.makedirs(cache_dir, exist_ok=True)

            self.cache_file = os.path.join(cache_dir, "noisy_data.pkl")
            self.index_file = os.path.join(cache_dir, "index_mapping.pkl")

            if os.path.exists(self.cache_file) and os.path.exists(self.index_file):
                print("Loading cached noisy images...")
                with open(self.cache_file, 'rb') as f:
                    self.noisy_images = pickle.load(f)
                with open(self.index_file, 'rb') as f:
                    self.index_mapping = pickle.load(f)
                print(f"Loaded {len(self.noisy_images)} cached noisy images")
            else:
                print("Creating and caching noisy images...")
                self.noisy_images = []
                self.index_mapping = []

                for i in range(len(self.dataset)):
                    img, label = self.dataset[i]
                    noisy_img = self._add_noise_to_image(img)
                    self.noisy_images.append(noisy_img)
                    self.index_mapping.append((i, label, 0)) 

                with open(self.cache_file, 'wb') as f:
                    pickle.dump(self.noisy_images, f)
                with open(self.index_file, 'wb') as f:
                    pickle.dump(self.index_mapping, f)
                print(f"Cached {len(self.noisy_images)} noisy images to {cache_dir}")

        ## Class to prototype mode --> save prototype images
        elif self.mode == 'class_to_prototype':
            self.prototypes = {}
            if self.dataset_name == 'mnist':
                for label in range(10):
                    for img, lbl in self.dataset:
                        if lbl == label:
                            self.prototypes[label] = img.clone()
                            break
            else:
                raise NotImplementedError()
    
    def _create_train_eval_split(self, full_dataset):
        total_size = len(full_dataset)
        local_rng = np.random.RandomState(self.random_seed)
        indices = local_rng.permutation(total_size)
        
        if self.dataset_name == 'mnist':
            if self.split == 'train':
                selected_indices = indices[:50000]
                print(f"Created train split with {len(selected_indices)} samples")
            elif self.split == 'eval':
                selected_indices = indices[50000:]
                print(f"Created eval split with {len(selected_indices)} samples")
        elif self.dataset_name == 'cifar100':
            if self.split == 'train':
                selected_indices = indices[:40000]
                print(f"Created train split with {len(selected_indices)} samples")
            elif self.split == 'eval':
                selected_indices = indices[40000:]
                print(f"Created eval split with {len(selected_indices)} samples")
        else:
            raise ValueError(f"Unsupported dataset: {self.dataset_name}")
        return Subset(full_dataset, selected_indices)
    
    def _create_imagenet32_eval_test_split(self, full_val_dataset):
        """用固定 random seed 將 ImageNet-32 val 切分成 eval/test"""
        total_size = len(full_val_dataset)
        local_rng = np.random.RandomState(self.random_seed)  # 固定 seed = 42
        indices = local_rng.permutation(total_size)
        
        # 50% eval, 50% test
        half_size = total_size // 2
        
        if self.split == 'eval':
            selected_indices = indices[:half_size]
            print(f"Created ImageNet-32 eval split with {len(selected_indices)} samples")
        elif self.split == 'test':
            selected_indices = indices[half_size:]
            print(f"Created ImageNet-32 test split with {len(selected_indices)} samples")
        
        return Subset(full_val_dataset, selected_indices)

    def _downsample_image(self, img):
        C, H, W = img.shape
        target_H = H // int(self.down_scale)
        target_W = W // int(self.down_scale)
        img_batch = img.unsqueeze(0)  # (1,C,H,W)
        downsampled = F.interpolate(
            img_batch, 
            size=(target_H, target_W),
            mode='bicubic',
            align_corners=False,
            antialias=False
        )
        return downsampled.squeeze(0)

    def _add_noise_to_image(self, img):
        noise = torch.randn_like(img) * self.noise_std
        noisy_img = torch.clamp(img + noise, 0., 1.)
        return noisy_img

    def _apply_augmentation(self, img, aug_type=None):
        if aug_type == 1:
            return torch.flip(img, dims=[2])   # 水平
        elif aug_type == 2:
            return torch.flip(img, dims=[1])   # 垂直
        elif aug_type == 3:
            return torch.flip(torch.flip(img, dims=[2]), dims=[1])  # 先水平再垂直
        else:
            raise ValueError(f"Unsupported aug_type: {aug_type}, must be one of [1, 2, 3]")

    # >>> 新增：對 HR/LR 成對的仿射強度增強（線上） <<<
    def _apply_paired_affine_intensity(self, hr_img, lr_img):
        """
        hr_img, lr_img: (C,H,W) in [0,1]
        以同一組 (alpha, bias) 做 y = clamp(alpha*x + bias, 0, 1)
        """
        if self.scaling_prob < 1.0 and torch.rand(1).item() >= self.scaling_prob:
            return hr_img, lr_img

        # 取樣 alpha, bias
        a_low, a_high = self.scaling_alpha_range
        b_low, b_high = self.scaling_bias_range

        if self.scaling_per_channel:
            shape = (hr_img.size(0), 1, 1)
            alpha = torch.empty(shape, device=hr_img.device).uniform_(a_low, a_high)
            bias  = torch.empty(shape, device=hr_img.device).uniform_(b_low, b_high)
        else:
            alpha = torch.empty(1, device=hr_img.device).uniform_(a_low, a_high)
            bias  = torch.empty(1, device=hr_img.device).uniform_(b_low, b_high)

        def adj(x):
            y = alpha * x + bias
            return torch.clamp(y, 0., 1.)

        return adj(hr_img), adj(lr_img)
 
    def __len__(self):
        if self.mode == 'super':
            if hasattr(self, 'use_online_downsampling') and self.use_online_downsampling:
                # Online downsampling: return dataset size with augmentation
                if self.use_augmentation:
                    return len(self.dataset) * self.aug_num
                return len(self.dataset)
            else:
                # Cached downsampling: return cached size
                return len(self.downsampled_images)
        elif self.mode == 'denoising' and (self.split == 'eval' or self.split == 'test'): # load from cache
            return len(self.noisy_images)
        else:
            if self.use_augmentation:
                return len(self.dataset) * self.aug_num
            return len(self.dataset)

    def __getitem__(self, idx):
        if self.mode == 'super':
            if hasattr(self, 'use_online_downsampling') and self.use_online_downsampling:
                # Online downsampling for ImageNet-32
                if self.use_augmentation:
                    dataset_size = len(self.dataset)
                    if self.aug_num == 4:
                        if idx < dataset_size:
                            original_idx = idx
                            img, label = self.dataset[original_idx]
                            aug_type = 0
                        elif idx < dataset_size * 2:
                            original_idx = idx - dataset_size
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=1)
                            aug_type = 1
                        elif idx < dataset_size * 3:
                            original_idx = idx - dataset_size * 2
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=2)
                            aug_type = 2
                        else:
                            original_idx = idx - dataset_size * 3
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=3)
                            aug_type = 3
                    elif self.aug_num == 2:
                        if idx < dataset_size:
                            original_idx = idx
                            img, label = self.dataset[original_idx]
                            aug_type = 0
                        else:
                            original_idx = idx - dataset_size
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=3)
                            aug_type = 3
                    else:
                        raise NotImplementedError(f"Unsupported aug_num: {self.aug_num}, only support 2 or 4")
                else:
                    original_idx = idx
                    img, label = self.dataset[idx]
                    aug_type = 0
                
                # Online downsampling
                target_img = img.clone()  # HR
                input_img = self._downsample_image(img)  # LR
                
                # 線上成對強度增強（只在 train split 啟用較合理；eval/test 通常不做）
                if self.split == 'train' and self.use_scaling_aug:
                    target_img, input_img = self._apply_paired_affine_intensity(target_img, input_img)

                input_img = input_img.view(-1)
                target_img = target_img.view(-1)
            else:
                # Cached downsampling for smaller datasets
                img, low_res_img = self.downsampled_images[idx]
                original_idx, label, aug_type = self.index_mapping[idx]
                
                # clone，避免改動快取
                target_img = img.clone()          # HR
                input_img  = low_res_img.clone()  # LR

                # 線上成對強度增強（只在 train split 啟用較合理；eval/test 通常不做）
                if self.split == 'train' and self.use_scaling_aug:
                    target_img, input_img = self._apply_paired_affine_intensity(target_img, input_img)

                input_img  = input_img.view(-1)
                target_img = target_img.view(-1)
            
        elif self.mode == 'denoising':
            if self.split in ['eval', 'test']:
                noisy_img = self.noisy_images[idx]
                original_idx, label, aug_type = self.index_mapping[idx]
                img, _ = self.dataset[original_idx]
                
                target_img = img.clone()
                input_img  = noisy_img.view(-1)
                target_img = target_img.view(-1)
            else:
                if self.use_augmentation:
                    dataset_size = len(self.dataset)
                    if self.aug_num == 4:
                        if idx < dataset_size:
                            original_idx = idx
                            img, label = self.dataset[original_idx]
                        elif idx < dataset_size * 2:
                            original_idx = idx - dataset_size
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=1)
                        elif idx < dataset_size * 3:
                            original_idx = idx - dataset_size * 2
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=2)
                        else:
                            original_idx = idx - dataset_size * 3
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=3)
                    elif self.aug_num == 2:
                        if idx < dataset_size:
                            original_idx = idx
                            img, label = self.dataset[original_idx]
                        else:
                            original_idx = idx - dataset_size
                            img, label = self.dataset[original_idx]
                            img = self._apply_augmentation(img, aug_type=3)
                else:
                    original_idx = idx
                    img, label = self.dataset[original_idx]
                noisy_img = self._add_noise_to_image(img)
                target_img = img.clone()
                input_img  = noisy_img.view(-1)
                target_img = target_img.view(-1)

        elif self.mode == 'class_to_prototype':
            img, label = self.dataset[idx]
            noisy_img = img.clone()
            target_img = self.prototypes[label].clone()
            input_img  = noisy_img.view(-1)
            target_img = target_img.view(-1)
        else:
            raise ValueError(f"Unsupported mode: {self.mode}")

        return input_img, target_img, label