import os
import subprocess
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Sampler
import math
from utils.utils import TwoTransform
from torch.utils.data import Dataset
from tqdm import tqdm
from torchvision import transforms
from torchvision.datasets import ImageFolder

def get_dataset(args):
    mean, std, image_size, num_classes = get_statistics(args.data_name, args)
    args.mean, args.std, args.image_size, args.num_classes = mean, std, image_size, num_classes
    
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.model_name == 'ViT' or args.data_name == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])

    else:
        train_transform = transforms.Compose([
                transforms.RandomCrop(image_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])


    # if no dataset execute shell file
    if not os.path.exists(f'./dataset/{args.data_name}'):
        subprocess.run([f'./dataset/{args.data_name}.sh'])
        subprocess.run(['mv', f'./{args.data_name}', './dataset'])
        subprocess.run(['rm', f'./{args.data_name}_png.tar'])
    
    if args.test_mode == "sub_class":
        trainset = CIFAR100TO20(f'./dataset/{args.data_name}/train', transform=train_transform)
        testset = CIFAR100TO20(f'./dataset/{args.data_name}/test', transform=test_transform)
        trainset_test = CIFAR100TO20(f'./dataset/{args.data_name}/train', transform=test_transform)
    else:
        trainset = ImageFolder(root=f'./dataset/{args.data_name}/train', transform=train_transform)
        if 'imagenet' in args.data_name:
            testset = ImageFolder(root=f'./dataset/{args.data_name}/val', transform=test_transform)
        else:
            testset = ImageFolder(root=f'./dataset/{args.data_name}/test', transform=test_transform)
        trainset_test = ImageFolder(root=f'./dataset/{args.data_name}/train', transform=test_transform)
    return trainset, testset, trainset_test, num_classes


def get_dataloader(trainset, testset, trainset_test, args):
    train_loader = DataLoader(dataset=trainset, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(dataset=testset, batch_size=args.forget_batch_size, shuffle=False, num_workers=4)
    train_test_loader = DataLoader(dataset=trainset_test, batch_size=args.remain_batch_size, shuffle=False, num_workers=4)

    return train_loader, test_loader, train_test_loader

def get_statistics(data_name, args):

    # CIFAR datasets
    if data_name == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
        image_size = 32
        num_classes = 10
    elif data_name == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
        image_size = 32
        num_classes = 100
    elif data_name == "tinyimagenet":
        mean = (0.4802, 0.4481, 0.3975)
        std = (0.2302, 0.2265, 0.2262)
        image_size = 64
        num_classes = 200
    elif data_name == 'imagenet':
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        image_size = 224
        num_classes = 1000
    # FGVC datasets
    elif data_name == 'cars':
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        image_size = 224
        num_classes = 196
    elif data_name == 'flowers':
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        image_size = 224
        num_classes = 102
    else: raise ValueError('Invalid dataset name')

    # if args.model_name == 'vit':
    #     processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
    #     mean = processor.image_mean
    #     std = processor.image_std
    #     image_size = 224

    return mean, std, image_size, num_classes

def get_unlearn_loader(trainset, testset, trainset_test, args):
    train_adjacent_set, test_adjacent_set, train_adjacent_test_set = None, None, None
    train_adjacent_loader, test_adjacent_loader, train_adjacent_test_loader = None, None, None
    if args.test_mode == 'class': 
        if 'imagenet' in args.data_name:
            test_forget_set = FilteredDataset(f'./dataset/{args.data_name}/val', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=testset.transform, exclude=False)
            test_remain_set = FilteredDataset(f'./dataset/{args.data_name}/val', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=testset.transform, exclude=True)
        else:
            test_forget_set = FilteredDataset(f'./dataset/{args.data_name}/test', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=testset.transform, exclude=False)
            test_remain_set = FilteredDataset(f'./dataset/{args.data_name}/test', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=testset.transform, exclude=True)
        test_forget_loader = DataLoader(dataset=test_forget_set, batch_size=args.forget_batch_size, shuffle=False, num_workers=4)
        test_remain_loader = DataLoader(dataset=test_remain_set, batch_size=args.remain_batch_size, shuffle=False, num_workers=4)

        train_forget_set = FilteredDataset(f'./dataset/{args.data_name}/train', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=trainset.transform, exclude=False)
        train_remain_set = FilteredDataset(f'./dataset/{args.data_name}/train', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=trainset.transform, exclude=True)
        train_forget_loader = DataLoader(dataset=train_forget_set, batch_size=args.forget_batch_size, shuffle=True, num_workers=4)
        train_remain_loader = DataLoader(dataset=train_remain_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)

        train_forget_test_set = FilteredDataset(f'./dataset/{args.data_name}/train', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=testset.transform, exclude=False)
        train_remain_test_set = FilteredDataset(f'./dataset/{args.data_name}/train', list(range(args.class_idx, args.class_idx + args.class_idx_unlearn)), transform=testset.transform, exclude=True)
        train_forget_test_loader = DataLoader(dataset=train_forget_test_set, batch_size=args.forget_batch_size, shuffle=True, num_workers=4)
        train_remain_test_loader = DataLoader(dataset=train_remain_test_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)

    elif args.test_mode == "sub_class":
        assert args.data_name == 'cifar100', 'sub_class mode only supports cifar100'
        test_forget_set = SubClassDataset(f'./dataset/{args.data_name}/test', args.sub_class_name, transform=testset.transform, exclude=False)
        test_remain_set = SubClassDataset(f'./dataset/{args.data_name}/test', args.sub_class_name, transform=testset.transform, exclude=True)
        test_forget_loader = DataLoader(dataset=test_forget_set, batch_size=args.forget_batch_size, shuffle=False, num_workers=4)
        test_remain_loader = DataLoader(dataset=test_remain_set, batch_size=args.remain_batch_size, shuffle=False, num_workers=4)

        train_forget_set = SubClassDataset(f'./dataset/{args.data_name}/train', args.sub_class_name, transform=trainset.transform, exclude=False)
        train_remain_set = SubClassDataset(f'./dataset/{args.data_name}/train', args.sub_class_name, transform=trainset.transform, exclude=True)
        train_forget_loader = DataLoader(dataset=train_forget_set, batch_size=args.forget_batch_size, shuffle=True, num_workers=4)
        train_remain_loader = DataLoader(dataset=train_remain_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)

        train_forget_test_set = SubClassDataset(f'./dataset/{args.data_name}/train', args.sub_class_name, transform=testset.transform, exclude=False)
        train_remain_test_set = SubClassDataset(f'./dataset/{args.data_name}/train', args.sub_class_name, transform=testset.transform, exclude=True)
        train_forget_test_loader = DataLoader(dataset=train_forget_test_set, batch_size=args.forget_batch_size, shuffle=True, num_workers=4)
        train_remain_test_loader = DataLoader(dataset=train_remain_test_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)
        names = args.sub_class_name
        ids = [train_forget_test_set.dataset.class_to_idx[name] for name in names]
        adjacent_ids = []
        for id in ids:
            for key, value in train_forget_test_set.coarse_map.items():
                if id in value:
                    adjacent_ids.extend(value)
                    break

        # get names reverse idx_to_class
        idx_to_class = {v: k for k, v in train_forget_test_set.dataset.class_to_idx.items()}
        adjacent_classes = [idx_to_class[id] for id in adjacent_ids]
        for name in names: adjacent_classes.remove(name)

        train_adjacent_set = SubClassDataset(f'./dataset/{args.data_name}/train', adjacent_classes, transform=trainset.transform, exclude=False)
        test_adjacent_set = SubClassDataset(f'./dataset/{args.data_name}/test', adjacent_classes, transform=testset.transform, exclude=False)
        train_adjacent_test_set = SubClassDataset(f'./dataset/{args.data_name}/train', adjacent_classes, transform=testset.transform, exclude=False)

        train_adjacent_loader = DataLoader(dataset=train_adjacent_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)
        test_adjacent_loader = DataLoader(dataset=test_adjacent_set, batch_size=args.remain_batch_size, shuffle=False, num_workers=4)
        train_adjacent_test_loader = DataLoader(dataset=train_adjacent_test_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)

    elif args.test_mode == 'sample':
        test_forget_set = None
        test_remain_set = testset
        test_forget_loader = None
        test_remain_loader = DataLoader(dataset=test_remain_set, batch_size=args.remain_batch_size, shuffle=False, num_workers=4)

        train_forget_set = SampleDataset(f'./dataset/{args.data_name}/train', args.num_classes, exclude_num_per_class=args.sample_unlearn_per_class, transform=trainset.transform, forget=True)
        train_remain_set = SampleDataset(f'./dataset/{args.data_name}/train', args.num_classes, exclude_num_per_class=args.sample_unlearn_per_class, transform=trainset.transform, forget=False)
        train_forget_loader = DataLoader(dataset=train_forget_set, batch_size=args.forget_batch_size, shuffle=True, num_workers=4)
        train_remain_loader = DataLoader(dataset=train_remain_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)

        train_forget_test_set = SampleDataset(f'./dataset/{args.data_name}/train', args.num_classes, exclude_num_per_class=args.sample_unlearn_per_class, transform=testset.transform, forget=True)
        train_remain_test_set = SampleDataset(f'./dataset/{args.data_name}/train', args.num_classes, exclude_num_per_class=args.sample_unlearn_per_class, transform=testset.transform, forget=False)
        train_forget_test_loader = DataLoader(dataset=train_forget_test_set, batch_size=args.forget_batch_size, shuffle=True, num_workers=4)
        train_remain_test_loader = DataLoader(dataset=train_remain_test_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4)


    return train_forget_set, train_remain_set, test_forget_set, test_remain_set, train_forget_test_set, train_remain_test_set, \
        train_forget_loader, train_remain_loader, test_forget_loader, test_remain_loader, train_forget_test_loader, train_remain_test_loader, \
        train_adjacent_set, test_adjacent_set, train_adjacent_test_set, train_adjacent_loader, test_adjacent_loader, train_adjacent_test_loader

def split_class_data(dataset, forget_class, num_forget):
    forget_index = []
    remain_index = []
    sum = 0
    for i, (data, target) in enumerate(dataset):
        if target == forget_class and sum < num_forget:
            forget_index.append(i)
            sum += 1
        else:
            remain_index.append(i)
    return forget_index, remain_index

class FilteredDataset(Dataset):
    def __init__(self, root_dir, classes, transform=None, exclude=True):
        self.dataset = ImageFolder(root_dir, transform=transform)  # Apply transformations when loading images
        self.classes = classes
        self.exclude = exclude
        self.filtered_indices = self._filter_indices()
        self.transform = transform  # Store the transform for optional use in __getitem__

    def _filter_indices(self):
        # Efficiently filter indices based on metadata/annotations
        indices = []
        for idx, (_, class_idx) in enumerate(tqdm(self.dataset.samples)):
            if self.exclude:
                if class_idx not in self.classes:
                    indices.append(idx)
            else:
                if class_idx in self.classes:
                    indices.append(idx)

        return indices

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        # Map the filtered index back to the dataset's original indexing
        original_idx = self.filtered_indices[idx]
        image, label = self.dataset[original_idx]

        return image, label

class RandomDataset(Dataset):
    def __init__(self, subset, args):
        self.dataset = subset.dataset
        self.classes = subset.classes
        # self.exclude = subset.exclude
        self.filtered_indices = subset.filtered_indices
        self.transform = subset.transform  # Store the transform for optional use in __getitem__
        self.num_classes = args.num_classes

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        # Map the filtered index back to the dataset's original indexing
        original_idx = self.filtered_indices[idx]
        image, _ = self.dataset[original_idx]
        label = torch.randint(0, self.num_classes, (1,)).item()

        return image, label

class SubClassDataset(Dataset):
    def __init__(self, root_dir, sub_class_name, transform=None, exclude=True):
        self.dataset = ImageFolder(root_dir, transform=transform)  # Apply transformations when loading images
        # class name -> class index
        print(self.dataset.class_to_idx)
        self.classes = [self.dataset.class_to_idx[name] for name in sub_class_name]
        self.exclude = exclude
        self.coarse_map = {
            0:[4, 30, 55, 72, 95],
            1:[1, 32, 67, 73, 91],
            2:[54, 62, 70, 82, 92],
            3:[9, 10, 16, 28, 61],
            4:[0, 51, 53, 57, 83],
            5:[22, 39, 40, 86, 87],
            6:[5, 20, 25, 84, 94],
            7:[6, 7, 14, 18, 24],
            8:[3, 42, 43, 88, 97],
            9:[12, 17, 37, 68, 76],
            10:[23, 33, 49, 60, 71],
            11:[15, 19, 21, 31, 38],
            12:[34, 63, 64, 66, 75],
            13:[26, 45, 77, 79, 99],
            14:[2, 11, 35, 46, 98],
            15:[27, 29, 44, 78, 93],
            16:[36, 50, 65, 74, 80],
            17:[47, 52, 56, 59, 96],
            18:[8, 13, 48, 58, 90],
            19:[41, 69, 81, 85, 89]
        } # for cifar100
        self.filtered_indices = self._filter_indices()
        self.transform = transform  # Store the transform for optional use in __getitem__

    def _filter_indices(self):
        # Efficiently filter indices based on metadata/annotations
        indices = []
        for idx, (_, class_idx) in enumerate(self.dataset.samples):
            if self.exclude:
                if class_idx not in self.classes:
                    indices.append(idx)
            else:
                if class_idx in self.classes:
                    indices.append(idx)

        return indices

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        # Map the filtered index back to the dataset's original indexing
        original_idx = self.filtered_indices[idx]
        image, label = self.dataset[original_idx]
        for key, value in self.coarse_map.items():
            if label in value:
                label = key
                break
        # 100 classes -> 20 classes

        return image, label

class SampleDataset(Dataset):
    def __init__(self, root_dir, num_class, exclude_num_per_class=500, transform=None, forget=True):
        self.dataset = ImageFolder(root_dir, transform=transform)  # Apply transformations when loading images
        self.classes = [0] * num_class
        self.forget = forget
        self.exclude_num_per_class = exclude_num_per_class
        self.filtered_indices = self._filter_indices()
        self.transform = transform

    def _filter_indices(self):
        # Efficiently filter indices based on metadata/annotations
        indices = []
        for idx, (_, class_idx) in enumerate(tqdm(self.dataset.samples)):
            if self.forget:
                # 0 to exclude_num_per_class-1 samples are included
                if self.classes[class_idx] < self.exclude_num_per_class:
                    indices.append(idx)
                    self.classes[class_idx] += 1
            else:
                # exclude_num_per_class to end samples are included
                if self.classes[class_idx] >= self.exclude_num_per_class:
                    indices.append(idx)
                    self.classes[class_idx] += 1
                self.classes[class_idx] += 1

        return indices

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        # Map the filtered index back to the dataset's original indexing
        original_idx = self.filtered_indices[idx]
        image, label = self.dataset[original_idx]

        return image, label

class CIFAR100TO20(Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataset = ImageFolder(root_dir, transform=transform)  # Apply transformations when loading images
        self.coarse_map = {
            0:[4, 30, 55, 72, 95],
            1:[1, 32, 67, 73, 91],
            2:[54, 62, 70, 82, 92],
            3:[9, 10, 16, 28, 61],
            4:[0, 51, 53, 57, 83],
            5:[22, 39, 40, 86, 87],
            6:[5, 20, 25, 84, 94],
            7:[6, 7, 14, 18, 24],
            8:[3, 42, 43, 88, 97],
            9:[12, 17, 37, 68, 76],
            10:[23, 33, 49, 60, 71],
            11:[15, 19, 21, 31, 38],
            12:[34, 63, 64, 66, 75],
            13:[26, 45, 77, 79, 99],
            14:[2, 11, 35, 46, 98],
            15:[27, 29, 44, 78, 93],
            16:[36, 50, 65, 74, 80],
            17:[47, 52, 56, 59, 96],
            18:[8, 13, 48, 58, 90],
            19:[41, 69, 81, 85, 89]
        } # for cifar100
        self.transform = transform  # Store the transform for optional use in __getitem__

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Map the filtered index back to the dataset's original indexing
        image, label = self.dataset[idx]
        for key, value in self.coarse_map.items():
            if label in value:
                label = key
                break
        # 100 classes -> 20 classes

        return image, label

class IntraClassShuffleSampler(Sampler):
    def __init__(self, dataset, batch_size, switch_class=False):
        self.dataset = dataset
        self.batch_size = batch_size
        
        self.class_indices = self._group_and_shuffle_indices()

    def _group_and_shuffle_indices(self):
        class_indices = {}
        class_indices_list = []
        for idx, (_, label) in enumerate(self.dataset):
            if label not in class_indices:
                class_indices[label] = []
            class_indices[label].append(idx)
        for indices in class_indices.values():
            random.shuffle(indices)

        return class_indices

    def __iter__(self):
        self.class_indices = self._group_and_shuffle_indices()
        for label_indices in self.class_indices.values():
            for idx in label_indices:
                yield idx

    def __len__(self):
        return len(self.dataset)


class ClassMatchingDataset(Dataset):
    def __init__(self, forget_data, retain_data):
        super().__init__()
        self.forget_data = forget_data
        self.retain_data = retain_data
        self.forget_len = len(forget_data)
        self.retain_len = len(retain_data)
    
    def __len__(self):
        return self.retain_len + self.forget_len

    def __getitem__(self, index):
        if(index < self.forget_len):
            x = self.forget_data[index][0]
            y = -(self.forget_data[index][1] + 1)
            return x,y
        else:
            x = self.remain_data[index][0]
            y = self.remain_data[index][1]
            return x,y

