import numpy as np
import torch as t
import torchvision.transforms as tr
from torchvision.datasets import MNIST, FashionMNIST

from data.loader import loader

class dataclass(loader):
    def __init__(self,args,aug,setting='train'):
        super().__init__(args,aug)



        if setting == 'train':
            mn = MNIST('./tmp',train = True, download = True)
            self.data = mn.data[:int(len(mn.data)*0.9)]
            self.label = mn.targets[:int(len(mn.targets)*0.9)]
            self.gt_label = self.label.clone()

            fmn = FashionMNIST('./tmp', train=True, download = True)
            self.bdata = fmn.data[:int(len(fmn.data)*0.9)]
            self.blabel = fmn.targets[:int(len(fmn.targets)*0.9)]

            self.use_aug = aug['aug']
            self.biasing()
            self.noising()
        

        elif setting == 'val':
            mn = MNIST('./tmp',train = True, download = True)
            self.data = mn.data[int(len(mn.data)*0.9):]
            self.label = mn.targets[int(len(mn.targets)*0.9):]
            self.gt_label = self.label.clone()
            
            fmn = FashionMNIST('./tmp',train = False, download = True)
            self.bdata = fmn.data[int(len(fmn.data)*0.9):]
            self.blabel = fmn.targets[int(len(fmn.targets)*0.9):]

            self.use_aug = aug['noaug']
            self.biasing()
            self.noising()
        

        elif setting == 'test':
            mn = MNIST('./tmp',train = False, download = True)
            self.data = mn.data
            self.label = mn.targets
            self.gt_label = self.label.clone()

            fmn = FashionMNIST('./tmp', train=False, download = True)
            self.bdata = fmn.data
            self.blabel = fmn.targets

            self.use_aug = aug['noaug']
            self.biasing('uniform')
            

    ##### Control Bias
    def noising(self,noise_ratio = None):
        if noise_ratio == None:
            noise_ratio = self.args.noise
        for idx in range(len(self.label)):
            if t.rand(1) > noise_ratio:
                while(True):
                    label_tmp = t.randint(0,10,(1,))
                    if label_tmp != self.label[idx]:
                        self.label[idx] = label_tmp
                        break

    
    def biasing(self,setting=None): 
        if setting == 'uniform':
            self.midx = t.rand(len(self.label)) > 1./self.args.num_labels
        else:
            self.midx = t.rand(len(self.label)) > self.args.bias
        
        self.mark = t.zeros((len(self.label),28,28))
        self.mark_idx = t.zeros((len(self.label)))
        for idx in range(len(self.label)):
            if not self.midx[idx]:
                self.mark[idx] = self.mark_sampling(self.label[idx])
                self.mark_idx[idx] = self.label[idx]
            else:
                while(True):
                    rand_label = t.randint(0,self.args.num_labels,(1,))
                    if rand_label != self.label[idx]:
                        break
                self.mark_idx[idx] = rand_label
                self.mark[idx] = self.mark_sampling(rand_label)

        dev_mark = t.randint(0,int(self.args.d_option+1),(len(self.label),))
        dev_data = t.randint(0,int(self.args.d_option+1),(len(self.label),))

        self.imgs = t.zeros((len(self.label),self.args.img_size,self.args.img_size))
        for idx in range(len(self.label)):
            self.imgs[idx,dev_mark[idx]:28+dev_mark[idx], :28] = self.mark[idx]
            self.imgs[idx,28-dev_data[idx]:56-dev_data[idx]:, 28:] = self.data[idx]
        self.imgs = self.imgs.float()/255.


    def mark_sampling(self,label):
        label_set = self.blabel.numpy()
        pos = np.where(label_set == label.numpy())[0]
        idx = pos[t.randint(0,len(pos),(1,))]
        return self.bdata[idx]
        