import numpy as np
import torch as t
import torchvision.transforms as tr
from torchvision.datasets import MNIST

from data.loader import loader

class dataclass(loader):
    def __init__(self,args,aug,setting='train'):
        super().__init__(args, aug)
        self.color = t.tensor([[0.8627451,0.07843137,0.23529412],       #0
                                [0,0.50196078,0.50196078],              #1
                                [0.99215686,0.91372549,0.0627451],      #2
                                [0,0.58431373,0.71372549],              #3
                                [0.929411765,0.568627451,0.129411765],  #4
                                [0.568627451,0.117647059,0.737254902],  #5
                                [0.274509804,0.941176471,0.941176471],  #6
                                [0.980392157,0.77254902,0.733333333],   #7
                                [0.823529412,0.960784314,0.235294118],  #8
                                [0.501960784,0,0]])                     #8



        if setting == 'train':
            mn = MNIST('./tmp',train = True, download = True)
            self.data = mn.data[:int(len(mn.data)*0.9)]
            self.label = mn.targets[:int(len(mn.targets)*0.9)]
            self.gt_label = self.label.clone()
            self.use_aug = aug['aug']
            self.biasing()
            self.noising()
        

        elif setting == 'val':
            mn = MNIST('./tmp',train = True, download = True)
            self.data = mn.data[int(len(mn.data)*0.9):]
            self.label = mn.targets[int(len(mn.targets)*0.9):]
            self.gt_label = self.label.clone()
            self.use_aug = aug['noaug']
            self.biasing()
            self.noising()
        

        elif setting == 'test':
            mn = MNIST('./tmp',train = False, download = True)
            self.data = mn.data
            self.label = mn.targets
            self.gt_label = self.label.clone()
            self.use_aug = aug['noaug']
            self.biasing('uniform')
            

    ##### Control Bias
    def noising(self,noise_ratio = None):
        if noise_ratio == None:
            noise_ratio = self.args.noise
        for idx in range(len(self.label)):
            if t.rand(1) > noise_ratio:
                while(True):
                    label_tmp = t.randint(0,10,(1,))
                    if label_tmp != self.label[idx]:
                        self.label[idx] = label_tmp
                        break

    def biasing(self,setting=None): 
        # Bias splitting
        if setting == 'uniform':
            self.midx =  t.rand(len(self.label)) > 1./self.args.num_labels
        else:
            self.midx =  t.rand(len(self.label)) > self.args.bias
        
        # Coloring
        self.color_idx = t.zeros(len(self.label),dtype=t.int)
        for idx in range(len(self.label)):
            if not self.midx[idx]:
                self.color_idx[idx] = self.label[idx]
            else:
                while(True):
                    rand_label = t.randint(0,self.args.num_labels,(1,))
                    if rand_label != self.label[idx]:
                        break
                self.color_idx[idx] = rand_label
        
        # Deviation
        self.dev = t.normal(0,self.args.d_option,(len(self.label),3))

        # Image preprocessing
        self.imgs = t.zeros((len(self.label),self.args.img_dim,self.args.img_size,self.args.img_size))
        for idx in range(len(self.label)):
            color = self.color[self.color_idx[idx]] + self.dev[idx]
            self.imgs[idx] = t.clamp((self.data[idx].unsqueeze(2).repeat(1,1,3).float()*color)/255.,0.,1.).permute((2,0,1))
