import logging

import os.path as osp
import pickle
import numpy as np

import torch
from torch.utils.data import Dataset
from datasets import transform as T

from datasets.randaugment import RandomAugment
from datasets.sampler import RandomSampler, BatchSampler

from torchvision import datasets as torchdatasets

from PIL import Image

logger = logging.getLogger(__name__)


def load_data_train(L=1000, num_val=7325, dataset='SVHN', dspth='./data'):

    assert L in [40, 250, 1000], "The number of labels should be 40, 250, or 1000"

    n_class = 10

    svhn_dataset = torchdatasets.SVHN(root=dspth, split='train', download=True)

    data = svhn_dataset.data
    labels = svhn_dataset.labels

    n_labels = L // n_class
    n_vals = num_val // n_class
    data_x, label_x, data_u, label_u, data_val, label_val = [], [], [], [], [], []
    for i in range(n_class):
        indices = np.where(labels == i)[0]
        np.random.shuffle(indices)
        p = len(indices) - n_vals
        inds_x, inds_u, inds_val = indices[:n_labels], indices[n_labels:p], indices[p:]
        data_x += [
            data[i].reshape(3, 32, 32).transpose(1, 2, 0)   # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
            for i in inds_x
        ]
        label_x += [labels[i] for i in inds_x]
        data_u += [
            data[i].reshape(3, 32, 32).transpose(1, 2, 0)   # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
            for i in inds_u
        ]
        label_u += [labels[i] for i in inds_u]
        data_val += [
            data[i].reshape(3, 32, 32).transpose(1, 2, 0)   # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
            for i in inds_val
        ]
        label_val += [labels[i] for i in inds_val]

    return data_x, label_x, data_u, label_u, data_val, label_val


def load_data_test(dataset, dspth='./data'):

    svhn_dataset = torchdatasets.SVHN(root=dspth, split='test', download=True)

    data = svhn_dataset.data
    labels = svhn_dataset.labels

    data = [
        el.reshape(3, 32, 32).transpose(1, 2, 0)    # Transpose to (32, 32, 3) to use cv2 library. T.Tensor will re-transpose it back to (3, 32, 32)
        for el in data
    ]

    return data, labels


def load_extra_set(dataset, dspth='./data'):

    pass


def compute_mean_var():
    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train()
    data = data_x + data_u + data_val
    data = np.concatenate([el[None, ...] for el in data], axis=0)

    mean, var = [], []
    for i in range(3):
        channel = (data[:, :, :, i].ravel() / 255)
        mean.append(np.mean(channel))
        var.append(np.std(channel))

    print('mean: ', mean)
    print('var: ', var)


class SVHN(Dataset):
    def __init__(self, dataset, data, labels, is_train=True):
        super(SVHN, self).__init__()
        self.data, self.labels = data, labels
        self.is_train = is_train
        assert len(self.data) == len(self.labels)
        mean, std = (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)     

        if is_train:
            self.trans_weak = T.Compose([
                T.Resize((32, 32)),
                T.PadandRandomCrop(border=4, cropsize=(32, 32)),
#                T.RandomHorizontalFlip(p=0.5),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
            self.trans_strong = T.Compose([
                T.Resize((32, 32)),
                T.PadandRandomCrop(border=4, cropsize=(32, 32)),
#                T.RandomHorizontalFlip(p=0.5),
                RandomAugment(2, 10),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
        else:
            self.trans = T.Compose([
                T.Resize((32, 32)),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])

    def __getitem__(self, idx):
        im, lb = self.data[idx], self.labels[idx]
        if self.is_train:
            return self.trans_weak(im), self.trans_strong(im), lb
        else:
            return self.trans(im), lb

    def __len__(self):
        leng = len(self.data)
        return leng


class SVHNMixMatch(Dataset):
    def __init__(self, dataset, data, labels, is_train=True):
        super(SVHNMixMatch, self).__init__()
        self.data, self.labels = data, labels
        self.is_train = is_train
        assert len(self.data) == len(self.labels)
        mean, std = (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)  

        if is_train:
            self.trans_train = T.Compose([
                T.Resize((32, 32)),
                T.PadandRandomCrop(border=4, cropsize=(32, 32)),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
        else:
            self.trans = T.Compose([
                T.Resize((32, 32)),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])

    def __getitem__(self, idx):
        im, lb = self.data[idx], self.labels[idx]
        if self.is_train:
            return self.trans_train(im), self.trans_train(im), lb
        else:
            return self.trans(im), lb

    def __len__(self):
        leng = len(self.data)
        return leng


def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, num_val, root='data'):
    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train(L=L, num_val=num_val, dataset=dataset, dspth=root)

    logger.info('labeled dataset: {}. unlabeled dataset: {}. validation dataset: {}.'.format(len(data_x), len(data_u), len(data_val)))

    ds_x = SVHN(
        dataset=dataset,
        data=data_x,
        labels=label_x,
        is_train=True
    )
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)  # yield a batch of samples one time
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=2,
        pin_memory=True
    )

    ds_u = SVHN(
        dataset=dataset,
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)  # multiply by mu
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )

    ds_val = SVHN(
        dataset=dataset,
        data=data_val,
        labels=label_val,
        is_train=False
    )
    dl_val = torch.utils.data.DataLoader(
        ds_val,
        shuffle=False,
        batch_size=100,
        drop_last=False,
        num_workers=2,
        pin_memory=True
    )
    
    logger.info('labeled sample: {}. unlabeled sample: {}.'.format(len(sampler_x), len(sampler_u)))
    logger.info('labeled loader: {}. unlabeled loader: {}. validation loader: {}.'.format(len(dl_x), len(dl_u), len(dl_val)))

    return dl_x, dl_u, dl_val


def get_test_loader(dataset, batch_size, num_workers, pin_memory=True):
    data, labels = load_data_test(dataset)
    logger.info('test dataset: {}.'.format(len(data)))
    print('test dataset: {}.'.format(len(data)))
    ds = SVHN(
        dataset=dataset,
        data=data,
        labels=labels,
        is_train=False
    )
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=False,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    logger.info('test loader: {}.'.format(len(dl)))
    return dl


def get_train_loader_mixmatch(dataset, batch_size, mu, n_iters_per_epoch, L, num_val, root='data'):
    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train(L=L, num_val=num_val, dataset=dataset, dspth=root)

    logger.info('labeled dataset: {}. unlabeled dataset: {}. validation dataset: {}.'.format(len(data_x), len(data_u), len(data_val)))

    ds_x = SVHNMixMatch(
        dataset=dataset,
        data=data_x,
        labels=label_x,
        is_train=True
    )
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)  # yield a batch of samples one time
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=2,
        pin_memory=True
    )

    ds_u = SVHNMixMatch(
        dataset=dataset,
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)  # multiply by mu
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )

    ds_val = SVHNMixMatch(
        dataset=dataset,
        data=data_val,
        labels=label_val,
        is_train=False
    )
    dl_val = torch.utils.data.DataLoader(
        ds_val,
        shuffle=False,
        batch_size=100,
        drop_last=False,
        num_workers=2,
        pin_memory=True
    )

    logger.info('labeled sample: {}. unlabeled sample: {}.'.format(len(sampler_x), len(sampler_u)))
    logger.info('labeled loader: {}. unlabeled loader: {}. validation loader: {}.'.format(len(dl_x), len(dl_u), len(dl_val)))

    return dl_x, dl_u, dl_val


def get_test_loader_mixmatch(dataset, batch_size, num_workers, pin_memory=True):
    data, labels = load_data_test(dataset)
    logger.info('test dataset: {}.'.format(len(data)))
    print('test dataset: {}.'.format(len(data)))
    ds = SVHNMixMatch(
        dataset=dataset,
        data=data,
        labels=labels,
        is_train=False
    )
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=False,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    logger.info('test loader: {}.'.format(len(dl)))
    return dl
