import logging

import os.path as osp
import pickle
import numpy as np

import torch
from torch.utils.data import Dataset
from datasets import transform as T

from datasets.randaugment import RandomAugment
from datasets.sampler import RandomSampler, BatchSampler

from torchvision import datasets as torchdatasets

from PIL import Image

logger = logging.getLogger(__name__)


def load_data_train(L=9000, num_val=10000, dataset='TinyImageNet', dspth='data/tiny-imagenet-200'):

    assert L in [9000,500]

    n_class = 200
    n_labels = L // n_class
    n_vals = num_val // n_class

    train_dataset = torchdatasets.ImageFolder(osp.join(dspth, 'train'))
    train_images = train_dataset.imgs

    data_x, label_x, data_u, label_u, data_val, label_val = [], [], [], [], [], []

    for i in range(n_class):
        indices = []
        for idx, train_image in enumerate(train_images):
            _, label = train_image
            if (label == i):
                indices.append(idx)
        np.random.shuffle(indices)
        p = len(indices) - n_vals
        inds_x, inds_u, inds_val = indices[:n_labels], indices[n_labels:p], indices[p:]

        for ind in inds_x:
            img_path, label = train_images[ind]
            image = Image.open(img_path)
            if (np.asarray(image).shape != (64,64,3)):
                image = image.convert('RGB')
            data_x.append(np.asarray(image))
            label_x.append(label)
        
        for ind in inds_u:
            img_path, label = train_images[ind]
            image = Image.open(img_path)
            if (np.asarray(image).shape != (64,64,3)):
                image = image.convert('RGB')
            data_u.append(np.asarray(image))
            label_u.append(label)

        for ind in inds_val:
            img_path, label = train_images[ind]
            image = Image.open(img_path)
            if (np.asarray(image).shape != (64,64,3)):
                image = image.convert('RGB')
            data_val.append(np.asarray(image))
            label_val.append(label)

    return data_x, label_x, data_u, label_u, data_val, label_val
        

def load_data_test(dataset, dspth='data/tiny-imagenet-200'):

    test_dataset = torchdatasets.ImageFolder(osp.join(dspth, 'val'))
    test_images = test_dataset.imgs

    data, labels = [], []

    for test_image in test_images:
        img_path, label = test_image
        image = Image.open(img_path)
        if (np.asarray(image).shape != (64,64,3)):
            image = image.convert('RGB')
        data.append(np.asarray(image))    
        labels.append(label)

    return data, labels


def compute_mean_var():
    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train()
    print('Finish loading data.')

    data = data_x + data_u + data_val
    
    data = np.concatenate([el[None, ...] for el in data], axis=0)
    print('Finish concatenating data.')
    print('data: ', len(data))

    mean, var = [], []
    for i in range(3):
        channel = (data[:, :, :, i].ravel() / 255)
        mean.append(np.mean(channel))
        var.append(np.std(channel))

    print('mean: ', mean)
    print('var: ', var)


class TinyImageNet(Dataset):
    def __init__(self, dataset, data, labels, is_train=True):
        super(TinyImageNet, self).__init__()
        self.data, self.labels = data, labels
        self.is_train = is_train
        assert len(self.data) == len(self.labels)

        mean, std = (0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)      # calculation

        if is_train:
            self.trans_weak = T.Compose([
                T.Resize((64, 64)),
                T.PadandRandomCrop(border=8, cropsize=(64, 64)),
                T.RandomHorizontalFlip(p=0.5),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
            self.trans_strong = T.Compose([
                T.Resize((64, 64)),
                T.PadandRandomCrop(border=8, cropsize=(64, 64)),
                T.RandomHorizontalFlip(p=0.5),
                RandomAugment(2, 10, cutoutSize=32),           # Change the size of Cutout
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
        else:
            self.trans = T.Compose([
                T.Resize((64, 64)),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])

    def __getitem__(self, idx):
        im, lb = self.data[idx], self.labels[idx]
        if self.is_train:
            return self.trans_weak(im), self.trans_strong(im), lb
        else:
            return self.trans(im), lb

    def __len__(self):
        leng = len(self.data)
        return leng


class TinyImageNetMixMatch(Dataset):
    def __init__(self, dataset, data, labels, is_train=True):
        super(TinyImageNetMixMatch, self).__init__()
        self.data, self.labels = data, labels
        self.is_train = is_train
        assert len(self.data) == len(self.labels)

        mean, std = (0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821)      # calculation

        if is_train:
            self.trans_train = T.Compose([
                T.Resize((64, 64)),
                T.PadandRandomCrop(border=8, cropsize=(64, 64)),
                T.RandomHorizontalFlip(p=0.5),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])
        else:
            self.trans = T.Compose([
                T.Resize((64, 64)),
                T.Normalize(mean, std),
                T.ToTensor(),
            ])

    def __getitem__(self, idx):
        im, lb = self.data[idx], self.labels[idx]
        if self.is_train:
            return self.trans_train(im), self.trans_train(im), lb
        else:
            return self.trans(im), lb

    def __len__(self):
        leng = len(self.data)
        return leng


def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, num_val, root='data/tiny-imagenet-200'):
        
    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train(L=L, num_val=num_val, dataset=dataset, dspth=root)

    logger.info('labeled dataset: {}. unlabeled dataset: {}. validation dataset: {}.'.format(len(data_x), len(data_u), len(data_val)))

    ds_x = TinyImageNet(
        dataset=dataset,
        data=data_x,
        labels=label_x,
        is_train=True
    )  # return an iter of num_samples length (all indices of samples)
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)  # yield a batch of samples one time
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=2,
        pin_memory=True
    )

    ds_u = TinyImageNet(
        dataset=dataset,
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)  # multiply by mu
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )

    ds_val = TinyImageNet(
        dataset=dataset,
        data=data_val,
        labels=label_val,
        is_train=False
    )
    dl_val = torch.utils.data.DataLoader(
        ds_val,
        shuffle=False,
        batch_size=100,
        drop_last=False,
        num_workers=2,
        pin_memory=True
    )
    
    logger.info('labeled sample: {}. unlabeled sample: {}.'.format(len(sampler_x), len(sampler_u)))
    logger.info('labeled loader: {}. unlabeled loader: {}. validation loader: {}.'.format(len(dl_x), len(dl_u), len(dl_val)))

    return dl_x, dl_u, dl_val


def get_test_loader(dataset, batch_size, num_workers, pin_memory=True):
    data, labels = load_data_test(dataset)
    logger.info('test dataset: {}.'.format(len(data)))

    ds = TinyImageNet(
        dataset=dataset,
        data=data,
        labels=labels,
        is_train=False
    )
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=False,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    logger.info('test loader: {}.'.format(len(dl)))
    return dl


def get_train_loader_mixmatch(dataset, batch_size, mu, n_iters_per_epoch, L, num_val, root='data/tiny-imagenet-200'):

    data_x, label_x, data_u, label_u, data_val, label_val = load_data_train(L=L, num_val=num_val, dataset=dataset, dspth=root)

    logger.info('labeled dataset: {}. unlabeled dataset: {}. validation dataset: {}.'.format(len(data_x), len(data_u), len(data_val)))

    ds_x = TinyImageNetMixMatch(
        dataset=dataset,
        data=data_x,
        labels=label_x,
        is_train=True
    )  # return an iter of num_samples length (all indices of samples)
    sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
    batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True)  # yield a batch of samples one time
    dl_x = torch.utils.data.DataLoader(
        ds_x,
        batch_sampler=batch_sampler_x,
        num_workers=2,
        pin_memory=True
    )

    ds_u = TinyImageNetMixMatch(
        dataset=dataset,
        data=data_u,
        labels=label_u,
        is_train=True
    )
    sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)  # multiply by mu
    batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
    dl_u = torch.utils.data.DataLoader(
        ds_u,
        batch_sampler=batch_sampler_u,
        num_workers=2,
        pin_memory=True
    )

    ds_val = TinyImageNetMixMatch(
        dataset=dataset,
        data=data_val,
        labels=label_val,
        is_train=False
    )
    dl_val = torch.utils.data.DataLoader(
        ds_val,
        shuffle=False,
        batch_size=100,
        drop_last=False,
        num_workers=2,
        pin_memory=True
    )

    logger.info('labeled sample: {}. unlabeled sample: {}.'.format(len(sampler_x), len(sampler_u)))
    logger.info('labeled loader: {}. unlabeled loader: {}. validation loader: {}.'.format(len(dl_x), len(dl_u), len(dl_val)))

    return dl_x, dl_u, dl_val


def get_test_loader_mixmatch(dataset, batch_size, num_workers, pin_memory=True):
    data, labels = load_data_test(dataset)
    logger.info('test dataset: {}.'.format(len(data)))

    ds = TinyImageNetMixMatch(
        dataset=dataset,
        data=data,
        labels=labels,
        is_train=False
    )
    dl = torch.utils.data.DataLoader(
        ds,
        shuffle=False,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    logger.info('test loader: {}.'.format(len(dl)))
    return dl
