import os
import sys

import torch
import torchvision
from datasets.load import load_dataset
from torch.utils.data import DataLoader, Subset
# import datasets
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

from torchvision.datasets import ImageFolder
import os
import copy
class ImageFolderWithPaths(ImageFolder):
    def __getitem__(self, index):
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        path = self.imgs[index][0]
        tuple_with_path = original_tuple + (path,)
        return tuple_with_path

def replace_indexes(
    dataset: torch.utils.data.Dataset, indexes, seed=0, only_mark: bool = False
):
    if not only_mark:
        rng = np.random.RandomState(seed)
        new_indexes = rng.choice(
            list(set(range(len(dataset))) - set(indexes)), size=len(indexes)
        )
        dataset.data[indexes] = dataset.data[new_indexes]
        try:
            dataset.targets[indexes] = dataset.targets[new_indexes]
        except:
            dataset.labels[indexes] = dataset.labels[new_indexes]
        else:
            dataset._labels[indexes] = dataset._labels[new_indexes]
    else:
        dataset.targets = np.array(dataset.targets)
        dataset.targets[indexes] = -dataset.targets[indexes] - 1
        dataset.targets = dataset.targets.tolist()

def replace_class(
    dataset: torch.utils.data.Dataset,
    class_to_replace: int,
    num_indexes_to_replace: int = None,
    seed: int = 0,
    only_mark: bool = True,
):
    if class_to_replace == -1:
        try:
            indexes = np.flatnonzero(np.ones_like(dataset.targets))
        except:
            try:
                indexes = np.flatnonzero(np.ones_like(dataset.labels))
            except:
                indexes = np.flatnonzero(np.ones_like(dataset._labels))
    else:
        try:
            indexes = np.flatnonzero(np.array(dataset.targets) == class_to_replace)
        except:
            try:
                indexes = np.flatnonzero(np.array(dataset.labels) == class_to_replace)
            except:
                indexes = np.flatnonzero(np.array(dataset._labels) == class_to_replace)

    if num_indexes_to_replace is not None:
        assert num_indexes_to_replace <= len(
            indexes
        ), f"Want to replace {num_indexes_to_replace} indexes but only {len(indexes)} samples in dataset"
        rng = np.random.RandomState(seed)
        indexes = rng.choice(indexes, size=num_indexes_to_replace, replace=False)
    replace_indexes(dataset, indexes, seed, only_mark) 

def prepare_data(
    dataset,
    batch_size=512,
    shuffle=True,
    class_to_replace=None,
    indexes_to_replace=None,
    num_indexes_to_replace=None,
    train_subset_indices=None,
    val_subset_indices=None,
    only_mark: bool = True, 
    seed: int=0,
    adv: str=None,
    cor: str=None,
    cor_type: str=None,
    level: str=None,
    single: int=0,
    phase: str="train",
    unlearn: str="retrain",
    data_path="...",
    percent: int=10,
    arch: str="resnet18",
    multi_classes_to_replace: int=0,
    num_classes: int=10,
):
    path = os.path.join(data_path, "huggingface")
    if dataset == "imagenet":
        train_set = load_dataset(
            "imagenet-1k", use_auth_token=True, split="train", cache_dir=path
        )
        validation_set = load_dataset(
            "imagenet-1k", use_auth_token=True, split="validation", cache_dir=path
        )

        def train_transform(examples):
            transform = torchvision.transforms.Compose(
                [
                    torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
                    torchvision.transforms.RandomResizedCrop((224, 224)),
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.ToTensor(),
                ]
            )
            examples["image"] = [transform(x) for x in examples["image"]]
            return examples

        def validation_transform(examples):
            transform = torchvision.transforms.Compose(
                [
                    torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
                    torchvision.transforms.Resize((256, 256)),
                    torchvision.transforms.CenterCrop((224, 224)),
                    torchvision.transforms.ToTensor(),
                ]
            )
            examples["image"] = [transform(x) for x in examples["image"]]
            return examples
    elif dataset == "imagenet10":
        train_transform = transforms.Compose([
                    torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
                    torchvision.transforms.RandomResizedCrop((224, 224)),
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.ToTensor(),
                ]
    )

        validation_transform = transforms.Compose([
                    torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
                    torchvision.transforms.Resize((256, 256)),
                    torchvision.transforms.CenterCrop((224, 224)),
                    torchvision.transforms.ToTensor(),
                ])
        
                
        train_set = datasets.ImageFolder(os.path.join(data_path, 'train'), transform=train_transform)
        
        train_set_for_test = datasets.ImageFolder(os.path.join(data_path, 'train'), transform=validation_transform)

        val_set = datasets.ImageFolder(os.path.join(data_path, 'validation'), transform=validation_transform)
        
        
        train_set_adv = None
        val_set_adv = None
        
        if adv is not None:
            adv_path = adv
            train_set_adv = datasets.ImageFolder(os.path.join(adv_path,"train_same_folder"), transform=validation_transform)
            val_set_adv = datasets.ImageFolder(os.path.join(adv_path,"validation_same_folder"), transform=validation_transform)
        elif cor is not None:
            adv_path = cor
            train_set_adv = datasets.ImageFolder(os.path.join(adv_path,"train",cor_type,level), transform=validation_transform)
            val_set_adv = datasets.ImageFolder(os.path.join(adv_path,"validation",cor_type,level), transform=validation_transform)
        else:
            print("Wrong")
                    
        if class_to_replace is not None and num_indexes_to_replace is not None:
            raise ValueError(
                "Only one of `class_to_replace` and `indexes_to_replace` can be specified"
            )
        elif class_to_replace is None and num_indexes_to_replace is None:
            sets = {
                "train": train_set,
                "train_for_test":train_set_for_test,
                "val":val_set,
            }
            return sets
    
        elif class_to_replace is not None and num_indexes_to_replace is None:
                    
            forget_indices = np.flatnonzero(np.array(train_set.targets) == class_to_replace)
            retain_indices = np.flatnonzero(np.array(train_set.targets) != class_to_replace)
            
            forget_indices_for_test = np.flatnonzero(np.array(train_set_for_test.targets) == class_to_replace)
            retain_indices_for_test = np.flatnonzero(np.array(train_set_for_test.targets) != class_to_replace)
            forget_indices_val = np.flatnonzero(np.array(val_set.targets) == class_to_replace)
            retain_indices_val = np.flatnonzero(np.array(val_set.targets) != class_to_replace)
            
            forget_set = Subset(train_set, forget_indices)
            retain_set = Subset(train_set, retain_indices)
            
            forget_set_for_test = Subset(train_set_for_test, forget_indices_for_test) 
            retain_set_for_test = Subset(train_set_for_test, retain_indices_for_test)
            
            forget_set_val = Subset(val_set, forget_indices_val)
            retain_set_val = Subset(val_set, retain_indices_val)
                
            retain_set_adv = None
            forget_set_adv = None
            retain_set_val_adv = None
            forget_set_val_adv = None

            retain_set_adv = Subset(train_set_adv, retain_indices)
            forget_set_adv = Subset(train_set_adv, forget_indices)
            retain_set_val_adv = Subset(val_set_adv, retain_indices_val)
            forget_set_val_adv = Subset(val_set_adv, forget_indices_val)
                
            
            sets = {
                    "retain": retain_set,
                    "forget":forget_set,
                    "retain_for_test":retain_set_for_test,
                    "forget_for_test":forget_set_for_test,
                    "val": val_set,
                    "val_retain": retain_set_val,
                    "val_forget": forget_set_val,
                    "retain_adv": retain_set_adv,
                    "forget_adv": forget_set_adv,
                    "val_adv": val_set_adv,    
                    "val_retain_adv": retain_set_val_adv,
                    "val_forget_adv":forget_set_val_adv,
                    }
            return sets
        
        elif class_to_replace is None and num_indexes_to_replace is not None:
            len_train = len(train_set)
            assert len_train == 13000, f"The number for data-wise of imagenet10 setting is wrong."

            indexes = range(len_train)
            rng = np.random.RandomState(seed)
            indexes = rng.choice(indexes, size=int(len_train*(percent/100)), replace=False)#replace=False：表示抽样时不允许重复
            forget_indices = indexes
        
            all_index = np.arange(len_train)
            mask = ~np.isin(all_index, forget_indices)
            retain_indices = all_index[mask]

            forget_set = Subset(train_set, forget_indices)
            retain_set = Subset(train_set, retain_indices)

            forget_set_for_test = Subset(train_set_for_test, forget_indices)
            retain_set_for_test = Subset(train_set_for_test, retain_indices)
        
            retain_set_adv = Subset(train_set_adv, retain_indices)
            forget_set_adv = Subset(train_set_adv, forget_indices)
            
            sets = {
                    "retain": retain_set,
                    "forget":forget_set,
                    "retain_for_test":retain_set_for_test,
                    "forget_for_test":forget_set_for_test,
                    "val": val_set,
                    "retain_adv": retain_set_adv,
                    "forget_adv": forget_set_adv,
                    "val_adv": val_set_adv,    
                    }
            return sets
        
    elif dataset == "cifar10":
        if arch=="vit":
            train_transform = transforms.Compose(
                [
                transforms.Resize((224, 224)),           
                transforms.RandomCrop(224, padding=4),  
                transforms.RandomHorizontalFlip(),        
                transforms.ToTensor(),
                ]
            )
            test_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),            
                transforms.ToTensor(),
            ]
            )
        else:
            train_transform = transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ]
            )

            test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
            )        
        
        train_set = datasets.ImageFolder(os.path.join(data_path, 'train'), transform=train_transform)
        train_set_for_test = datasets.ImageFolder(os.path.join(data_path, 'train'), transform=test_transform)
        val_set = datasets.ImageFolder(os.path.join(data_path, 'test'), transform=test_transform)
        
        from PIL import Image
        train_set.data = torch.stack([train_transform(Image.open(path)) for path, _ in train_set.samples])
        train_set_for_test.data = torch.stack([test_transform(Image.open(path)) for path, _ in train_set_for_test.samples])
        val_set.data = torch.stack([test_transform(Image.open(path)) for path, _ in val_set.samples])


        
        train_set_adv = None
        if adv is not None:
            adv_path = adv
            train_set_adv = datasets.ImageFolder(os.path.join(adv_path,"train_same_folder"), transform=test_transform)
            val_set_adv = datasets.ImageFolder(os.path.join(adv_path,"test_same_folder"), transform=test_transform)
        elif cor is not None:
            adv_path = cor
            train_set_adv = datasets.ImageFolder(os.path.join(adv_path,"train",cor_type,level), transform=test_transform)
            val_set_adv = datasets.ImageFolder(os.path.join(adv_path,"test",cor_type,level), transform=test_transform)
        else:
            print("Wrong")
        
        if class_to_replace is not None and num_indexes_to_replace is not None:
            raise ValueError(
                "Only one of `class_to_replace` and `indexes_to_replace` can be specified"
            )
        elif class_to_replace is None and num_indexes_to_replace is None:
            sets = {
                    "train": train_set,
                    "train_for_test": train_set_for_test,
                    "val":val_set,
                    }
            return sets
    
        elif class_to_replace is not None and num_indexes_to_replace is None:
        
            forget_indices = np.flatnonzero(np.array(train_set.targets) == class_to_replace)
            retain_indices = np.flatnonzero(np.array(train_set.targets) != class_to_replace)
            
            forget_indices_for_test = np.flatnonzero(np.array(train_set_for_test.targets) == class_to_replace)
            retain_indices_for_test = np.flatnonzero(np.array(train_set_for_test.targets) != class_to_replace)
            
            
            forget_indices_val = np.flatnonzero(np.array(val_set.targets) == class_to_replace)
            retain_indices_val = np.flatnonzero(np.array(val_set.targets) != class_to_replace)
            
            
            retain_set = Subset(train_set, retain_indices)
            forget_set = Subset(train_set, forget_indices)   
             
            forget_set_for_test = Subset(train_set_for_test, forget_indices_for_test) 
            retain_set_for_test = Subset(train_set_for_test, retain_indices_for_test)
            
            retain_set_val = Subset(val_set, retain_indices_val)
            forget_set_val = Subset(val_set, forget_indices_val)
            
            
            retain_set.targets = np.array(train_set.targets)[retain_indices].tolist()
            forget_set.targets = np.array(train_set.targets)[forget_indices].tolist()
            
            forget_set_for_test.targets = np.array(train_set_for_test.targets)[forget_indices_for_test].tolist()
            retain_set_for_test.targets = np.array(train_set_for_test.targets)[retain_indices_for_test].tolist()
            
            retain_set_val.targets = np.array(val_set.targets)[retain_indices_val].tolist()
            forget_set_val.targets = np.array(val_set.targets)[forget_indices_val].tolist()
            
            retain_set.data = np.array(train_set.data)[retain_indices].tolist()
            forget_set.data = np.array(train_set.data)[forget_indices].tolist()
            
            forget_set_for_test.data = np.array(train_set_for_test.data)[forget_indices_for_test].tolist()
            retain_set_for_test.data = np.array(train_set_for_test.data)[retain_indices_for_test].tolist()
            
            retain_set_val.data = np.array(val_set.data)[retain_indices_val].tolist()
            forget_set_val.data = np.array(val_set.data)[forget_indices_val].tolist()
            
            
            retain_set_adv = None
            forget_set_adv = None
            retain_set_val_adv = None
            forget_set_val_adv = None
            
            retain_set_adv = Subset(train_set_adv, retain_indices)
            forget_set_adv = Subset(train_set_adv, forget_indices)
            retain_set_val_adv = Subset(val_set_adv, retain_indices_val)
            forget_set_val_adv = Subset(val_set_adv, forget_indices_val)
            sets = {
                    "retain": retain_set,
                    "forget":forget_set,
                    "retain_for_test":retain_set_for_test,
                    "forget_for_test":forget_set_for_test,
                    "val": val_set,
                    "val_retain": retain_set_val,
                    "val_forget": forget_set_val,
                    "retain_adv": retain_set_adv,
                    "forget_adv": forget_set_adv,
                    "val_adv": val_set_adv,    
                    "val_retain_adv": retain_set_val_adv,
                    "val_forget_adv":forget_set_val_adv,
                    }
            return sets
        
        elif class_to_replace is None and num_indexes_to_replace is not None:
            len_train = len(train_set)
            assert len_train == 50000, f"The number for data-wise of cifar10 setting is wrong."
            
            indexes = None
            forget_indices = None
            retain_indices = None
            
            indexes = range(len_train)
            rng = np.random.RandomState(seed)
            indexes = rng.choice(indexes, size=int(len_train*(percent/100)), replace=False)#replace=False：表示抽样时不允许重复
            forget_indices = indexes
            
            all_index = np.arange(len_train)
            mask = ~np.isin(all_index, forget_indices)
            retain_indices = all_index[mask]
            
            retain_set = Subset(train_set, retain_indices)
            forget_set = Subset(train_set, forget_indices)
            
            forget_set_for_test = Subset(train_set_for_test, forget_indices)
            retain_set_for_test = Subset(train_set_for_test, retain_indices)
            retain_set_adv = Subset(train_set_adv, retain_indices)
            forget_set_adv = Subset(train_set_adv, forget_indices)

            retain_set.targets = np.array(train_set.targets)[retain_indices].tolist()
            forget_set.targets = np.array(train_set.targets)[forget_indices].tolist()
            
            forget_set_for_test.targets = np.array(train_set_for_test.targets)[forget_indices].tolist()
            retain_set_for_test.targets = np.array(train_set_for_test.targets)[retain_indices].tolist()
            
            val_set.targets = np.array(val_set.targets).tolist()
            
            retain_set.data = np.array(train_set.data)[retain_indices].tolist()
            forget_set.data = np.array(train_set.data)[forget_indices].tolist()
            
            forget_set_for_test.data = np.array(train_set_for_test.data)[forget_indices].tolist()
            retain_set_for_test.data = np.array(train_set_for_test.data)[retain_indices].tolist()
            
            val_set.data = np.array(val_set.data).tolist()
            
            
          
            sets = {
                    "retain": retain_set,
                    "forget":forget_set,
                    "retain_for_test":retain_set_for_test,
                    "forget_for_test":forget_set_for_test,
                    "val": val_set,
                    "retain_adv": retain_set_adv,
                    "forget_adv": forget_set_adv,
                    "val_adv": val_set_adv,    
                    }
            return sets

    else:
        raise NotImplementedError

    if class_to_replace is not None and num_indexes_to_replace is not None:
        raise ValueError(
            "Only one of `class_to_replace` and `indexes_to_replace` can be specified"
        )
    elif class_to_replace is None and num_indexes_to_replace is None:
        if single>=0:
            val_indices = np.flatnonzero(np.array(train_set.targets) == single)
            single_train_set = Subset(train_set, val_indices)
            
            loaders = {
                "train": train_loader,
                "val": DataLoader(
                single_train_set, batch_size=batch_size, num_workers=4, shuffle=False
            ),
            }
            return loaders
        loaders = {
            "train": train_loader,
            "val":val_loader,
        }
        return loaders
    
    elif class_to_replace is not None and num_indexes_to_replace is None:
        forget_indices = np.flatnonzero(np.array(train_set.targets) == class_to_replace)
        retain_indices = np.flatnonzero(np.array(train_set.targets) != class_to_replace)
        retain_set = Subset(train_set, retain_indices)
        forget_set = None
        if adv is not None:
            forget_set = Subset(train_set_adv, forget_indices)
        else:
            forget_set = Subset(train_set, forget_indices)
 
        forget_indices_val = np.flatnonzero(np.array(val_set.targets) == class_to_replace)
        retain_indices_val = np.flatnonzero(np.array(val_set.targets) != class_to_replace)
        retain_set_val = Subset(val_set, retain_indices_val)
        forget_set_val = Subset(val_set, forget_indices_val)
        
        if phase=="train":
            loaders = {
                "retain": DataLoader(
                    retain_set, batch_size=batch_size, num_workers=4, shuffle=True
                ),
                "forget": DataLoader(
                    forget_set, batch_size=batch_size, num_workers=4, shuffle=False
                ),
                "val": DataLoader(
                    val_set, batch_size=batch_size, num_workers=4, shuffle=False
                ),
                "val_retain": DataLoader(
                    retain_set_val, batch_size=batch_size, num_workers=4, shuffle=False
                ),
                "val_forget": DataLoader(
                    forget_set_val, batch_size=batch_size, num_workers=4, shuffle=False
                )}
         
        else:
            loaders = {
            "retain": DataLoader(
                retain_set, batch_size=batch_size, num_workers=4, shuffle=False
            ),
            "forget": DataLoader(
                forget_set, batch_size=batch_size, num_workers=4, shuffle=False
            ),
            "val": DataLoader(
                val_set, batch_size=batch_size, num_workers=4, shuffle=False
            ),
            "val_retain": DataLoader(
                retain_set_val, batch_size=batch_size, num_workers=4, shuffle=False
            ),
            "val_forget": DataLoader(
                forget_set_val, batch_size=batch_size, num_workers=4, shuffle=False
            ),
        }
        return loaders
    elif num_indexes_to_replace is not None:
        assert num_indexes_to_replace <= len(
            train_set
        ), f"Want to replace {num_indexes_to_replace} indexes but only {len(train_set)} samples in dataset"
        
        indexes = None
        forget_indices = None
        retain_indices = None
        
        if dataset == "imagenet10":
            indexes = range(13000)
            rng = np.random.RandomState(seed)
            indexes = rng.choice(indexes, size=1300, replace=False)
            forget_indices = indexes
        
            all_index = np.arange(13000)
            mask = ~np.isin(all_index, forget_indices)
            retain_indices = all_index[mask]
            
        elif dataset == "cifar10":
            indexes = range(50000)
            rng = np.random.RandomState(seed)
            indexes = rng.choice(indexes, size=5000, replace=False)
            forget_indices = indexes
        
            all_index = np.arange(50000)
            mask = ~np.isin(all_index, forget_indices)
            retain_indices = all_index[mask]
        
       
        retain_set = Subset(train_set, retain_indices)
        
        forget_set = None
        if adv is not None:
            forget_set = Subset(train_set_adv, forget_indices)
        else:
            forget_set = Subset(train_set, forget_indices)
        if phase == "train":
            loaders = {
                "retain": DataLoader(
                    retain_set, batch_size=batch_size, num_workers=4, shuffle=True
                ),
                "forget": DataLoader(
                    forget_set, batch_size=batch_size, num_workers=4, shuffle=False  #####
                ),
                "val": DataLoader(
                    val_set, batch_size=batch_size, num_workers=4, shuffle=False
                ),
            }
        else:
            loaders = {
                "retain": DataLoader(
                    retain_set, batch_size=batch_size, num_workers=4, shuffle=False
                ),
                "forget": DataLoader(
                    forget_set, batch_size=batch_size, num_workers=4, shuffle=False 
                ),
                "val": DataLoader(
                    val_set, batch_size=batch_size, num_workers=4, shuffle=False
                ),
            }
        return loaders
        
    
    


def get_x_y_from_data_dict(data, device):
    x, y = data.values()
    if isinstance(x, list):
        x, y = x[0].to(device), y[0].to(device)
    else:
        x, y = x.to(device), y.to(device)
    return x, y


if __name__ == "__main__":
    ys = {}
    ys["train"] = []
    ys["val"] = []
    loaders = prepare_data(dataset="imagenet", batch_size=1, shuffle=False)
    for data in tqdm(loaders["val"], ncols=100):
        x, y = get_x_y_from_data_dict(data, "cpu")
        ys["val"].append(y.item())
    for data in tqdm(loaders["train"], ncols=100):
        x, y = get_x_y_from_data_dict(data, "cpu")
        ys["train"].append(y.item())
    ys["train"] = torch.Tensor(ys["train"]).long()
    ys["val"] = torch.Tensor(ys["val"]).long()
    torch.save(ys["train"], "train_ys.pth")
    torch.save(ys["val"], "val_ys.pth")
