from pathlib import Path
from typing import Callable, Optional, Sequence, Tuple

import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

from robustbench.data import CORRUPTIONS, PREPROCESSINGS
import random
from PIL import Image
import numpy as np

def load_tiny_imagenet_train(
    n_examples: Optional[int] = None,
    data_dir: str = './data/dir',
    size: int = 64,
    shuffle: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    prepr = transforms.Compose([transforms.Resize((size, size)), transforms.ToTensor()])
    data_folder_path = Path(data_dir) / 'tiny-imagenet-200'
    data_folder_path = data_folder_path / 'train'
    tiny_imagenet_train = ImageFolder(data_folder_path, prepr)
    
    if n_examples is None:
        n_examples = 100000
    loader = data.DataLoader(tiny_imagenet_train,
                                batch_size=n_examples,
                                shuffle=shuffle,
                                num_workers=0)
    
    x_test, y_test = next(iter(loader))
    
    return x_test, y_test

def load_tiny_imagenet(
    n_examples: Optional[int] = None,
    data_dir: str = '/data/dir',
    size: int = 64,
    shuffle: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:

    prepr = transforms.Compose([transforms.Resize((size, size)), transforms.ToTensor()])
    data_folder_path = Path(data_dir) / 'Tiny-ImageNet-C'
    
    data_folder_path = data_folder_path / 'val'
    tiny_imagenet = ImageFolder(data_folder_path, prepr)
    
    repeats = n_examples // len(tiny_imagenet) + 1
    repeated_tiny_imagenet = data.ConcatDataset([tiny_imagenet] * repeats)
    
    if n_examples is None:
        n_examples = 10000
    loader = data.DataLoader(repeated_tiny_imagenet,
                                batch_size=n_examples,
                                shuffle=shuffle,
                                num_workers=0)
    
    x_test, y_test = next(iter(loader))
    return x_test, y_test

def load_tiny_imagenet_c(
    n_examples: Optional[int] = 10000,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 64,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if n_examples > 10000:
        raise ValueError(
            'The evaluation is currently possible on at most 10000 points.')

    prepr = transforms.Compose([transforms.Resize((size, size)), transforms.ToTensor()])

    assert len(
        corruptions
    ) == 1, "so far only one corruption is supported (that's how this function is called in eval.py"
    # TODO: generalize this (although this would probably require writing a function similar to `load_corruptions_cifar`
    #  or alternatively creating yet another CustomImageFolder class that fetches images from multiple corruption types
    #  at once -- perhaps this is a cleaner solution)

    data_folder_path = Path(data_dir) / 'Tiny-ImageNet-C' / corruptions[0] / str(severity)
    tiny_imagenet_c = ImageFolder(data_folder_path, prepr)
    repeats = n_examples // len(tiny_imagenet_c) + 1
    repeated_tiny_imagenet_c = data.ConcatDataset([tiny_imagenet_c] * repeats)
    test_loader = data.DataLoader(repeated_tiny_imagenet_c,
                                  batch_size=n_examples,
                                  shuffle=shuffle,
                                  num_workers=2)

    x_test, y_test = next(iter(test_loader))

    return x_test, y_test

def load_imagenet_o(
    n_examples: Optional[int] = None,
    batch_size: int = 100,
    data_dir: str = '/data/dir',
    shuffle: bool = False,
    size: int = 64,
) -> Tuple[torch.Tensor, torch.Tensor]:

    prepr = transforms.Compose([transforms.Resize((size, size)), transforms.ToTensor()])
    data_folder_path = Path(data_dir) / 'ImageNet-O-64x64-C'

    data_folder_path = data_folder_path / 'val'
    imagenet_o = ImageFolder(data_folder_path, prepr)

    repeats = n_examples // len(imagenet_o) + 1
    repeated_imagenet_o = data.ConcatDataset([imagenet_o] * repeats)
    
    loader = data.DataLoader(repeated_imagenet_o,
                                batch_size=n_examples,
                                shuffle=shuffle,
                                num_workers=0)

    x_test, y_test = next(iter(loader))

    return x_test, y_test

def load_imagenet_o_c(
    n_examples: Optional[int] = 10000,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 64,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if n_examples > 10000:
        raise ValueError(
            'The evaluation is currently possible on at most 10000 points.')

    prepr = transforms.Compose([transforms.Resize((size, size)), transforms.ToTensor()])

    assert len(
        corruptions
    ) == 1, "so far only one corruption is supported (that's how this function is called in eval.py"
    # TODO: generalize this (although this would probably require writing a function similar to `load_corruptions_cifar`
    #  or alternatively creating yet another CustomImageFolder class that fetches images from multiple corruption types
    #  at once -- perhaps this is a cleaner solution)

    data_folder_path = Path(data_dir) / 'ImageNet-O-64x64-C' / corruptions[0] / str(severity)
    imagenet_o_64x64_c = ImageFolder(data_folder_path, prepr)
    repeats = n_examples // len(imagenet_o_64x64_c) + 1
    repeated_imagenet_o_64x64_c = data.ConcatDataset([imagenet_o_64x64_c] * repeats)
    test_loader = data.DataLoader(repeated_imagenet_o_64x64_c,
                                  batch_size=n_examples,
                                  shuffle=shuffle,
                                  num_workers=2)

    x_test, y_test = next(iter(test_loader))

    return x_test, y_test

def load_Places(
    n_examples: int,
    data_dir: str = '/data/dir',
    shuffle: bool = False,
    size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load Places original dataset
    """
    data_dir = Path(data_dir)
    data_dir = data_dir / 'Places/images'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # find image files in directory
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_files.extend(list(data_dir.glob(ext)))
        image_files.extend(list(data_dir.glob(ext.upper())))
    
    if not image_files:
        raise FileNotFoundError(f"No image files found in {data_dir}")
    
    if shuffle:
        random.shuffle(image_files)
    
    # load specified number of images
    all_images = []
    for img_path in image_files[:n_examples]:
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img)
            all_images.append(img_tensor)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        
        if len(all_images) >= n_examples:
            break
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    # convert to tensor
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_Places_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 64,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load corrupted OOD dataset by directly loading image files
    
    Args:
        n_examples: total number of examples to load
        severity: corruption severity (1-5)
        data_dir: directory containing corrupted dataset
        shuffle: shuffle data
        corruptions: list of corruption types to include
        size: image size (64)
        
    Returns:
        tuple of torch tensors (images, labels)
    """
    assert 1 <= severity <= 5
    
    data_dir = Path(data_dir)
    data_dir = data_dir / f'Places/images-C_{size}'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        # transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # collect images by each corruption type
    all_images = []
    n_pert = len(corruptions)
    samples_per_corruption = int(np.ceil(n_examples / n_pert))
    
    for corruption in corruptions:
        corruption_dir = data_dir / corruption / str(severity)
        
        if not corruption_dir.exists():
            print(f"Warning: {corruption_dir} doesn't exist. Skipping.")
            continue
        
        # find image files in directory
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_files.extend(list(corruption_dir.glob(ext)))
            image_files.extend(list(corruption_dir.glob(ext.upper())))
        
        if not image_files:
            print(f"Warning: No image files found in {corruption_dir}. Skipping.")
            continue
        
        # sample images
        if shuffle:
            random.shuffle(image_files)
        
        corruption_images = []
        for img_path in image_files[:samples_per_corruption]:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img)
                corruption_images.append(img_tensor)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
            
            if len(corruption_images) >= samples_per_corruption:
                break
        
        if corruption_images:
            all_images.extend(corruption_images)
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    x_test = torch.stack(repeated_all_images)
    
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    if shuffle:
        indices = torch.randperm(x_test.size(0))
        x_test = x_test[indices]
        y_test = y_test[indices]
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_iNaturalist(
    n_examples: int,
    data_dir: str = '/data/dir',
    shuffle: bool = False,
    size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load iNaturalist original dataset
    """
    data_dir = Path(data_dir)
    data_dir = data_dir / 'iNaturalist/images'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # find image files in directory
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_files.extend(list(data_dir.glob(ext)))
        image_files.extend(list(data_dir.glob(ext.upper())))
    
    if not image_files:
        raise FileNotFoundError(f"No image files found in {data_dir}")
    
    if shuffle:
        random.shuffle(image_files)
    
    # load specified number of images
    all_images = []
    for img_path in image_files[:n_examples]:
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img)
            all_images.append(img_tensor)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        
        if len(all_images) >= n_examples:
            break
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    # convert to tensor
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_iNaturalist_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 64,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load corrupted OOD dataset by directly loading image files
    
    Args:
        n_examples: total number of examples to load
        severity: corruption severity (1-5)
        data_dir: directory containing corrupted dataset
        shuffle: shuffle data
        corruptions: list of corruption types to include
        size: image size (64)
        
    Returns:
        tuple of torch tensors (images, labels)
    """
    assert 1 <= severity <= 5
    
    data_dir = Path(data_dir)
    data_dir = data_dir / f'iNaturalist/images-C_{size}'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        # transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # collect images by each corruption type
    all_images = []
    n_pert = len(corruptions)
    samples_per_corruption = int(np.ceil(n_examples / n_pert))
    
    for corruption in corruptions:
        corruption_dir = data_dir / corruption / str(severity)
        
        if not corruption_dir.exists():
            print(f"Warning: {corruption_dir} doesn't exist. Skipping.")
            continue
        
        # find image files in directory
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_files.extend(list(corruption_dir.glob(ext)))
            image_files.extend(list(corruption_dir.glob(ext.upper())))
        
        if not image_files:
            print(f"Warning: No image files found in {corruption_dir}. Skipping.")
            continue
        
        if shuffle:
            random.shuffle(image_files)
        
        corruption_images = []
        for img_path in image_files[:samples_per_corruption]:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img)
                corruption_images.append(img_tensor)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
            
            if len(corruption_images) >= samples_per_corruption:
                break
        
        if corruption_images:
            all_images.extend(corruption_images)
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    # convert to tensor
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # shuffle if needed
    if shuffle:
        indices = torch.randperm(x_test.size(0))
        x_test = x_test[indices]
        y_test = y_test[indices]
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_SUN(
    n_examples: int,
    data_dir: str = '/data/dir',
    shuffle: bool = False,
    size: int = 64
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load SUN original dataset
    """
    data_dir = Path(data_dir)
    data_dir = data_dir / 'SUN/images'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # find image files in directory
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_files.extend(list(data_dir.glob(ext)))
        image_files.extend(list(data_dir.glob(ext.upper())))
    
    if not image_files:
        raise FileNotFoundError(f"No image files found in {data_dir}")
    
    if shuffle:
        random.shuffle(image_files)
    
    all_images = []
    for img_path in image_files[:n_examples]:
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img)
            all_images.append(img_tensor)
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            continue
        
        if len(all_images) >= n_examples:
            break
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1
    repeated_all_images = all_images * repeats
    
    # convert to tensor
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]

def load_SUN_c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = '/data/dir',
    size: int = 64,
    shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load corrupted OOD dataset by directly loading image files
    
    Args:
        n_examples: total number of examples to load
        severity: corruption severity (1-5)
        data_dir: directory containing corrupted dataset
        shuffle: shuffle data
        corruptions: list of corruption types to include
        size: image size (64)
        
    Returns:
        tuple of torch tensors (images, labels)
    """
    assert 1 <= severity <= 5
    
    data_dir = Path(data_dir)
    data_dir = data_dir / f'SUN/images-C_{size}'
    
    # set transform - resize image and convert to tensor
    transform = transforms.Compose([
        # transforms.Resize((size, size)),
        transforms.ToTensor()
    ])
    
    # collect images by each corruption type
    all_images = []
    n_pert = len(corruptions)
    samples_per_corruption = int(np.ceil(n_examples / n_pert))
    
    for corruption in corruptions:
        corruption_dir = data_dir / corruption / str(severity)
        
        if not corruption_dir.exists():
            print(f"Warning: {corruption_dir} doesn't exist. Skipping.")
            continue
        
        # find image files in directory
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            image_files.extend(list(corruption_dir.glob(ext)))
            image_files.extend(list(corruption_dir.glob(ext.upper())))
        
        if not image_files:
            print(f"Warning: No image files found in {corruption_dir}. Skipping.")
            continue
            
        if shuffle:
            random.shuffle(image_files)
    
        corruption_images = []
        for img_path in image_files[:samples_per_corruption]:
            try:
                img = Image.open(img_path).convert('RGB')
                img_tensor = transform(img)
                corruption_images.append(img_tensor)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
            
            if len(corruption_images) >= samples_per_corruption:
                break
        
        if corruption_images:
            all_images.extend(corruption_images)
    
    if not all_images:
        raise FileNotFoundError(f"No valid images found in {data_dir}")
    
    repeats = n_examples // len(all_images) + 1 
    repeated_all_images = all_images * repeats

    # convert to tensor
    x_test = torch.stack(repeated_all_images)
    
    # set all labels to -1 (OOD)
    y_test = torch.full((x_test.size(0),), -1, dtype=torch.long)
    
    # shuffle if needed
    if shuffle:
        indices = torch.randperm(x_test.size(0))
        x_test = x_test[indices]
        y_test = y_test[indices]
    
    # return exactly n_examples (or as many as possible)
    return x_test[:n_examples], y_test[:n_examples]
