import numpy as np
from PIL import Image

from torchvision import transforms
from torch.utils.data import Dataset

from src.fl_datasets.utils import get_onehot


class BasicDataset(Dataset):
    """
    BasicDataset returns a pair of image and labels (targets).
    If targets are not given, BasicDataset returns None as the label.
    This class supports strong augmentation for Fixmatch,
    and return both weakly and strongly augmented images.
    """

    def __init__(self,
                 alg,
                 data,
                 targets=None,
                 num_classes=None,
                 transform=None,
                 is_ulb=False,
                 strong_transform=None,
                 onehot=False,
                 *args,
                 **kwargs):
        """
        Args
            data: x_data
            targets: y_data (if not exist, None)
            num_classes: number of label classes
            transform: basic transformation of data
        """
        super(BasicDataset, self).__init__()
        self.alg = alg
        self.data = data
        self.targets = targets

        self.num_classes = num_classes
        self.is_ulb = is_ulb
        self.onehot = onehot

        self.transform = transform
        self.strong_transform = strong_transform


    def __sample__(self, idx):
        """ dataset specific sample function """
        # set idx-th target
        if self.targets is None:
            target = None
        else:
            target_ = self.targets[idx]
            target = target_ if not self.onehot else get_onehot(self.num_classes, target_)

        # set augmented images
        img = self.data[idx]
        return img, target


    def __getitem__(self, idx):
        """
        If strong augmentation is not used,
            return weak_augment_image, target
        else:
            return weak_augment_image, strong_augment_image, target
        """
        img, target = self.__sample__(idx)

        if self.transform is None:
            return {'x_lb': transforms.ToTensor()(img), 'y_lb': target}
        else:
            if isinstance(img, np.ndarray):
                img = Image.fromarray(img)  # shape of img should be [H, W, C]
            if isinstance(img, str):
                img = pil_loader(img)

            if self.alg not in ['openmatch', 'ours', 'prosub']:
                img_w = self.transform(img)
                
            # [*] Labeled sample
            if not self.is_ulb:
                if self.alg in ['openmatch', 'ours', 'prosub']:
                    return {'idx_lb': idx,
                            'x_lb': self.transform[1](img), 'x_lb_w': self.transform[0](img), 
                            'x_lb_s0': self.strong_transform(img), 'x_lb_s1': self.strong_transform(img),
                            'y_lb': target}
                else:
                    return {'idx_lb': idx, 'x_lb': img_w, 'y_lb': target}
            
            # [*] Unlabeled sample
            # => y_ulb should be only used for evaluating pseudo-labels and be never used for training
            else:
                if self.alg in ['openmatch', 'ours', 'prosub']:
                    return {'idx_ulb': idx, 
                            'x_ulb': self.transform[1](img), 'x_ulb_w': self.transform[0](img), 
                            'x_ulb_s': self.strong_transform(img), 'y_ulb': target}
                else:
                    return {'idx_ulb': idx, 'x_ulb_w': img_w, 'x_ulb_s': self.strong_transform(img), 'y_ulb': target}

    def __len__(self):
        return len(self.data)
    
    
def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')