import numpy as np
import torch
from bisect import bisect_left
import sys
import random, json
import os
from PIL import Image
import copy
from datasets.tools import sample_id
from datasets.tools import load_image

class CIFAR5M(torch.utils.data.Dataset):

    def __init__(self, transform=None, data_num=50000):
        data_all = np.load("./data/cifar5m/cifar5m_sampled_x_500k.npy", mmap_mode='r')
        labels_all = np.load("./data/cifar5m/cifar5m_sampled_y_500k.npy", mmap_mode='r')

        self.transform = transform
        self.data_num = data_num
        # id = np.random.permutation(data_all.shape[0])
        self.data = np.array(data_all[:data_num])
        self.label = np.array(labels_all[:data_num])


        self.origin_data = copy.deepcopy(self.data)
        self.transform = transform
        self.id_sampled = sample_id(data_num, True)

        noisy_prob = (1 - np.eye(10)[self.label])/9.0
        self.noisy_label = np.ones_like(self.label) * (-1)
        for i in range(data_num):
            self.noisy_label[i] = np.random.choice(10, 1, p=noisy_prob[i])

    def __getitem__(self, index):

        img = self.data[index]
        label = self.noisy_label[index]
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)

        return img, label, index  # 0 is the class

    def __len__(self):
        return self.data_num

    def mix_out(self, lam=0.5):
        for id, id_out in zip(list(range(self.data_num)), self.id_sampled):
            self.data[id] = lam * self.origin_data[id] + (1-lam) * load_image(id_out)
