from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random
import numpy as np
from PIL import Image
import json
import torch
from utils import RandAugment, CIFAR10Policy, ImageNetPolicy
import copy

transform_weak_c1m_c10_compose = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
    ]
)


def transform_weak_c1m(x):
    return transform_weak_c1m_c10_compose(x)


transform_strong_c1m_c10_compose = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
    ]
)


def transform_strong_c1m_c10(x):
    return transform_strong_c1m_c10_compose(x)


transform_strong_c1m_in_compose = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        ImageNetPolicy(),
        transforms.ToTensor(),
        transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
    ]
)


def transform_strong_c1m_in(x):
    return transform_strong_c1m_in_compose(x)


class clothing_dataset(Dataset): 
    def __init__(self, root, transform, mode, num_samples=0, pred=[], probability=[], paths=[], num_class=14, eval_train_loss = None,meta_pred=None, meta_probability=None, use_meta_label=-1, epoch=0): 
        
        self.root = root
        self.transform = transform
        self.mode = mode
        self.train_labels = {}
        self.test_labels = {}
        self.val_labels = {}           
        
        with open('%s/noisy_label_kv.txt'%self.root,'r') as f:
            lines = f.read().splitlines()
            for l in lines:
                entry = l.split()           
                img_path = '%s/'%self.root+entry[0][7:]
                self.train_labels[img_path] = int(entry[1])                         
        with open('%s/clean_label_kv.txt'%self.root,'r') as f:
            lines = f.read().splitlines()
            for l in lines:
                entry = l.split()           
                img_path = '%s/'%self.root+entry[0][7:]
                self.test_labels[img_path] = int(entry[1])   

        if mode == 'all' or mode == 'warm':
            train_imgs=[]
            with open('%s/noisy_train_key_list.txt'%self.root,'r') as f:
                lines = f.read().splitlines()
                for l in lines:
                    img_path = '%s/'%self.root+l[7:]
                    train_imgs.append(img_path)                                
            random.shuffle(train_imgs)
            class_num = torch.zeros(num_class)
            self.train_imgs = []
            for impath in train_imgs:
                label = self.train_labels[impath] 
                if class_num[label]<(num_samples/14) and len(self.train_imgs)<num_samples:
                    self.train_imgs.append(impath)
                    class_num[label]+=1
            random.shuffle(self.train_imgs)       
        elif self.mode == "labeled":   
            train_imgs = paths 
            pred_idx = pred.nonzero()[0]

            if meta_pred is not None:
                meta_pred_idx = meta_pred.nonzero()[0]
            
            if use_meta_label>0 and use_meta_label < epoch:
                self.train_imgs = [train_imgs[i] for i in meta_pred_idx]
                self.probability = [meta_probability[i] for i in meta_pred_idx]   
                self.eval_train_loss = [probability[i] for i in meta_pred_idx]
                pred_idx = meta_pred_idx 
            else:
                self.train_imgs = [train_imgs[i] for i in pred_idx] 
                self.eval_train_loss = eval_train_loss[pred_idx]  # this varaible may not represent loss but represent the meta_logits when ablation study         
                self.probability = [probability[i] for i in pred_idx]            
            print("%s data has a size of %d"%(self.mode,len(self.train_imgs)))
        elif self.mode == "unlabeled":  
            train_imgs = paths 
            
            pred_idx = (1-pred).nonzero()[0]  
            
            if use_meta_label < 0 or use_meta_label >= epoch: 
                self.train_imgs = [train_imgs[i] for i in pred_idx]
                
                self.probability = [probability[i] for i in pred_idx]     
                self.eval_train_loss = eval_train_loss[pred_idx]  
            else:
                meta_pred_idx = (1-meta_pred).nonzero()[0]
                self.train_imgs = [train_imgs[i] for i in meta_pred_idx]
                self.probability = [meta_probability[i] for i in meta_pred_idx]   
                self.eval_train_loss = [probability[i] for i in meta_pred_idx] 
                pred_idx = meta_pred_idx

            print("%s data has a size of %d"%(self.mode,len(self.train_imgs) ))                                
                         
        elif mode=='test':
            self.test_imgs = []
            with open('%s/clean_test_key_list.txt'%self.root,'r') as f:
                lines = f.read().splitlines()
                for l in lines:
                    img_path = '%s/'%self.root+l[7:]
                    self.test_imgs.append(img_path)            
        elif mode=='val':
            self.val_imgs = []
            with open('%s/clean_val_key_list.txt'%self.root,'r') as f:
                lines = f.read().splitlines()
                for l in lines:
                    img_path = '%s/'%self.root+l[7:]
                    self.val_imgs.append(img_path)
                    
    def __getitem__(self, index):  
        if self.mode=='labeled':
            img_path = self.train_imgs[index]
            target = self.train_labels[img_path] 
            prob = self.probability[index]
            eval_train_loss = self.eval_train_loss[index]
            try:     
                image = Image.open(img_path).convert('RGB')  
            except Exception as e:
                print('{}, img_path: {}'.format(e, img_path)) 
                img_path = self.train_imgs[index-1]
                target = self.train_labels[img_path]
                image = Image.open(img_path).convert('RGB')
           

            img1 = self.transform[0](image)
            img2 = self.transform[1](image)
            img3 = self.transform[2](image)
            img4 = self.transform[3](image) 
                
            
            return img1, img2, img3, img4, target, prob, eval_train_loss              
        elif self.mode=='unlabeled':
            img_path = self.train_imgs[index]
            target = self.train_labels[img_path] 
            prob = self.probability[index]
            eval_train_loss = self.eval_train_loss[index]
            try:     
                image = Image.open(img_path).convert('RGB')  
            except Exception as e:
                print('{}, img_path: {}'.format(e, img_path)) 
                img_path = self.train_imgs[index-1]
                image = Image.open(img_path).convert('RGB')   

            img1 = self.transform[0](image)
            img2 = self.transform[1](image)
            img3 = self.transform[2](image)
            img4 = self.transform[3](image)
                
            return img1, img2, img3, img4, target, prob, eval_train_loss           
        elif self.mode=='all' or self.mode=='warm':
            img_path = self.train_imgs[index]
            target = self.train_labels[img_path]
            try:     
                image = Image.open(img_path).convert('RGB')  
            except Exception as e:
                print('{}, img_path: {}'.format(e, img_path)) 
                img_path = self.train_imgs[index-1]
                target = self.train_labels[img_path]
                image = Image.open(img_path).convert('RGB')   
            img = self.transform(image)
            if self.mode == "warm":
                img2 = self.transform(image)
                return  img, target, img_path, img2 
            return img, target, img_path        
        elif self.mode=='test':
            img_path = self.test_imgs[index]
            target = self.test_labels[img_path]     
            try:     
                image = Image.open(img_path).convert('RGB')  
            except Exception as e:
                print('{}, img_path: {}'.format(e, img_path)) 
                img_path = self.test_imgs[index-1]
                target = self.test_labels[img_path]
                image = Image.open(img_path).convert('RGB')
            img = self.transform(image) 
            return img, target
        elif self.mode=='val':
            img_path = self.val_imgs[index]
            target = self.test_labels[img_path]     
            try:     
                image = Image.open(img_path).convert('RGB')  
            except Exception as e:
                print('{}, img_path: {}'.format(e, img_path)) 
                img_path = self.val_imgs[index-1]
                target = self.test_labels[img_path]
                image = Image.open(img_path).convert('RGB')  
            img = self.transform(image) 
            return img, target    
        
    def __len__(self):
        if self.mode=='test':
            return len(self.test_imgs)
        if self.mode=='val':
            return len(self.val_imgs)
        else:
            return len(self.train_imgs)            
        
class clothing_dataloader():  
    def __init__(self, root, batch_size, num_batches, num_workers):    
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.num_batches = num_batches
        self.root = root
        self.strong_aug = True  
     
        self.transform_train = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),                
                transforms.Normalize((0.6959, 0.6537, 0.6371),(0.3113, 0.3192, 0.3214)),                     
            ]) 
        self.transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.6959, 0.6537, 0.6371),(0.3113, 0.3192, 0.3214)),
            ])

        if True:
            self.augmentation_strategy = {
            "labeled_transforms": [
                        "transform_strong_c1m_in",
                        "transform_strong_c1m_in",
                        "transform_weak_c1m",
                        "transform_weak_c1m"
                    ],
                    "unlabeled_transforms": [
                        "transform_strong_c1m_in",
                        "transform_strong_c1m_in",
                        "transform_weak_c1m",
                        "transform_weak_c1m"
                    ],
                "warmup_transform": "transform_strong_c1m_c10"      
            }

            self.transforms = {
                "warmup": globals()[self.augmentation_strategy["warmup_transform"]],
                "unlabeled": [None for i in range(4)],
                "labeled": [None for i in range(4)],
                "test": None,
            }
            # workaround so it works on both windows and linux
            for i in range(len(self.augmentation_strategy["unlabeled_transforms"])):
                self.transforms["unlabeled"][i] = globals()[
                    self.augmentation_strategy["unlabeled_transforms"][i]
                ]
            for i in range(len(self.augmentation_strategy["labeled_transforms"])):
                self.transforms["labeled"][i] = globals()[
                    self.augmentation_strategy["labeled_transforms"][i]
                ]


    def run(self,mode,pred=[],prob=[],paths=[], amplifier=1, eval_train_loss = None, epoch=0, meta_pred = None,  meta_prob = None, use_meta_label=-1):        
        if mode=='warmup':
            warmup_dataset = clothing_dataset(self.root,transform= self.transforms["warmup"], mode='warm',num_samples=self.num_batches*self.batch_size*2*amplifier)
            warmup_loader = DataLoader(
                dataset=warmup_dataset, 
                batch_size=self.batch_size*2,
                shuffle=True,
                num_workers=self.num_workers)  
            return warmup_loader
        elif mode=='train':
            labeled_dataset = clothing_dataset(self.root,transform=self.transforms["labeled"], mode='labeled',pred=pred, probability=prob,paths=paths, eval_train_loss=eval_train_loss,meta_pred=meta_pred, meta_probability=meta_prob, use_meta_label=use_meta_label, epoch=epoch)

            labeled_loader = DataLoader(
                dataset=labeled_dataset, 
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers)           
            unlabeled_dataset = clothing_dataset(self.root,transform= self.transforms["unlabeled"], mode='unlabeled',pred=pred, probability=prob,paths=paths, eval_train_loss=eval_train_loss,meta_pred=meta_pred, meta_probability=meta_prob, use_meta_label=use_meta_label, epoch=epoch)
            
            unlabeled_loader = DataLoader(
                dataset=unlabeled_dataset, 
                batch_size= int(self.batch_size),
                shuffle=True,
                num_workers=self.num_workers)   
            return labeled_loader,unlabeled_loader
        elif mode=='eval_train':
            eval_dataset = clothing_dataset(self.root,transform=self.transform_test, mode='all',num_samples=self.num_batches*self.batch_size)
            eval_loader = DataLoader(
                dataset=eval_dataset, 
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers)          
            return eval_loader        
        elif mode=='test':
            test_dataset = clothing_dataset(self.root,transform=self.transform_test, mode='test')
            test_loader = DataLoader(
                dataset=test_dataset, 
                batch_size=1000,
                shuffle=False,
                num_workers=self.num_workers)             
            return test_loader             
        elif mode=='val':
            val_dataset = clothing_dataset(self.root,transform=self.transform_test, mode='val')
            val_loader = DataLoader(
                dataset=val_dataset, 
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers)             
            return val_loader     