import os
import json
import random
import PIL.Image as Image
from copy import deepcopy
from tqdm import tqdm

from torch.utils.data import Dataset
from typing import *
from torchvision.transforms import ToPILImage

import torch
import numpy as np
from numpy.testing import assert_array_almost_equal

def unpickle(file):
    import _pickle as cPickle
    with open(file, 'rb') as fo:
        dict = cPickle.load(fo, encoding='latin1')
    return dict

def _img_loader(path, mode='RGB'):
    assert mode in ['RGB', 'L']
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert(mode)

def _make_dataset(image_dir):
    samples = []  # image_path, class_idx

    class_names, class_indices = _find_classes(image_dir)

    for class_name in sorted(class_names):
        class_idx = class_indices[class_name]
        target_dir = os.path.join(image_dir, class_name)

        if not os.path.isdir(target_dir):
            continue

        for root, _, files in sorted(os.walk(target_dir)):
            for file in sorted(files):
                image_path = os.path.join(root, file)
                item = image_path, class_idx
                samples.append(item)
    return samples

def _find_classes(root):
    class_names = [d.name for d in os.scandir(root) if d.is_dir()]
    class_names.sort()
    classes_indices = {class_names[i]: i for i in range(len(class_names))}
    # print(classes_indices)
    return class_names, classes_indices  # 'class_name':index

class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.samples = _make_dataset(self.image_dir)
        self.targets = [s[1] for s in self.samples]

    def __getitem__(self, index):
        image_path, target = self.samples[index]
        image = _img_loader(image_path, mode='RGB')
        
        name = os.path.split(image_path)[1]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.samples)

    

class ImageFileDataset(Dataset):
    def __init__(self, file_dir, data_type, path_prefix, transform=None):
        self.file_dir = os.path.join(file_dir, f"{data_type}.txt")
        self.samples = self._load_file(self.file_dir, path_prefix)
        self.transform = transform
    
    def __getitem__(self, index):
        image_path, target = self.samples[index]
        image = _img_loader(image_path, mode='RGB')
        name = os.path.split(image_path)[1]

        if self.transform is not None:
            image = self.transform(image)

        return image, target
    
    def __len__(self):
        return len(self.samples)
        
    def _load_file(self, file_path, path_prefix):
        samples = []
        with open(file_path) as f:
            for line in f:
                image_path = line.split()[0]
                if path_prefix is not None:
                    image_path = os.path.join(path_prefix, image_path)
                sample = image_path, int(line.split()[1])
                samples.append(sample)
        
        return samples
    
class NoiseLabelDataset(Dataset):
    def __init__(self,  dataset, noise_mode, noise_type, noise_path, root_dir, transform, noise_file='', r =0.2):
        
        assert dataset in ['cifar10', 'cifar100']
        self.dataset = dataset
        self.transform = transform
        self.noise_type = noise_type
        self.noise_path = noise_path
        self.noise_mode = noise_mode
        self.r = r
        self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for cifar10 asymmetric noise
        
        self.train_data = None
        self.train_labels = None
        self.train_noisy_labels = None

        if dataset == 'cifar10':
            self.num_classes = 10
            idx_each_class_noisy = [[] for i in range(10)]
        elif dataset == 'cifar100':
            self.num_classes = 100
            idx_each_class_noisy = [[] for i in range(100)]
        
        train_data = []
        train_label = []
        if dataset == 'cifar10':
            for k in range(1, 6):
                dpath = '%s/cifar-10-batches-py/data_batch_%d' % (root_dir, k)
                data_dic = unpickle(dpath)
                train_data.append(data_dic['data'])
                train_label = train_label + data_dic['labels']
            train_data = np.concatenate(train_data)
        elif dataset == 'cifar100':
            train_dic = unpickle('%s/cifar-100-python/train' % root_dir)
            train_data = train_dic['data']
            train_label = train_dic['fine_labels']
        train_data = train_data.reshape((50000, 3, 32, 32))
        train_data = train_data.transpose((0, 2, 3, 1))
        self.train_data = train_data
        self.train_labels = train_label

    
        noise_label = []
        # if noise_file is not None:
        if os.path.exists(noise_file):
            noise_label = json.load(open(noise_file,"r"))
            self.train_noisy_labels = noise_label
            self.noise_or_not = np.transpose(self.train_noisy_labels) != np.transpose(self.train_labels)
        else:    #inject noise   
            if self.noise_mode=='symmetric':
                idx = list(range(50000))
                random.shuffle(idx)
                num_noise = int(self.r * 50000)        
                noise_idx = idx[:num_noise]
                for i in range(50000):
                    if i in noise_idx:
                        noiselabel = random.randint(0, self.num_classes-1)
                        noise_label.append(noiselabel)                  
                    else:    
                        noise_label.append(train_label[i])   
                
                self.train_noisy_labels = noise_label
                self.noise_or_not = np.transpose(self.train_noisy_labels) != np.transpose(self.train_labels)
                
                print("save noisy labels to %s ..."%noise_file)
                os.makedirs(os.path.dirname(noise_file), exist_ok=True)    
                json.dump(noise_label, open(noise_file,"w"))
            elif self.noise_mode=='asymmetric':
                if dataset == 'cifar10':
                    idx = list(range(50000))
                    random.shuffle(idx)
                    num_noise = int(self.r * 50000)        
                    noise_idx = idx[:num_noise]
                    for i in range(50000):
                        if i in noise_idx:
                            noiselabel = self.transition[train_label[i]]
                            noise_label.append(noiselabel)                  
                        else:    
                            noise_label.append(train_label[i])   
                    
                else:
                    P = np.eye(self.num_classes)
                    nb_superclasses = 20
                    nb_subclasses = 5

                    if self.r > 0.0:
                        for i in np.arange(nb_superclasses):
                            init, end = i * nb_subclasses, (i+1) * nb_subclasses
                            P[init:end, init:end] = self._build_for_cifar100(nb_subclasses, self.r)

                        noise_label = self._multiclass_noisify(np.array(self.train_labels), P=P,
                                                        random_state=0)
                        noise_label = noise_label.tolist()
                        actual_noise = (np.array(noise_label) != np.array(self.train_labels)).mean()
                        assert actual_noise > 0.0
                
                self.train_noisy_labels = noise_label
                self.noise_or_not = np.transpose(self.train_noisy_labels) != np.transpose(self.train_labels)
                    
                print("save noisy labels to %s ..."%noise_file)
                os.makedirs(os.path.dirname(noise_file), exist_ok=True)       
                json.dump(noise_label,open(noise_file,"w"))

            elif self.noise_mode == 'cifarn':
                if noise_type != 'clean':
                    # Load human noisy labels
                    train_noisy_labels = self.load_label()
                    self.train_noisy_labels = train_noisy_labels.tolist()
                    self.noise_or_not = np.transpose(self.train_noisy_labels) != np.transpose(self.train_labels)
                
                    for i in range(len(self.train_noisy_labels)):
                        idx_each_class_noisy[self.train_noisy_labels[i]].append(i)
                    class_size_noisy = [len(idx_each_class_noisy[i]) for i in range(10)]
                    self.noise_prior = np.array(class_size_noisy) / sum(class_size_noisy)
                    print(f'The noisy data ratio in each class is {self.noise_prior}')
                    self.noise_or_not = np.transpose(self.train_noisy_labels) != np.transpose(self.train_labels)
            
        self.actual_noise_rate = np.sum(self.noise_or_not) / 50000
        print('over all noise rate is ', self.actual_noise_rate)
    

    def load_label(self):
        # NOTE only load manual training label
        noise_label = torch.load(self.noise_path)
        if isinstance(noise_label, dict):
            if "clean_label" in noise_label.keys():
                clean_label = torch.tensor(noise_label['clean_label'])
                assert torch.sum(torch.tensor(self.train_labels) - clean_label) == 0
                print(f'Loaded {self.noise_type} from {self.noise_path}.')
                print(f'The overall noise rate is {1 - np.mean(clean_label.numpy() == noise_label[self.noise_type])}')
            return noise_label[self.noise_type].reshape(-1)
        else:
            raise Exception('Input Error')
    
    def _multiclass_noisify(self, y, P, random_state=0):
        """ Flip classes according to transition probability matrix T.
        It expects a number between 0 and the number of classes - 1.
        """

        assert P.shape[0] == P.shape[1]
        assert np.max(y) < P.shape[0]

        # row stochastic matrix
        assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
        assert (P >= 0.0).all()

        m = y.shape[0]
        new_y = y.copy()
        flipper = np.random.RandomState(random_state)

        for idx in np.arange(m):
            i = y[idx]
            # draw a vector with only an 1
            flipped = flipper.multinomial(1, P[i, :], 1)[0]
            new_y[idx] = np.where(flipped == 1)[0]

        return new_y



    def _build_for_cifar100(self, size, noise):
        """ random flip between two random classes.
        """
        assert(noise >= 0.) and (noise <= 1.)

        P = np.eye(size)
        cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
        P[cls1, cls2] = noise
        P[cls2, cls1] = noise
        P[cls1, cls1] = 1.0 - noise
        P[cls2, cls2] = 1.0 - noise

        assert_array_almost_equal(P.sum(axis=1), 1, 1)
        return P
    
    def __getitem__(self, index):
        image, target = self.train_data[index], self.train_noisy_labels[index]
        
        image = Image.fromarray(image)
        
        if self.transform is not None:
            image = self.transform(image)
        
        return image, target
        

    def __len__(self):
        return len(self.train_labels)



class poisonedCLSDataContainer:
    '''
    Two mode:
        in RAM / disk
        if in RAM
            save {key : value}
        elif in disk:
            save {
                key : {
                    "path":path, (must take a PIL image and save to .png)
                    "other_info": other_info, (Non-img)
                    }
            }
            where img, *other_info = value
    '''
    def __init__(self, save_folder_path=None, save_file_format = ".png"):
        self.save_folder_path = save_folder_path
        self.data_dict = {}
        self.save_file_format = save_file_format

    def retrieve_state(self):
        return {
            "save_folder_path":self.save_folder_path,
            "data_dict":self.data_dict,
            "save_file_format":self.save_file_format,
        }

    def set_state(self, state_file, root=None):
        self.save_folder_path = state_file["save_folder_path"]
        self.data_dict = state_file["data_dict"]
        self.save_file_format = state_file["save_file_format"]
        self.root = root
            
    def setitem(self, key, value, relative_loc_to_save_folder_name=None):

        if self.save_folder_path is None:
            self.data_dict[key] = value
        else:
            img, *other_info = value

            save_subfolder_path = f"{self.save_folder_path}/{relative_loc_to_save_folder_name}"
            if not (
                os.path.exists(save_subfolder_path)
                and
                os.path.isdir(save_subfolder_path)
            ):
                os.makedirs(save_subfolder_path)

            file_path = f"{save_subfolder_path}/{key}{self.save_file_format}"
            img.save(file_path)

            self.data_dict[key] = {
                    "path": file_path,
                    "other_info": other_info,
            }

    def __getitem__(self, key):
        if self.save_folder_path is None:
            return self.data_dict[key]
        else:
            file_path = self.data_dict[key]["path"]
            if self.root:
                file_path = os.path.join(self.root, file_path.lstrip('./'))
                
            other_info = self.data_dict[key]["other_info"]
            img =  Image.open(file_path)
            im = deepcopy(img)
            img.close()
            return (im, *other_info)

    def __len__(self):
        return len(self.data_dict)
    
class prepro_cls_DatasetBD_v2(torch.utils.data.Dataset):

    def __init__(
            self,
            full_dataset_without_transform,
            poison_indicator: Optional[Sequence] = None,  # one-hot to determine which image may take bd_transform

            bd_image_pre_transform: Optional[Callable] = None,
            bd_label_pre_transform: Optional[Callable] = None,
            save_folder_path = None,

            mode = 'attack',
        ):
        '''
        This class require poisonedCLSDataContainer

        :param full_dataset_without_transform: dataset without any transform. (just raw data)

        :param poison_indicator:
            array with 0 or 1 at each position corresponding to clean/poisoned
            Must have the same len as given full_dataset_without_transform (default None, regarded as all 0s)

        :param bd_image_pre_transform:
        :param bd_label_pre_transform:
        ( if your backdoor method is really complicated, then do not set these two params. These are for simplicity.
        You can inherit the class and rewrite method preprocess part as you like)

        :param save_folder_path:
            This is for the case to save the poisoned imgs on disk.
            (In case, your RAM may not be able to hold all poisoned imgs.)
            If you do not want this feature for small dataset, then just left it as default, None.

        '''

        self.dataset = full_dataset_without_transform

        if poison_indicator is None:
            poison_indicator = np.zeros(len(full_dataset_without_transform))
        self.poison_indicator = poison_indicator

        assert len(full_dataset_without_transform) == len(poison_indicator)

        self.bd_image_pre_transform = bd_image_pre_transform
        self.bd_label_pre_transform = bd_label_pre_transform

        self.save_folder_path = save_folder_path # since when we want to save this dataset, this may cause problem

        self.original_index_array = np.arange(len(full_dataset_without_transform))

        self.bd_data_container = poisonedCLSDataContainer(self.save_folder_path, ".png")

        if sum(self.poison_indicator) >= 1:
            self.prepro_backdoor()

        self.getitem_all = True
        self.getitem_all_switch = False

        self.mode = mode

    def prepro_backdoor(self):
        for selected_index in tqdm(self.original_index_array, desc="prepro_backdoor"):
            if self.poison_indicator[selected_index] == 1:
                img, label = self.dataset[selected_index]
                img = self.bd_image_pre_transform(img, target=label, image_serial_id=selected_index)
                bd_label = self.bd_label_pre_transform(label)
                self.set_one_bd_sample(
                    selected_index, img, bd_label, label
                )

    def set_one_bd_sample(self, selected_index, img, bd_label, label):

        '''
        1. To pil image
        2. set the image to container
        3. change the poison_index.

        logic is that no matter by the prepro_backdoor or not, after we set the bd sample,
        This method will automatically change the poison index to 1.

        :param selected_index: The index of bd sample
        :param img: The converted img that want to put in the bd_container
        :param bd_label: The label bd_sample has
        :param label: The original label bd_sample has

        '''

        # we need to save the bd img, so we turn it into PIL
        if (not isinstance(img, Image.Image)) :
            if isinstance(img, np.ndarray):
                img = img.astype(np.uint8)
            img = ToPILImage()(img)
        
        self.bd_data_container.setitem(
            key=selected_index,
            value=(img, bd_label, label),
            relative_loc_to_save_folder_name=f"{label}",
        )
        self.poison_indicator[selected_index] = 1

    def __len__(self):
        return len(self.original_index_array)

    def __getitem__(self, index):

        original_index = self.original_index_array[index]
        if self.poison_indicator[original_index] == 0:
            # clean
            img, label = self.dataset[original_index]
            original_target = label
            poison_or_not = 0
        else:
            # bd
            img, label, original_target = self.bd_data_container[original_index]
            poison_or_not = 1

        if not isinstance(img, Image.Image):
            img = ToPILImage()(img)

        if self.getitem_all:
            if self.getitem_all_switch:
                # this is for the case that you want original targets, but you do not want change your testing process
                return img, \
                       original_target, \
                       original_index, \
                       poison_or_not, \
                       label

            else: # here should corresponding to the order in the bd trainer
                return img, \
                       label, \
                       original_index, \
                       poison_or_not, \
                       original_target
        else:
            return img, label

    def subset(self, chosen_index_list):
        self.original_index_array = self.original_index_array[chosen_index_list]

    def retrieve_state(self):
        return {
            "bd_data_container" : self.bd_data_container.retrieve_state(),
            "getitem_all":self.getitem_all,
            "getitem_all_switch":self.getitem_all_switch,
            "original_index_array": self.original_index_array,
            "poison_indicator": self.poison_indicator,
            "save_folder_path": self.save_folder_path,
        }

    def copy(self):
        bd_train_dataset = prepro_cls_DatasetBD_v2(self.dataset)
        copy_state = deepcopy(self.retrieve_state())
        bd_train_dataset.set_state(
            copy_state
        )
        return bd_train_dataset

    def set_state(self, state_file, root):
        self.bd_data_container = poisonedCLSDataContainer()
        self.bd_data_container.set_state(
            state_file['bd_data_container'],
            root
        )
        self.getitem_all = state_file['getitem_all']
        self.getitem_all_switch = state_file['getitem_all_switch']
        self.original_index_array = state_file["original_index_array"]
        self.poison_indicator = state_file["poison_indicator"]
        self.save_folder_path = state_file["save_folder_path"]

class dataset_wrapper_with_transform(torch.utils.data.Dataset):
    '''
    idea from https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python
    '''

    def __init__(self, obj, wrap_img_transform=None, wrap_label_transform=None):

        # this warpper should NEVER be warp twice.
        # Since the attr name may cause trouble.
        assert not "wrap_img_transform" in obj.__dict__
        assert not "wrap_label_transform" in obj.__dict__

        self.wrapped_dataset = obj
        self.wrap_img_transform = wrap_img_transform
        self.wrap_label_transform = wrap_label_transform

    def __getattr__(self, attr):
        # # https://github.com/python-babel/flask-babel/commit/8319a7f44f4a0b97298d20ad702f7618e6bdab6a
        # # https://stackoverflow.com/questions/47299243/recursionerror-when-python-copy-deepcopy
        # if attr == "__setstate__":
        #     raise AttributeError(attr)
        if attr in self.__dict__:
            return getattr(self, attr)
        return getattr(self.wrapped_dataset, attr)

    def __getitem__(self, index):
        img, label, *other_info = self.wrapped_dataset[index]
        if self.wrap_img_transform is not None:
            img = self.wrap_img_transform(img)
        if self.wrap_label_transform is not None:
            label = self.wrap_label_transform(label)
        return (img, label, *other_info)

    def __len__(self):
        return len(self.wrapped_dataset)
    
    def __deepcopy__(self, memo):
        # In copy.deepcopy, init() will not be called and some attr will not be initialized. 
        # The getattr will be infinitely called in deepcopy process.
        # So, we need to manually deepcopy the wrapped dataset or raise error when "__setstate__" us called. Here we choose the first solution.
        return dataset_wrapper_with_transform(copy.deepcopy(self.wrapped_dataset), copy.deepcopy(self.wrap_img_transform), copy.deepcopy(self.wrap_label_transform))
        
if __name__ == "__main__":
    data_dir = '/nfs196/wjx/datasets/Imagenet-lt'
    # data_name = 'cifar100-cifarn-noisy_label'
    data_type = 'test'
    path_prefix = '/nfs196/hjc/datasets/ILSVRC2012'
        
    # if data_type == 'train':
    #     dataset, noise_mode, noise_type = data_name.split('-')
    #     if noise_mode in ['symmetric', 'asymmetric']:
    #         noise_rate = float(noise_type)
    #         noise_type = None
    #         noise_path = None
    #     elif noise_mode == 'cifarn':
    #         noise_rate = 0.0
    #         noise_path = os.path.join(data_dir, 'cifarn', 'CIFAR-10_human.pt' if dataset == 'cifar10' else 'CIFAR-100_human.pt')
        
    #     NoiseLabelDataset(dataset=dataset,
    #                         noise_mode=noise_mode,
    #                         noise_type=noise_type,
    #                         noise_path=noise_path,
    #                         root_dir=data_dir,
    #                         transform=None,
    #                         noise_file=os.path.join(data_dir, data_name, 'noise_label.json'),
    #                         r=noise_rate)
    
    # dataset = ImageFileDataset(file_dir=data_dir,
                                # data_type=data_type,
                                # path_prefix=path_prefix,
                                # transform=None)
    # dataset = ImageDataset(image_dir = os.path.join('/nfs196/wjx/datasets/Tiny-Imagenet', data_type),
    #                        transform=None)
    
    # print(len(dataset))