import os
import random
import numpy as np

import torch
import torchvision
from torchvision import transforms

    
def find_subclasses(spec, nclass, phase=0):
    classes = []
    cls_from = nclass * phase
    cls_to = nclass * (phase + 1)
    if spec == 'woof':
        file_list = './misc/class_woof.txt'
    elif spec == 'im100':
        file_list = './misc/class_100.txt'
    else:
        file_list = './misc/class_indices.txt'
    with open(file_list, 'r') as f:
        class_name = f.readlines()
    for c in class_name:
        c = c.split('\n')[0]
        classes.append(c)
    classes = classes[cls_from:cls_to]
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}

    return classes, class_to_idx


def find_original_classes(spec, classes):
    file_list = './misc/class_indices.txt'
    with open(file_list, 'r') as f:
        all_classes = f.readlines()
    all_classes = [class_name.split('\n')[0] for class_name in all_classes]
    original_classes = []
    for class_name in classes:
        original_classes.append(all_classes.index(class_name))
    return original_classes


class ImageFolder(torchvision.datasets.ImageFolder):
    def __init__(self, nclass, ipc, mem=False, spec='none', phase=0, **kwargs):
        super(ImageFolder, self).__init__(**kwargs)
        self.mem = mem
        self.spec = spec
        self.classes, self.class_to_idx = find_subclasses(
            spec=spec, nclass=nclass, phase=phase
        )
        self.original_classes = find_original_classes(spec=self.spec, classes=self.classes)
        self.samples, self.targets = self.load_subset(ipc=ipc)
        if self.mem:
            self.samples = [self.loader(path) for path in self.samples]

    def load_subset(self, ipc=-1):
        all_samples = torchvision.datasets.folder.make_dataset(
            self.root, self.class_to_idx, self.extensions
        )
        samples = np.array([item[0] for item in all_samples])
        targets = np.array([item[1] for item in all_samples])

        if ipc == -1:
            return samples, targets
        else:
            sub_samples = []
            sub_targets = []
            for c in range(len(self.classes)):
                c_indices = np.where(targets == c)[0]
                #random.shuffle(c_indices)
                sub_samples.extend(samples[c_indices[:ipc]])
                sub_targets.extend(targets[c_indices[:ipc]])
            return sub_samples, sub_targets

    def __getitem__(self, index):
        if self.mem:
            sample = self.samples[index]
        else:
            sample = self.loader(self.samples[index])
        sample = self.transform(sample)
        return sample, self.targets[index]

    def __len__(self):
        return len(self.targets)


def transform_imagenet(args):
    resize_test = [transforms.Resize(args.input_size // 7 * 8), transforms.CenterCrop(args.input_size)]

    cast = [transforms.ToTensor()]

    aug = [
        ShufflePatches(args.factor),
        transforms.RandomResizedCrop(
            size=args.input_size,
            scale=(1 / args.factor, args.max_scale_crops),
            antialias=True,
        ),
        transforms.RandomHorizontalFlip()
    ]

    normalize = [transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )]

    train_transform = transforms.Compose(cast + aug + normalize)
    test_transform = transforms.Compose(resize_test + cast + normalize)

    return train_transform, test_transform


sharing_strategy = "file_system"
torch.multiprocessing.set_sharing_strategy(sharing_strategy)


def set_worker_sharing_strategy(worker_id: int) -> None:
    torch.multiprocessing.set_sharing_strategy(sharing_strategy)


def load_data(args, coreset=False, resize_only=False, mem_flag=True, trainset_only=False):
    train_transform, test_transform = transform_imagenet(args)
    if len(args.data_dir) == 1:
        train_dir = os.path.join(args.data_dir[0], 'train')
        val_dir = os.path.join(args.data_dir[0], 'val')
    else:
        train_dir = args.data_dir[0]
        val_dir = os.path.join(args.data_dir[1], 'val')

    if resize_only:
        train_transform = transforms.Compose([
            transforms.Resize((512, 512)),
        ])
    elif coreset:
        train_transform = test_transform

    train_dataset = ImageFolder(
        nclass=args.nclass,
        ipc=args.ipc,
        mem=mem_flag,
        spec=args.spec,
        phase=args.phase,
        root=train_dir,
        transform=train_transform,
    )

    if trainset_only:
        return train_dataset

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        worker_init_fn=set_worker_sharing_strategy,
    )

    val_loader = torch.utils.data.DataLoader(
        ImageFolder(
            nclass=args.nclass,
            ipc=-1,
            mem=mem_flag,
            spec=args.spec,
            phase=args.phase,
            root=val_dir,
            transform=test_transform
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        worker_init_fn=set_worker_sharing_strategy,
    )
    print("load data successfully")

    return train_dataset, train_loader, val_loader


class ShufflePatches(torch.nn.Module):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def shuffle_weight(self, img, factor):
        h, w = img.shape[1:]
        tw = w // factor
        patches = []
        for i in range(factor):
            i = i * tw
            if i != factor - 1:
                patches.append(img[..., i : i + tw])
            else:
                patches.append(img[..., i:])
        random.shuffle(patches)
        img = torch.cat(patches, -1)
        return img

    def forward(self, img):
        img = self.shuffle_weight(img, self.factor)
        img = img.permute(0, 2, 1)
        img = self.shuffle_weight(img, self.factor)
        img = img.permute(0, 2, 1)
        return img
