from torchvision import datasets, transforms
from torch.utils.data import Dataset
from .data_utils import get_onehot
from .augmentation.randaugment import RandAugment

from PIL import Image
import numpy as np
import copy


class BasicDataset(Dataset):
    """
    BasicDataset returns a pair of image and labels (targets).
    If targets are not given, BasicDataset returns None as the label.
    This class supports strong augmentation for Fixmatch,
    and return both weakly and strongly augmented images.
    """
    def __init__(self,
                 data,
                 targets=None,
                 num_classes=None,
                 transform=None,
                 use_strong_transform=False,
                 strong_transform=None,
                 onehot=False,
                 label_idxs=None,
                 *args, **kwargs):
        """
        Args
            data: x_data
            targets: y_data (if not exist, None)
            num_classes: number of label classes
            transform: basic transformation of data
            use_strong_transform: If True, this dataset returns both weakly and strongly augmented images.
            strong_transform: list of transformation functions for strong augmentation
            onehot: If True, label is converted into onehot vector.
        """
        super(BasicDataset, self).__init__()
        self.data = data
        self.targets = targets
        self.label_idxs = label_idxs
        self.num_classes = num_classes
        self.use_strong_transform = use_strong_transform
        self.onehot = onehot
        
        self.transform = transform
        if use_strong_transform:
            if strong_transform is None:
                self.strong_transform = copy.deepcopy(transform)
                self.strong_transform.transforms.insert(0, RandAugment(3,5))
        else:
            self.strong_transform = strong_transform
                
    
    def __getitem__(self, idx):
        """
        If strong augmentation is not used,
            return weak_augment_image, target
        else:
            return weak_augment_image, strong_augment_image, target
        """
        
        #set idx-th target
        if self.targets is None:
            target = None
        else:
            if self.label_idxs is None:
                target_ = self.targets[idx]
            else:
                reliable_ind = self.label_idxs[idx % len(self.label_idxs)]
                target_ = self.targets[reliable_ind]

            target = target_ if not self.onehot else get_onehot(self.num_classes, target_)
            
        #set augmented images
            

        if self.label_idxs is None:
            img, _, _, domain_label, sample_idx = self.data.__getitem__(idx)
        else:
            img, _, _, domain_label, sample_idx = self.data.__getitem__(reliable_ind)

        img = img[0]
        domain_label = domain_label[0]
        if self.transform is None:
            return transforms.ToTensor()(img), target, domain_label
        else:
            if isinstance(img, np.ndarray):
                img = Image.fromarray(img)
            img_w = self.transform(img)
            if not self.use_strong_transform:
                return img_w, target, idx, domain_label, sample_idx
            else:
                return img_w, self.strong_transform(img), target, domain_label, sample_idx

    
    def __len__(self):
        return len(self.data)