import random
import torch
import math
from torch.utils.data import DataLoader, Sampler
import torch.distributed as dist
from torchvision import transforms
from datasets import load_dataset
import numpy as np
from torchvision import datasets
import os
from PIL import Image
from misc.imagenet_class_names import IMAGENET2012_CLASSES


def set_global_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f'Global seed set to {seed}')


class ClassWiseSampler(Sampler):

    def __init__(self, class_indices, batch_size, shuffle=True):
        self.class_indices = class_indices
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.batches = self._create_batches()
    def _create_batches(self):
        batches = []
        for class_label, indices in self.class_indices.items():
            if self.shuffle:
                indices = indices.copy()
                random.shuffle(indices)
            class_batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
            batches.extend(class_batches)
        if self.shuffle:
            random.shuffle(batches)
        return batches
    def __iter__(self):
        if self.shuffle:
            self.batches = self._create_batches()
        for batch in self.batches:
            yield batch
    def __len__(self):
        return len(self.batches)


class DistributedClassWiseSampler(Sampler):

    def __init__(self, dataset, class_indices, batch_size, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError('Requires distributed package to be available')
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError('Requires distributed package to be available')
            rank = dist.get_rank()
        self.dataset = dataset
        self.class_indices = class_indices
        self.batch_size = batch_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.shuffle = shuffle
        self.epoch = 0
        self.batches = self._create_batches()
    def _create_batches(self):
        batches = []
        for class_label, indices in self.class_indices.items():
            if self.shuffle:
                g = torch.Generator()
                g.manual_seed(self.epoch)
                indices = indices.copy()
                indices = torch.tensor(indices)
                indices = indices[torch.randperm(len(indices), generator=g)].tolist()
            class_batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
            batches.extend(class_batches)
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            random.shuffle(batches)
        total_size = len(batches)
        per_replica = int(math.ceil(total_size / self.num_replicas))
        total_size = per_replica * self.num_replicas
        if len(batches) < total_size:
            batches += batches[:total_size - len(batches)]
        start = self.rank * per_replica
        end = start + per_replica
        return batches[start:end]
    def set_epoch(self, epoch):
        self.epoch = epoch
        self.batches = self._create_batches()
    def __iter__(self):
        return iter(self.batches)
    def __len__(self):
        return len(self.batches)


def get_class_names(args):
    if args.dataset.lower() == 'cifar10' or args.dataset.lower() == 'uoft-cs/cifar10':
        with open('./misc/cifar10_class_names.txt', 'r') as fp:
            class_names = [class_name.strip() for class_name in fp.readlines()]
        return class_names
    elif args.dataset.lower() == 'imagenet':
        if args.subset == 'woof':
            class_file = './misc/class_woof.txt'
        elif args.subset == 'nette':
            class_file = './misc/class_nette.txt'
        elif args.subset == 'imagenet100':
            class_file = './misc/class100.txt'
        elif args.subset == 'imagenet1k':
            return list(IMAGENET2012_CLASSES.values())
        else:
            raise ValueError('Invalid subset')
        with open(class_file, 'r') as fp:
            sel_indices = [idx.strip() for idx in fp.readlines()]
        selected_class_names = [IMAGENET2012_CLASSES[idx] for idx in sel_indices]
        return selected_class_names
    else:
        raise ValueError(f'Unsupported dataset: {args.dataset}')


def collate_fn(batch):
    images = [item['image'] for item in batch]
    images = torch.stack(images)
    labels = [item['label'] for item in batch]
    labels = torch.tensor(labels)
    return (images, labels)


def center_crop_arr(pil_image, image_size):
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(tuple((x // 2 for x in pil_image.size)), resample=Image.BOX)
    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(tuple((round(x * scale) for x in pil_image.size)), resample=Image.BICUBIC)
    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y:crop_y + image_size, crop_x:crop_x + image_size])


def prepare_imagenet_dataloader(args, accelerator):
    set_global_seed(args.seed if hasattr(args, 'seed') else 42)
    if args.model_type.lower() == 'dit':
        transform = transforms.Compose([transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)])
    else:
        transform = transforms.Compose([transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()])
    with open('./misc/imagenet_class_indices.txt', 'r') as fp:
        all_classes = [class_index.strip() for class_index in fp.readlines()]
    if args.subset == 'woof':
        file_list = './misc/class_woof.txt'
    elif args.subset == 'nette':
        file_list = './misc/class_nette.txt'
    elif args.subset == 'imagenet100':
        file_list = './misc/class100.txt'
    elif args.subset == 'imagenet1k':
        file_list = './misc/imagenet_class_indices.txt'
    elif args.subset == 'class_idc':
        file_list = './misc/class_idc.txt'
    else:
        raise ValueError('Invalid subset')
    with open(file_list, 'r') as fp:
        sel_classes = [class_index.strip() for class_index in fp.readlines()]
    class_labels = [all_classes.index(sel_class) for sel_class in sel_classes]
    dataset = datasets.ImageFolder(root=os.path.join(args.imagenet_dir, 'train'), transform=transform)
    indices = [i for i, (_, label) in enumerate(dataset.samples) if label in class_labels]
    filtered_dataset = torch.utils.data.Subset(dataset, indices)
    labels = [dataset.targets[i] for i in indices]
    class_indices = {label: [i for i, l in enumerate(labels) if l == label] for label in class_labels}
    sampler = DistributedClassWiseSampler(dataset=filtered_dataset, class_indices=class_indices, batch_size=args.batch_size, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True)
    dataloader = DataLoader(filtered_dataset, batch_sampler=sampler, num_workers=4, pin_memory=True)
    return dataloader


def build_imagenet_filtered_dataset_and_class_indices(args):
    set_global_seed(args.seed if hasattr(args, 'seed') else 42)
    if args.model_type.lower() == 'dit':
        transform = transforms.Compose([transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)])
    else:
        transform = transforms.Compose([transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()])
    with open('./misc/imagenet_class_indices.txt', 'r') as fp:
        all_classes = [class_index.strip() for class_index in fp.readlines()]
    if args.subset == 'woof':
        file_list = './misc/class_woof.txt'
    elif args.subset == 'nette':
        file_list = './misc/class_nette.txt'
    elif args.subset == 'imagenet1k':
        file_list = './misc/imagenet_class_indices.txt'
    elif args.subset == 'class_idc':
        file_list = './misc/class_idc.txt'
    else:
        raise ValueError('Invalid subset')
    with open(file_list, 'r') as fp:
        sel_classes = [class_index.strip() for class_index in fp.readlines()]
    class_labels = [all_classes.index(sel_class) for sel_class in sel_classes]
    dataset = datasets.ImageFolder(root=os.path.join(args.imagenet_dir, 'train'), transform=transform)
    indices = [i for i, (_, label) in enumerate(dataset.samples) if label in class_labels]
    filtered_dataset = torch.utils.data.Subset(dataset, indices)
    labels = [dataset.targets[i] for i in indices]
    class_indices = {label: [i for i, l in enumerate(labels) if l == label] for label in class_labels}
    return (filtered_dataset, class_indices)


def prepare_dataloader(args, accelerator):
    set_global_seed(args.seed if hasattr(args, 'seed') else 42)
    if args.dataset.lower() == 'imagenet':
        return prepare_imagenet_dataloader(args, accelerator)
    else:
        raise ValueError(f'Unsupported dataset: {args.dataset}')