from pathlib import Path
from typing import Callable, Optional, Sequence, Tuple
import os

import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from robustbench.data import _load_dataset, CORRUPTIONS, PREPROCESSINGS
import random

class FlatOODImageDataset(Dataset):
    """
    Dataset that loads OOD images from a single folder
    """
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.image_files = self._get_image_files(directory)
        
    def _get_image_files(self, directory):
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm']
        files = []
        
        for f in os.listdir(directory):
            ext = os.path.splitext(f)[1].lower()
            if ext in image_extensions:
                files.append(os.path.join(directory, f))
                
        return sorted(files)
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        # set all labels to -1 (OOD)
        label = -1
        
        return image, label

def load_svhn(
    n_examples: Optional[int] = None,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load SVHN original dataset
    """
    # Load SVHN original dataset
    prepr = transforms.Compose([transforms.Resize((size, size)), transforms.ToTensor()])
    
    svhn_dir = Path(data_dir) / 'SVHN-C'
    if (svhn_dir / 'original.npy').exists():
        # if original.npy file exists, load directly
        images = np.load(svhn_dir / 'original.npy')
        labels = np.load(svhn_dir / 'labels.npy')
        
        # convert to PyTorch format
        images = np.transpose(images, (3, 2, 0, 1))
        
        # select desired number of samples
        if n_examples is not None:
            images = images[:n_examples]
            labels = labels[:n_examples]
        
        images = images.astype(np.float32) / 255
        x_test = torch.tensor(images)
        y_test = torch.tensor(labels)
        
        if shuffle:
            rand_idx = np.random.permutation(np.arange(len(x_test)))
            x_test, y_test = x_test[rand_idx], y_test[rand_idx]
            
        return x_test, y_test
    else:
        # if file does not exist, use PyTorch dataset
        dataset = datasets.SVHN(root=data_dir,
                               split='test',
                               download=True) 
        return _load_dataset(dataset, n_examples)

def load_svhn_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
    _: Callable = PREPROCESSINGS[None]
) -> Tuple[torch.Tensor, torch.Tensor]:
    assert 1 <= severity <= 5
    assert size == 32
    n_total_svhn = 26032

    data_dir = Path(data_dir)
    data_root_dir = data_dir / 'SVHN-C'

    labels_path = data_root_dir / 'labels.npy'
    labels = np.load(labels_path)

    x_test_list, y_test_list = [], []
    n_pert = len(corruptions)
    for corruption in corruptions:
        corruption_file_path = data_root_dir / (corruption + '.npy')

        images_all = np.load(corruption_file_path)
        images = images_all[(severity - 1) * n_total_svhn:severity *
                            n_total_svhn]
        n_img = int(np.ceil(n_examples / n_pert))
        x_test_list.append(images[:n_img])
        # Duplicate the same labels potentially multiple times
        y_test_list.append(labels[:n_img])

    x_test, y_test = np.concatenate(x_test_list), np.concatenate(y_test_list)
    if shuffle:
        rand_idx = np.random.permutation(np.arange(len(x_test)))
        x_test, y_test = x_test[rand_idx], y_test[rand_idx]

    # Make it in the PyTorch format
    x_test = np.transpose(x_test, (0, 3, 1, 2))
    # Make it compatible with our models
    x_test = x_test.astype(np.float32) / 255
    # Make sure that we get exactly n_examples but not a few samples more
    x_test = torch.tensor(x_test)[:n_examples]
    y_test = torch.tensor(y_test)[:n_examples]

    return x_test, y_test

def load_Places365(
    n_examples: int,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load Places365 original dataset
    """
    data_dir = Path(data_dir)
    data_dir = data_dir / 'places365/images'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # find image files in directory
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_files.extend(list(data_dir.glob(ext)))
        image_files.extend(list(data_dir.glob(ext.upper())))
    
    if not image_files:
        raise FileNotFoundError(f"No image files found in {data_dir}")
    
    if shuffle:
        random.shuffle(image_files)
    
    all_images = []
    for img_path in image_files[:n_examples]:
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img)
            all_images.append(img_tensor)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        
        if len(all_images) >= n_examples:
            break
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    # convert to tensor
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_Places365_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load corrupted OOD dataset by directly loading image files
    
    Args:
        n_examples: total number of examples to load
        severity: corruption severity (1-5)
        data_dir: directory containing corrupted dataset
        shuffle: shuffle data
        corruptions: list of corruption types to include
        size: image size (32)
        
    Returns:
        tuple of torch tensors (images, labels)
    """
    assert 1 <= severity <= 5
    
    data_dir = Path(data_dir)
    data_dir = data_dir / f'places365/images-C_{size}'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        # transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    all_images = []
    n_pert = len(corruptions)
    samples_per_corruption = int(np.ceil(n_examples / n_pert))
    
    for corruption in corruptions:
        corruption_dir = data_dir / corruption / str(severity)
        
        if not corruption_dir.exists():
            print(f"Warning: {corruption_dir} doesn't exist. Skipping.")
            continue
        
        # find image files in directory
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_files.extend(list(corruption_dir.glob(ext)))
            image_files.extend(list(corruption_dir.glob(ext.upper())))
        
        if not image_files:
            print(f"Warning: No image files found in {corruption_dir}. Skipping.")
            continue
        
        # sample images
        if shuffle:
            random.shuffle(image_files)
        
        # load specified number of images
        corruption_images = []
        for img_path in image_files[:samples_per_corruption]:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img)
                corruption_images.append(img_tensor)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
            
            if len(corruption_images) >= samples_per_corruption:
                break
        
        if corruption_images:
            all_images.extend(corruption_images)
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    # convert to tensor
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # shuffle if needed
    if shuffle:
        indices = torch.randperm(x_test.size(0))
        x_test = x_test[indices]
        y_test = y_test[indices]
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_LSUN(
    n_examples: int,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load LSUN original dataset
    """
    data_dir = Path(data_dir)
    data_dir = data_dir / 'LSUN_resize/images'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # find image files in directory
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_files.extend(list(data_dir.glob(ext)))
        image_files.extend(list(data_dir.glob(ext.upper())))
    
    if not image_files:
        raise FileNotFoundError(f"No image files found in {data_dir}")
    
    if shuffle:
        random.shuffle(image_files)
    
    # load specified number of images
    all_images = []
    for img_path in image_files[:n_examples]:
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img)
            all_images.append(img_tensor)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        
        if len(all_images) >= n_examples:
            break
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    # convert to tensor
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_LSUN_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load corrupted OOD dataset by directly loading image files
    
    Args:
        n_examples: total number of examples to load
        severity: corruption severity (1-5)
        data_dir: directory containing corrupted dataset
        shuffle: shuffle data
        corruptions: list of corruption types to include
        size: image size (32)
        
    Returns:
        tuple of torch tensors (images, labels)
    """
    assert 1 <= severity <= 5
    
    data_dir = Path(data_dir)
    data_dir = data_dir / f'LSUN_resize/images-C_{size}'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        # transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # collect images by each corruption type
    all_images = []
    n_pert = len(corruptions)
    samples_per_corruption = int(np.ceil(n_examples / n_pert))
    
    for corruption in corruptions:
        corruption_dir = data_dir / corruption / str(severity)
        
        if not corruption_dir.exists():
            print(f"Warning: {corruption_dir} doesn't exist. Skipping.")
            continue
        
        # find image files in directory
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_files.extend(list(corruption_dir.glob(ext)))
            image_files.extend(list(corruption_dir.glob(ext.upper())))
        
        if not image_files:
            print(f"Warning: No image files found in {corruption_dir}. Skipping.")
            continue
        
        # sample images
        if shuffle:
            random.shuffle(image_files)
        
        # load specified number of images
        corruption_images = []
        for img_path in image_files[:samples_per_corruption]:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img)
                corruption_images.append(img_tensor)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
            
            if len(corruption_images) >= samples_per_corruption:
                break
        
        if corruption_images:
            all_images.extend(corruption_images)
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    # convert to tensor
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # shuffle if needed
    if shuffle:
        indices = torch.randperm(x_test.size(0))
        x_test = x_test[indices]
        y_test = y_test[indices]
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_iSUN(
    n_examples: int,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load iSUN original dataset
    """
    data_dir = Path(data_dir)
    data_dir = data_dir / 'iSUN/images'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # find image files in directory
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_files.extend(list(data_dir.glob(ext)))
        image_files.extend(list(data_dir.glob(ext.upper())))
    
    if not image_files:
        raise FileNotFoundError(f"No image files found in {data_dir}")
    
    # shuffle if needed
    if shuffle:
        random.shuffle(image_files)
    
    # load specified number of images
    all_images = []
    for img_path in image_files[:n_examples]:
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img)
            all_images.append(img_tensor)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        
        if len(all_images) >= n_examples:
            break
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    # convert to tensor
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_iSUN_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load corrupted OOD dataset by directly loading image files
    
    Args:
        n_examples: total number of examples to load
        severity: corruption severity (1-5)
        data_dir: directory containing corrupted dataset
        shuffle: shuffle data
        corruptions: list of corruption types to include
        size: image size (32)
        
    Returns:
        tuple of torch tensors (images, labels)
    """
    assert 1 <= severity <= 5
    
    data_dir = Path(data_dir)
    data_dir = data_dir / f'iSUN/images-C_{size}'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        # transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # collect images by each corruption type
    all_images = []
    n_pert = len(corruptions)
    samples_per_corruption = int(np.ceil(n_examples / n_pert))
    
    for corruption in corruptions:
        corruption_dir = data_dir / corruption / str(severity)
        
        if not corruption_dir.exists():
            print(f"Warning: {corruption_dir} doesn't exist. Skipping.")
            continue
        
        # find image files in directory
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_files.extend(list(corruption_dir.glob(ext)))
            image_files.extend(list(corruption_dir.glob(ext.upper())))
        
        if not image_files:
            print(f"Warning: No image files found in {corruption_dir}. Skipping.")
            continue
        
        if shuffle:
            random.shuffle(image_files)
        
        # load specified number of images
        corruption_images = []
        for img_path in image_files[:samples_per_corruption]:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img)
                corruption_images.append(img_tensor)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
            
            if len(corruption_images) >= samples_per_corruption:
                break
        
        if corruption_images:
            all_images.extend(corruption_images)
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    # convert to tensor
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # shuffle if needed
    if shuffle:
        indices = torch.randperm(x_test.size(0))
        x_test = x_test[indices]
        y_test = y_test[indices]
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_textures(
    n_examples: int,
    data_dir: str = '/data/dir',
    size: int = 32,
    shuffle: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load textures original dataset
    """
    data_dir = Path(data_dir)
    data_dir = data_dir / 'textures/images'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # find image files in directory
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_files.extend(list(data_dir.glob(ext)))
        image_files.extend(list(data_dir.glob(ext.upper())))
    
    if not image_files:
        raise FileNotFoundError(f"No image files found in {data_dir}")
    
    if shuffle:
        random.shuffle(image_files)
    
    # load specified number of images
    all_images = []
    for img_path in image_files[:n_examples]:
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img)
            all_images.append(img_tensor)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        
        if len(all_images) >= n_examples:
            break
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    # convert to tensor
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_textures_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = './data/dir',
    size: int = 32,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load corrupted OOD dataset by directly loading image files
    
    Args:
        n_examples: total number of examples to load
        severity: corruption severity (1-5)
        data_dir: directory containing corrupted dataset
        shuffle: shuffle data
        corruptions: list of corruption types to include
        size: image size (32)
        
    Returns:
        tuple of torch tensors (images, labels)
    """
    assert 1 <= severity <= 5
    
    data_dir = Path(data_dir)
    data_dir = data_dir / f'textures/images-C_{size}'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        # transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # collect images by each corruption type
    all_images = []
    n_pert = len(corruptions)
    samples_per_corruption = int(np.ceil(n_examples / n_pert))
    
    for corruption in corruptions:
        corruption_dir = data_dir / corruption / str(severity)
        
        if not corruption_dir.exists():
            print(f"Warning: {corruption_dir} doesn't exist. Skipping.")
            continue
        
        # find image files in directory
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_files.extend(list(corruption_dir.glob(ext)))
            image_files.extend(list(corruption_dir.glob(ext.upper())))
        
        if not image_files:
            print(f"Warning: No image files found in {corruption_dir}. Skipping.")
            continue
        
        if shuffle:
            random.shuffle(image_files)
        
        # load specified number of images
        corruption_images = []
        for img_path in image_files[:samples_per_corruption]:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img)
                corruption_images.append(img_tensor)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
            
            if len(corruption_images) >= samples_per_corruption:
                break
        
        if corruption_images:
            all_images.extend(corruption_images)
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    # convert to tensor
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    if shuffle:
        indices = torch.randperm(x_test.size(0))
        x_test = x_test[indices]
        y_test = y_test[indices]
    return x_test[:n_examples], y_test[:n_examples]