import torch as t
from PIL import Image
import torchvision.transforms as tr
from torch.utils.data import DataLoader


def get_data(args):
    args.log.info('Data setting ...')

    if args.data == 'CM':
        from data.list.cm import dataclass
    elif args.data == 'WM':
        from data.list.wm import dataclass
    else:
        print('Invalid dataset ...')
        exit()
    
    aug_dict = {}
    aug_dict['aug'] = tr.Compose([tr.ToPILImage(),
                                tr.RandomResizedCrop(args.img_size, scale=(0.9,1.1)),
                                tr.ColorJitter(hue=.05, saturation=.05),
                                tr.RandomRotation((-10,10),resample=Image.NEAREST),
                                tr.ToTensor()])
    aug_dict['noaug'] = tr.Compose([tr.ToPILImage(),
                                tr.Resize(args.img_size),
                                tr.ToTensor()])
    
    loader_dict = {}
    loader_dict['train'] = DataLoader(dataclass(args, aug_dict, setting='train'),
                                        batch_size = args.batch,
                                        shuffle = True,
                                        num_workers = 8,
                                        drop_last = False)
    loader_dict['val'] = DataLoader(dataclass(args, aug_dict, setting='val'),
                                        batch_size = args.batch,
                                        shuffle = False,
                                        num_workers = 8,
                                        drop_last = False)
    loader_dict['test'] = DataLoader(dataclass(args, aug_dict, setting='test'),
                                        batch_size = args.batch,
                                        shuffle = False,
                                        num_workers = 8,
                                        drop_last = False)

    return loader_dict


class loader(DataLoader):
    def __init__(self, args, aug):
        self.args = args
        self.aug = aug
        self.sampling = False
        
    def __getitem__(self,idx):
        # 0: img, 1: label, 2: bias, 3: ground truth label, 4: Index
        if self.sampling:
            idx = self.new_idx()
        img = self.use_aug(self.imgs[idx]) 
        label = self.label[idx] 
        bias = self.midx[idx]
        gt_label = self.gt_label[idx]
        return self.use_aug(img), label, bias, gt_label, idx

    def __len__(self):
        return len(self.label) 

    def new_idx(self):
        return t.clamp(t.sum(t.rand(1) > self.prob) ,0, len(self.label)-1)

    def prob_update(self,prob):
        self.prob = t.cumsum(prob,dim=0)

    def sample_on(self):
        self.sampling = True

    def sample_off(self):
        self.sampling = False

    def refine_dataset(self,idxs):
        self.args.log.debug('-'*100)
        clean_major = t.where((self.gt_label == self.label) & (self.midx == False))[0]
        clean_minor = t.where((self.gt_label == self.label) & (self.midx == True))[0]
        noise_major = t.where((self.gt_label != self.label) & (self.midx == False))[0]
        noise_minor = t.where((self.gt_label != self.label) & (self.midx == True))[0]

        self.args.log.debug("[Before] Clean Major : %d, Clean Minor : %d, Noise Major : %d, Noise Minor : %d" \
            %(len(clean_major), len(clean_minor), len(noise_major), len(noise_minor)))

        self.imgs = self.imgs[idxs]
        self.label = self.label[idxs]
        self.midx = self.midx[idxs]
        self.gt_label = self.gt_label[idxs]

        clean_major = t.where((self.gt_label == self.label) & (self.midx == False))[0]
        clean_minor = t.where((self.gt_label == self.label) & (self.midx == True))[0]
        noise_major = t.where((self.gt_label != self.label) & (self.midx == False))[0]
        noise_minor = t.where((self.gt_label != self.label) & (self.midx == True))[0]

        self.args.log.debug("[ After] Clean Major : %d, Clean Minor : %d, Noise Major : %d, Noise Minor : %d" \
            %(len(clean_major), len(clean_minor), len(noise_major), len(noise_minor)))
        self.args.log.debug('-'*100)
        