import logging

import os.path as osp
import pickle
import numpy as np

import torch
from torch.utils.data import Dataset
from datasets import transform as T

from datasets.randaugment import RandomAugment
from datasets.sampler import RandomSampler, BatchSampler

from torchvision import datasets as torchdatasets

from PIL import Image

logger = logging.getLogger(__name__)


def load_data_train(L=1000, folds=None, dataset='STL10', dspth='./data'):

    assert L in [1000], "The number of labels should be 1000"
    assert folds in list(range(10)), "fold should be a number from 0 to 9"

    n_class = 10

    stl_labeled = torchdatasets.STL10(root=dspth, split='train', download=True)
    data_labeled = stl_labeled.data
    label_labeled = stl_labeled.labels  # dtype is uint8 (Byte) --> need to cast to Long (int64)
    label_labeled = label_labeled.astype(np.int64)

    stl_unlabeled = torchdatasets.STL10(root=dspth, split='unlabeled', download=True)
    data_unlabeled = stl_unlabeled.data
    label_unlabeled = stl_unlabeled.labels

    path_to_folds = osp.join(dspth, 'stl10_binary', 'fold_indices.txt')
    with open(path_to_folds, 'r') as f:
        idx_str = f.read().splitlines()[folds]
        idx_list = np.fromstring(idx_str, dtype=np.uint16, sep=' ')     # From 0 to 5000, thus using uint16

    data_x, label_x, data_u, label_u, data_val, label_val = [], [], [], [], [], []

    data_x += [
        data_labeled[i].reshape(3, 96, 96).transpose(1, 2, 0)   # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
        for i in idx_list
    ]
    label_x += [label_labeled[i] for i in idx_list]

    indices_labeled = np.arange(len(data_labeled))
    mask = np.ones(indices_labeled.shape, dtype=bool)
    mask[idx_list] = False
    inds_val = indices_labeled[mask]

    data_val += [
        data_labeled[i].reshape(3, 96, 96).transpose(1, 2, 0)   # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
        for i in inds_val
    ]
    label_val += [label_labeled[i] for i in inds_val]

    data_u += [
        data_unlabeled[i].reshape(3, 96, 96).transpose(1, 2, 0)   # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
        for i in range(len(data_unlabeled))
    ]
    label_u += [label_unlabeled[i] for i in range(len(label_unlabeled))]
    
    return data_x, label_x, data_u, label_u, data_val, label_val


def load_data_test(dataset, dspth='./data'):

    stl_test = torchdatasets.STL10(root=dspth, split='test', download=True)
    data_test = stl_test.data
    label_test = stl_test.labels
    label_test = label_test.astype(np.int64)    # Similar to train set, need to cast to Long (int64)

    data = []
    labels = []

    data += [
        data_test[i].reshape(3, 96, 96).transpose(1, 2, 0)   # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
        for i in range(len(data_test))
    ]
    labels += [label_test[i] for i in range(len(label_test))]

    return data, labels


def compute_mean_var():
    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train(folds=0)
    print('Finish loading data.')

    data = data_x + data_u + data_val
    
    data = np.concatenate([el[None, ...] for el in data], axis=0)
    print('Finish concatenating data.')

    mean, var = [], []
    for i in range(3):
        channel = (data[:, :, :, i].ravel() / 255)
        mean.append(np.mean(channel))
        var.append(np.std(channel))

    print('mean: ', mean)
    print('var: ', var)


class STL10(Dataset):
    def __init__(self, dataset, data, labels, is_train=True):
        super(STL10, self).__init__()
        self.data, self.labels = data, labels
        self.is_train = is_train
        assert len(self.data) == len(self.labels)

        mean, std = (0.4409, 0.4279, 0.3868), (0.2683, 0.2611, 0.2687)      

        if is_train:
            self.trans_weak = T.Compose([
                T.Resize((96, 96)),
                T.PadandRandomCrop(border=12, cropsize=(96, 96)),
                T.RandomHorizontalFlip(p=0.5),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
            self.trans_strong = T.Compose([
                T.Resize((96, 96)),
                T.PadandRandomCrop(border=12, cropsize=(96, 96)),
                T.RandomHorizontalFlip(p=0.5),
                RandomAugment(2, 10, cutoutSize=48),           # Change the size of Cutout
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
        else:
            self.trans = T.Compose([
                T.Resize((96, 96)),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])

    def __getitem__(self, idx):
        im, lb = self.data[idx], self.labels[idx]
        if self.is_train:
            return self.trans_weak(im), self.trans_strong(im), lb
        else:
            return self.trans(im), lb

    def __len__(self):
        leng = len(self.data)
        return leng


class STL10MixMatch(Dataset):
    def __init__(self, dataset, data, labels, is_train=True):
        super(STL10MixMatch, self).__init__()
        self.data, self.labels = data, labels
        self.is_train = is_train
        assert len(self.data) == len(self.labels)

        mean, std = (0.4409, 0.4279, 0.3868), (0.2683, 0.2611, 0.2687)      

        if is_train:
            self.trans_train = T.Compose([
                T.Resize((96, 96)),
                T.PadandRandomCrop(border=12, cropsize=(96, 96)),
                T.RandomHorizontalFlip(p=0.5),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
        else:
            self.trans = T.Compose([
                T.Resize((96, 96)),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])

    def __getitem__(self, idx):
        im, lb = self.data[idx], self.labels[idx]
        if self.is_train:
            return self.trans_train(im), self.trans_train(im), lb
        else:
            return self.trans(im), lb

    def __len__(self):
        leng = len(self.data)
        return leng


def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, num_val, root='data'):
        
    '''
    With STL-10, num_val is actually folds (0 to 9)
    '''

    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train(L=L, folds=num_val, dataset=dataset, dspth=root)

    logger.info('labeled dataset: {}. unlabeled dataset: {}. validation dataset: {}.'.format(len(data_x), len(data_u), len(data_val)))

    ds_x = STL10(
        dataset=dataset,
        data=data_x,
        labels=label_x,
        is_train=True
    )
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)  # yield a batch of samples one time
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=2,
        pin_memory=True
    )

    ds_u = STL10(
        dataset=dataset,
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)  # multiply by mu
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )

    ds_val = STL10(
        dataset=dataset,
        data=data_val,
        labels=label_val,
        is_train=False
    )
    dl_val = torch.utils.data.DataLoader(
        ds_val,
        shuffle=False,
        batch_size=100,
        drop_last=False,
        num_workers=2,
        pin_memory=True
    )
    
    logger.info('labeled sample: {}. unlabeled sample: {}.'.format(len(sampler_x), len(sampler_u)))
    logger.info('labeled loader: {}. unlabeled loader: {}. validation loader: {}.'.format(len(dl_x), len(dl_u), len(dl_val)))

    return dl_x, dl_u, dl_val


def get_test_loader(dataset, batch_size, num_workers, pin_memory=True):
    data, labels = load_data_test(dataset)
    logger.info('test dataset: {}.'.format(len(data)))

    ds = STL10(
        dataset=dataset,
        data=data,
        labels=labels,
        is_train=False
    )
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=False,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    logger.info('test loader: {}.'.format(len(dl)))
    return dl


def get_train_loader_mixmatch(dataset, batch_size, mu, n_iters_per_epoch, L, num_val, root='data'):
        
    '''
    With STL-10, num_val is actually folds (0 to 9)
    '''

    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train(L=L, folds=num_val, dataset=dataset, dspth=root)

    logger.info('labeled dataset: {}. unlabeled dataset: {}. validation dataset: {}.'.format(len(data_x), len(data_u), len(data_val)))

    ds_x = STL10MixMatch(
        dataset=dataset,
        data=data_x,
        labels=label_x,
        is_train=True
    )
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)  # yield a batch of samples one time
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=2,
        pin_memory=True
    )

    ds_u = STL10MixMatch(
        dataset=dataset,
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)  # multiply by mu
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )

    ds_val = STL10MixMatch(
        dataset=dataset,
        data=data_val,
        labels=label_val,
        is_train=False
    )
    dl_val = torch.utils.data.DataLoader(
        ds_val,
        shuffle=False,
        batch_size=100,
        drop_last=False,
        num_workers=2,
        pin_memory=True
    )
    
    logger.info('labeled sample: {}. unlabeled sample: {}.'.format(len(sampler_x), len(sampler_u)))
    logger.info('labeled loader: {}. unlabeled loader: {}. validation loader: {}.'.format(len(dl_x), len(dl_u), len(dl_val)))

    return dl_x, dl_u, dl_val


def get_test_loader_mixmatch(dataset, batch_size, num_workers, pin_memory=True):
    data, labels = load_data_test(dataset)
    logger.info('test dataset: {}.'.format(len(data)))

    ds = STL10MixMatch(
        dataset=dataset,
        data=data,
        labels=labels,
        is_train=False
    )
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=False,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    logger.info('test loader: {}.'.format(len(dl)))
    return dl
