from curses import meta
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random
import numpy as np
from PIL import Image
import json
import os
import torch
from torchnet.meter import AUCMeter
from utils import RandAugment, CIFAR10Policy, auc_acc_metric
import copy
import pickle

transform_none_10_compose = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


transform_none_100_compose = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ]
)


transform_weak_10_compose = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


transform_weak_100_compose = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ]
)


transform_strong_10_compose = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


transform_strong_100_compose = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ]
)

transform_strong_randaugment_10_compose = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        # RandAugment(1, 6),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_strong_randaugment_100_compose = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        # RandAugment(1, 6),
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
    ]
)

def get_data_aug_args(dataset, noise_mode, noise_ratio):
    if dataset == 'cifar100':
        if noise_mode == 'sym':
            if noise_ratio > 0.5:
                return {
                    "labeled_transforms": [
                                "transform_strong_100",
                                "transform_strong_100",
                                "transform_weak_100",
                                "transform_weak_100"
                            ],
                            "unlabeled_transforms": [
                                "transform_strong_100",
                                "transform_strong_100",
                                "transform_weak_100",
                                "transform_weak_100"
                            ],
                    "warmup_transform": "transform_weak_100"        
                } 
            else:
                return {
                    "labeled_transforms": [
                                "transform_strong_100",
                                "transform_strong_100",
                                "transform_weak_100",
                                "transform_weak_100"
                            ],
                            "unlabeled_transforms": [
                                "transform_strong_100",
                                "transform_strong_100",
                                "transform_weak_100",
                                "transform_weak_100"
                            ],
                    "warmup_transform": "transform_strong_100"        
                }
    if dataset == 'cifar10':
        if noise_mode == 'asym':
            return  {
                    "labeled_transforms": [
                                "transform_strong_10",
                                "transform_strong_10",
                                "transform_weak_10",
                                "transform_weak_10"
                            ],
                            "unlabeled_transforms": [
                                "transform_strong_10",
                                "transform_strong_10",
                                "transform_weak_10",
                                "transform_weak_10"
                            ],
                    "warmup_transform": "transform_weak_10"        
                }  
        else:
            if noise_ratio > 0.5:
                return {
                    "labeled_transforms": [
                                "transform_strong_10",
                                "transform_strong_10",
                                "transform_weak_10",
                                "transform_weak_10"
                            ],
                            "unlabeled_transforms": [
                                "transform_strong_10",
                                "transform_strong_10",
                                "transform_weak_10",
                                "transform_weak_10"
                            ],
                    "warmup_transform": "transform_weak_10"        
                } 
            else:
                return {
                    "labeled_transforms": [
                                "transform_strong_10",
                                "transform_strong_10",
                                "transform_weak_10",
                                "transform_weak_10"
                            ],
                            "unlabeled_transforms": [
                                "transform_strong_10",
                                "transform_strong_10",
                                "transform_weak_10",
                                "transform_weak_10"
                            ],
                    "warmup_transform": "transform_strong_10"        
                }

def unpickle(file):
    import _pickle as cPickle
    with open(file, 'rb') as fo:
        dict = cPickle.load(fo, encoding='latin1')
    return dict

class cifar_dataset(Dataset): 
    def __init__(self, dataset, r, noise_mode, root_dir, transform, mode, noise_file='', pred=[], probability=[], log='', strong_aug = None, noise_trans=None, tf_writer = None, epoch=0, model_name = '', eval_train_loss = None, meta_pred=None, meta_probability=None, use_meta_label = -1, test_transform = None): 
        self.transform_test = test_transform
        self.meta = False
        self.meta_num = 1000
        self.r = r # noise ratio
        self.strong_aug = strong_aug
        self.transform = transform
        
        self.mode = mode  
        self.transition = {0:0,2:0,4:7,7:7,1:1,9:1,3:5,5:3,6:6,8:8} # class transition for asymmetric noise
     
        if self.mode=='test':
            if dataset=='cifar10':                
                test_dic = unpickle('%s/test_batch'%root_dir)
                self.test_data = test_dic['data']
                self.test_data = self.test_data.reshape((10000, 3, 32, 32))
                self.test_data = self.test_data.transpose((0, 2, 3, 1))  
                self.test_label = test_dic['labels']
            elif dataset=='cifar100':
                test_dic = unpickle('%s/test'%root_dir)
                self.test_data = test_dic['data']
                self.test_data = self.test_data.reshape((10000, 3, 32, 32))
                self.test_data = self.test_data.transpose((0, 2, 3, 1))  
                self.test_label = test_dic['fine_labels']                            
        else:    
            train_data=[]
            train_label=[]
            if dataset=='cifar10': 
                for n in range(1,6):
                    dpath = '%s/data_batch_%d'%(root_dir,n)
                    data_dic = unpickle(dpath)
                    train_data.append(data_dic['data'])
                    train_label = train_label+data_dic['labels']
                train_data = np.concatenate(train_data)
                    
            elif dataset=='cifar100':    
                train_dic = unpickle('%s/train'%root_dir)
                train_data = train_dic['data']
                train_label = train_dic['fine_labels']

            
            train_num = 50000

            train_data = train_data.reshape((train_num, 3, 32, 32))
            train_data = train_data.transpose((0, 2, 3, 1))
            
            if os.path.exists(noise_file):
                noise_label = json.load(open(noise_file,"r"))
            else:    #inject noise   
                noise_label = []
                idx = list(range(train_num))

                random.shuffle(idx)

                num_noise = int(self.r*train_num)            
                noise_idx = idx[:num_noise]
                for i in range(train_num):
                    if i in noise_idx:
                        if noise_mode=='sym':
                            if dataset=='cifar10': 
                                noiselabel = random.randint(0,9)
                            elif dataset=='cifar100':    
                                noiselabel = random.randint(0,99)
                            noise_label.append(noiselabel)
                        elif noise_mode=='asym':   
                            noiselabel = self.transition[train_label[i]]
                            noise_label.append(noiselabel)                    
                    else:    
                        noise_label.append(train_label[i])   
                print("save noisy labels to %s ..."%noise_file)        
                json.dump(noise_label,open(noise_file,"w"))       
            
            if self.mode == 'all' or  self.mode == "eval_train":
                self.train_data = train_data
                self.noise_label = noise_label
                clean = (np.array(noise_label)==np.array(train_label))
                self.c_or_n = np.where(clean,1.0,0.0)   
            else:
                clean = (np.array(noise_label)==np.array(train_label))
                self.c_or_n = np.where(clean,1.0,0.0)                   
                if self.mode == "labeled":
                    if noise_trans is None:
                        pred_idx = pred.nonzero()[0]
                    else:
                        pred_idx = np.logical_or(pred>0,noise_trans>0).nonzero()[0]   
                    
                    if use_meta_label < 0 or (use_meta_label >= epoch and use_meta_label >= 0): 
                        self.probability = [probability[i] for i in pred_idx]   
                        self.eval_train_loss = [eval_train_loss[i] for i in pred_idx] 


                    if meta_pred is not None:
                        meta_pred_idx = meta_pred.nonzero()[0]

                    if use_meta_label>0 and use_meta_label < epoch: 
                        self.probability = [meta_probability[i] for i in meta_pred_idx]   
                        self.eval_train_loss = [probability[i] for i in meta_pred_idx]
                        #self.eval_train_loss = [eval_train_loss[i] for i in meta_pred_idx] 
                        pred_idx = meta_pred_idx 

                    
                elif self.mode == "unlabeled":
                    if noise_trans is None:
                        pred_idx = (1-pred).nonzero()[0]  
                    else:
                        pred_idx = np.logical_and((1-pred)>0,(1-noise_trans)>0).nonzero()[0]                                       

                    if use_meta_label < 0 or use_meta_label >= epoch: 
                        self.probability = [probability[i] for i in pred_idx]     
                        self.eval_train_loss = [eval_train_loss[i] for i in pred_idx]  
                    else:
                        meta_pred_idx = (1-meta_pred).nonzero()[0]
                        self.probability = [meta_probability[i] for i in meta_pred_idx]   
                        self.eval_train_loss = [probability[i] for i in meta_pred_idx] 
                        #self.eval_train_loss = [eval_train_loss[i] for i in meta_pred_idx]
                        pred_idx = meta_pred_idx
                self.train_data = train_data[pred_idx]
                self.noise_label = [noise_label[i] for i in pred_idx]                          
                print("%s data has a size of %d"%(self.mode,len(self.noise_label)))            
                
    def __getitem__(self, index):
        if self.mode=='eval_train':
            img, target = self.train_data[index], self.noise_label[index]
            img = Image.fromarray(img)
            img = self.transform(img)            
            return img, target, index 
        elif self.mode=='labeled':
            img, target, prob, eval_loss = self.train_data[index], self.noise_label[index], self.probability[index], self.eval_train_loss[index]
            image = Image.fromarray(img)

            img1 = self.transform[0](image)
            img2 = self.transform[1](image)
            img3 = self.transform[2](image)
            img4 = self.transform[3](image)

            return img1, img2, img3, img4,  target, prob,  eval_loss            
        elif self.mode=='unlabeled':
            img, target, prob, eval_loss = self.train_data[index], self.noise_label[index], self.probability[index], self.eval_train_loss[index]
            image = Image.fromarray(img)

            img1 = self.transform[0](image)
            img2 = self.transform[1](image)
            img3 = self.transform[2](image)
            img4 = self.transform[3](image)
            return img1, img2, img3, img4, target, prob,  eval_loss       
        elif self.mode=='all':
            img, target = self.train_data[index], self.noise_label[index]
            img = Image.fromarray(img)
            img = self.transform(img)            
            return img, target, index        
        elif self.mode=='test':
            img, target = self.test_data[index], self.test_label[index]
            img = Image.fromarray(img)
            img = self.transform(img)            
            return img, target
        elif self.mode =='val':
            img, target = self.val_data[index], self.val_label[index]
            img = Image.fromarray(img)
            img = self.transform(img)            
            return img, target 
           
    def __len__(self):
        if self.mode == 'val':
            return(len(self.val_data))
        elif self.mode!='test':
            return len(self.train_data)
        else:
            return len(self.test_data)         
        
        
class cifar_dataloader():

    def prob_transform_100(self, x):
        if random.random() < self.warmup_aug_prob:
            return transform_strong_100_compose(x)
        else:
            return transform_weak_100_compose(x)

    def prob_transform_10(self, x):
        if random.random() < self.warmup_aug_prob:
            return transform_strong_10_compose(x)
        else:
            return transform_weak_10_compose(x)

    def transform_strong_100(self, x):
        return transform_strong_100_compose(x)

    def transform_strong_10(self, x):
        return transform_strong_10_compose(x)

    def transform_weak_100(self, x):
        return transform_weak_100_compose(x)

    def transform_weak_10(self, x):
        return transform_weak_10_compose(x)

    def transform_strong_randaugment_10(self, x):
        return transform_strong_randaugment_10_compose(x)

    def transform_strong_randaugment_100(self, x):
        return transform_strong_randaugment_100_compose(x)

    def transform_none_10(self, x):
        return transform_none_10_compose(x)

    def transform_none_100(self, x):
        return transform_none_100_compose(x)



    def __init__(self, dataset, r, noise_mode, batch_size, num_workers, root_dir, log, noise_file=''):
        self.dataset = dataset
        self.r = r
        self.noise_mode = noise_mode
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.root_dir = root_dir
        self.log = log
        self.noise_file = noise_file
        self.strong_aug = True

        if self.dataset=='cifar10':
            self.transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
                ]) 
            
            self.transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
                ])    
        elif self.dataset=='cifar100':    
            self.transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
                ]) 
            self.transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
                ])  

        if True:
            self.augmentation_strategy = get_data_aug_args(self.dataset, self.noise_mode, self.r)

            self.transforms = {
            "warmup": self.__getattribute__(self.augmentation_strategy["warmup_transform"]),
            "unlabeled": [None for i in range(4)],
            "labeled": [None for i in range(4)],
            "test": None,
            }

            # workaround so it works on both windows and linux
            for i in range(len(self.augmentation_strategy["unlabeled_transforms"])):
                self.transforms["unlabeled"][i] = self.__getattribute__(
                    self.augmentation_strategy["unlabeled_transforms"][i])
                
            for i in range(len(self.augmentation_strategy["labeled_transforms"])):
                self.transforms["labeled"][i] = self.__getattribute__(
                    self.augmentation_strategy["labeled_transforms"][i]
                )


    def run(self,mode,pred=[],prob=[],noise_trans_idx=None, eval_train_loss = None,tf_writer=None, epoch=0, model_name='', meta_pred=None, meta_prob=None, use_meta_label=-1):
        if mode=='warmup':
            all_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_train if not self.strong_aug else self.transforms["warmup"], mode="all",noise_file=self.noise_file,strong_aug=self.strong_aug)                
            trainloader = DataLoader(
                dataset=all_dataset, 
                batch_size=self.batch_size*2,
                shuffle=True,
                num_workers=self.num_workers)             
            return trainloader
                                    
        
        elif mode=='train':
            
            labeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transforms["labeled"], mode="labeled", noise_file=self.noise_file, pred=pred, probability=prob,log=self.log,strong_aug = self.strong_aug, noise_trans= None, tf_writer=tf_writer, epoch=epoch, model_name=model_name, eval_train_loss=eval_train_loss,meta_pred=meta_pred, meta_probability=meta_prob, use_meta_label=use_meta_label,test_transform = self.transform_test) 
                        
            
            labeled_trainloader = DataLoader(
                dataset=labeled_dataset, 
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers)   
            
            unlabeled_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform= self.transforms["unlabeled"],mode="unlabeled", noise_file=self.noise_file, pred=pred, probability=prob,strong_aug = self.strong_aug,noise_trans= None, tf_writer=tf_writer, epoch=epoch, model_name=model_name, eval_train_loss=eval_train_loss,meta_pred=meta_pred, meta_probability=meta_prob,  use_meta_label=use_meta_label, test_transform = self.transform_test)

            unlabeled_trainloader = DataLoader(
                dataset=unlabeled_dataset, 
                batch_size=int(self.batch_size),
                shuffle=True,
                num_workers=self.num_workers)     
            return labeled_trainloader, unlabeled_trainloader
        
        elif mode=='test':
            test_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='test')      
            test_loader = DataLoader(
                dataset=test_dataset, 
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers)          
            return test_loader
        
        elif mode=='eval_train' or mode == 'eval_train_log':
            eval_dataset = cifar_dataset(dataset=self.dataset, noise_mode=self.noise_mode, r=self.r, root_dir=self.root_dir, transform=self.transform_test, mode='all' if mode=='eval_train' else 'eval_train', noise_file=self.noise_file)      
            eval_loader = DataLoader(
                dataset=eval_dataset, 
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers)          
            return eval_loader        