import glob
import torch
import torch.utils.data as data
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from PIL import Image
import os, sys
import pandas as pd
import lmdb
import csv
import random
import numpy as np
from PIL import Image
from torch.utils.data import Subset
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets import folder, ImageFolder
import logging

from random import triangular
from typing import Callable, Union
import os
from pathlib import Path
from typing import Callable, Dict, Optional, Sequence, Set, Tuple

from torch.utils.data import TensorDataset, DataLoader

logging.getLogger("PIL").setLevel(logging.WARNING)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def _load_dataset(
    dataset: Dataset,
    n_examples: Optional[int] = None,
    batch_size: Optional[int] = 1,
    data_seed: Optional[int] = 0) -> data.DataLoader:
    
    if n_examples and n_examples > 0:
        partition_idx = np.random.RandomState(data_seed).choice(
            len(dataset), n_examples, replace=False)
        dataset = Subset(dataset, partition_idx)
    
    test_loader = data.DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=4)
    
    return test_loader
    
class Read_Dataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.labels = {}
        with open(os.path.join(self.root_dir, 'label.txt'), 'r') as fp:
            self.label_texts = sorted([text.strip() for text in fp.readlines()])
        self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}
        self.samples = []
        for class_name in self.label_texts:
            class_dir = os.path.join(self.root_dir, mode, class_name)
            if not os.path.isdir(class_dir):
                continue
            for file_name in os.listdir(class_dir):
                file_path = os.path.join(class_dir, file_name)
                if os.path.isfile(file_path):
                    self.samples.append((file_path, self.label_text_to_number[class_name]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        image = Image.open(file_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label
    
def remove_prefix(s, prefix):
    if s.startswith(prefix):
        s = s[len(prefix):]
    return s
    
class ImageDataset(VisionDataset):
    """
    modified from: https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
    uses cached directory listing if available rather than walking directory
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, loader=folder.default_loader,
                 extensions=folder.IMG_EXTENSIONS, transform=None,
                 target_transform=None, is_valid_file=None, return_path=False):
        super(ImageDataset, self).__init__(root, transform=transform,
                                           target_transform=target_transform)
        classes, class_to_idx = self._find_classes(self.root)
        cache = self.root.rstrip('/') + '.txt'
        if os.path.isfile(cache):
            print("Using directory list at: %s" % cache)
            with open(cache) as f:
                samples = []
                for line in f:
                    (path, idx) = line.strip().split(';')
                    samples.append((os.path.join(self.root, path), int(idx)))
        else:
            print("Walking directory: %s" % self.root)
            samples = folder.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
            with open(cache, 'w') as f:
                for line in samples:
                    path, label = line
                    f.write('%s;%d\n' % (remove_prefix(path, self.root).lstrip('/'), label))

        if len(samples) == 0:
            raise (RuntimeError(
                "Found 0 files in subfolders of: " + self.root + "\nSupported extensions are: " + ",".join(extensions)))

        self.loader = loader
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.return_path = return_path

    def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.
        Ensures:
            No class is a subdirectory of another.
        """
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.return_path:
            return sample, target, path
        return sample, target

    def __len__(self):
        return len(self.samples)
    
# get the attributes from celebahq subset
def make_table(root):
    filenames = sorted(os.listdir(f'{root}/images'))
    # filter out non-png files, rename it to jpg to match entries in list_attr_celeba.txt
    celebahq = [os.path.basename(f).replace('png', 'jpg')
                if f.endswith('png') else os.path.basename(f) for f in filenames]
    attr_gt = pd.read_csv(f'{root}/list_attr_celeba.txt',
                          skiprows=1, delim_whitespace=True, index_col=0)
    attr_celebahq = attr_gt.reindex(index=celebahq).replace(-1, 0)

    # get the train/test/val partitions
    partitions = {}
    with open(f'{root}/list_eval_partition.txt') as f:
        for line in f:
            filename, part = line.strip().split(' ')
            partitions[filename] = int(part)
    partitions_list = [partitions[fname] for fname in attr_celebahq.index]

    attr_celebahq['partition'] = partitions_list
    return attr_celebahq

class CelebAHQDataset(Dataset):
    def __init__(self, partition, attribute, root, transform):
        self.root = root
        self.transform = transform
        self.dset = ImageDataset(root=self.root, transform=self.transform)

        # make table
        attr_celebahq = make_table(root)

        # convert from train/val/test to partition numbers
        part_to_int = dict(train=0, val=1, test=2)

        def get_partition_indices(part):
            return np.where(attr_celebahq['partition'] == part_to_int[part])[0]

        partition_idx = get_partition_indices(partition)

        self.dset = Subset(self.dset, partition_idx)
        attr_subset = attr_celebahq.iloc[partition_idx]
        self.attr_subset = attr_subset[attribute]

    def __len__(self):
        return len(self.dset)

    def __getitem__(self, idx):
        data = self.dset[idx]
        # first element is the class, replace it
        label = self.attr_subset[idx]
        return (data[0], label, *data[2:])
    
def shuffle_labels(label):
    max_val = torch.max(label).item()
    shuffled = torch.randint(0, max_val + 1, label.size()).to(device)
    shuffled[label == shuffled] = (shuffled[label == shuffled] + 1) % (max_val + 1)
    return shuffled

def extract_subset(dataset, num_subset :int, random_subset :bool):
    if random_subset:
        random.seed(0)
        indices = random.sample(list(range(len(dataset))), num_subset)
    else:
        indices = [i for i in range(num_subset)]
    return Subset(dataset, indices)


def load_dataset_example(dataset: str, train: bool = False, n_examples: Optional[int] = None, batch_size: Optional[int] = None, seed: Optional[int] = 0):
    # Common transformations
    common_train_transforms = [
        transforms.ToTensor(),
    ]
    
    common_test_transforms = [
        transforms.ToTensor(),
    ]
    
    # Dataset-specific transformations and loading functions
    dataset_params = {
        "Imagenette": {
            "train_transform": transforms.Compose([
                transforms.RandomCrop(160)
            ] + common_train_transforms),
            "test_transform": transforms.Compose([
                transforms.CenterCrop(160)
            ] + common_test_transforms),
            "train_loader": lambda: Read_Dataset(root_dir='./data/imagenette2-160', mode='train', transform=dataset_params["Imagenette"]["train_transform"]),
            "test_loader": lambda: Read_Dataset(root_dir='./data/imagenette2-160', mode='val', transform=dataset_params["Imagenette"]["test_transform"]),
        },
        "CIFAR10": {
            "train_transform": transforms.Compose([
                transforms.RandomCrop(32)
            ] + common_train_transforms),
            "test_transform": transforms.Compose(common_test_transforms),
            "train_loader": lambda: torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=dataset_params["CIFAR10"]["train_transform"]),
            "test_loader": lambda: torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=dataset_params["CIFAR10"]["test_transform"]),
        },
        "CIFAR100": {
            "train_transform": transforms.Compose([
                transforms.RandomCrop(32, padding=4)
            ] + common_train_transforms),
            "test_transform": transforms.Compose(common_test_transforms),
            "train_loader": lambda: torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=dataset_params["CIFAR100"]["train_transform"]),
            "test_loader": lambda: torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=dataset_params["CIFAR100"]["test_transform"]),
        },
        "ImageNet": {
            "train_transform": transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
            ] + common_train_transforms),
            "test_transform": transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
            ] + common_test_transforms),
            "train_loader": lambda: torchvision.datasets.ImageNet(root='./data', split='train', transform=dataset_params["ImageNet"]["train_transform"]),
            "test_loader": lambda: torchvision.datasets.ImageNet(root='./data', split='val', transform=dataset_params["ImageNet"]["test_transform"]),
        },
        "CelebA-HQ": {
            "train_transform": transforms.Compose([
                transforms.Resize(256),
            ] + common_train_transforms),
            "test_transform": transforms.Compose([
            ] + common_test_transforms),
            # 'Eyeglasses' 'Smiling'
            "train_loader": lambda: CelebAHQDataset('train', 'Smiling', root='./data', transform=dataset_params["CelebA-HQ"]["train_transform"]),
            "test_loader": lambda: CelebAHQDataset('val', 'Smiling', root='./data', transform=dataset_params["CelebA-HQ"]["train_transform"]),
        },
    }
    
    assert dataset in dataset_params, f"Unknown dataset: {dataset}"
    
    if train:
        dataset_instance = dataset_params[dataset]["train_loader"]()
    else:
        dataset_instance = dataset_params[dataset]["test_loader"]()

    return _load_dataset(dataset_instance, n_examples, batch_size, seed)

def load_imagenet_1k(img_size=256, batch_size=50, num_images=None):
    csv_filename = './data/imagenet/images.csv'
    input_path = './data/imagenet/images/'
    images = []
    targets = []

    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])

    with open(csv_filename) as csvfile:
        reader = csv.DictReader(csvfile, delimiter=',')
        for row in reader:
            image_id = row['ImageId']
            target_label = int(row['TrueLabel']) - 1

            img_path = f"{input_path}{image_id}.png"
            image = Image.open(img_path)
            image_tensor = transform(image)

            images.append(image_tensor)
            targets.append(target_label)

    images = torch.stack(images)
    targets = torch.tensor(targets)
    
    print(f"Loaded ImageNet-1k dataset with {len(images)} images")
    if num_images is not None:
        images = images[:num_images]
        targets = targets[:num_images]
        print(f"Returning only the first {num_images} images")
        
    print(f"Loaded ImageNet-1k dataset with {len(images)} images")

    # Create a TensorDataset and DataLoader
    dataset = torch.utils.data.TensorDataset(images, targets)
    dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return dataloader

def create_dataloader(dataset, Norm):
    if dataset == "ImageNet":
        if Norm == True:
            transform_train = transforms.Compose([
                    transforms.Resize(256),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])

            transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
        else:
            transform_train = transforms.Compose([
                    transforms.Resize(256),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor()
                ])

            transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
        train_dataset = torchvision.datasets.ImageNet(root='./data', split='train', transform=transform_train)
        test_dataset = torchvision.datasets.ImageNet(root='./data', split='val', transform=transform_test)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)
        return train_loader, test_loader
    if dataset == "CIFAR10":
        if Norm == True:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])
        train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)
        return train_loader, test_loader
    if dataset == "CIFAR100":
        if Norm == True:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
            ])
        train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)
        return train_loader, test_loader

##load image metadata (Image_ID, true label, and target label)
def load_ground_truth(csv_filename):
    image_id_list = []
    label_ori_list = []
    label_tar_list = []

    with open(csv_filename) as csvfile:
        reader = csv.DictReader(csvfile, delimiter=',')
        for row in reader:
            image_id_list.append( row['ImageId'] )
            label_ori_list.append( int(row['TrueLabel']) - 1 )
            label_tar_list.append( int(row['TargetClass']) - 1 )

    return image_id_list,label_ori_list,label_tar_list

