import logging
import math

import numpy as np
from PIL import Image
from torchvision import datasets
from torchvision import transforms

from .randaugment import RandAugmentMC

logger = logging.getLogger(__name__)

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
cifar100_mean = (0.5071, 0.4867, 0.4408)
cifar100_std = (0.2675, 0.2565, 0.2761)
normal_mean = (0.5, 0.5, 0.5)
normal_std = (0.5, 0.5, 0.5)


def get_cifar10(args, root, trans='fixmatch'):
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32,
                              padding=int(32*0.125),
                              padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    base_dataset = datasets.CIFAR10(root, train=True, download=True)

    if not args.use_same_idx:
        train_labeled_idxs, train_unlabeled_idxs = x_u_split(
            args, base_dataset.targets)
    else:
        train_labeled_idxs, train_unlabeled_idxs = retrieve_split(
            args, base_dataset.targets)

    train_labeled_dataset = CIFAR10SSL(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled)
    if trans == 'fixmatch':
        train_unlabeled_dataset = CIFAR10SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
    elif trans == 'joint':
        train_unlabeled_dataset = CIFAR10SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformJoint(mean=cifar10_mean, std=cifar10_std))
    elif trans == 'simclr':
        train_unlabeled_dataset = CIFAR10SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformSimCLR(mean=cifar10_mean, std=cifar10_std))
    elif trans == 'strong':
        train_unlabeled_dataset = CIFAR10SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformStrong(mean=cifar10_mean, std=cifar10_std))
    else:
        raise

    test_dataset = datasets.CIFAR10(
        root, train=False, transform=transform_val, download=False)

    return train_labeled_dataset, train_unlabeled_dataset, test_dataset,train_labeled_idxs


def get_cifar100(args, root, trans='fixmatch'):

    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32,
                              padding=int(32*0.125),
                              padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar100_mean, std=cifar100_std)])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar100_mean, std=cifar100_std)])

    base_dataset = datasets.CIFAR100(
        root, train=True, download=True)

    if not args.use_same_idx:
        train_labeled_idxs, train_unlabeled_idxs = x_u_split(
            args, base_dataset.targets)
    else:
        train_labeled_idxs, train_unlabeled_idxs = retrieve_split(
            args, base_dataset.targets)

    train_labeled_dataset = CIFAR100SSL(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled)
    if trans == 'fixmatch':
        train_unlabeled_dataset = CIFAR100SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformFixMatch(mean=cifar100_mean, std=cifar100_std))
    elif trans == 'joint':
        train_unlabeled_dataset = CIFAR100SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformJoint(mean=cifar100_mean, std=cifar100_std))
    elif trans == 'simclr':
        train_unlabeled_dataset = CIFAR100SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformSimCLR(mean=cifar100_mean, std=cifar100_std))
    elif trans == 'strong':
        train_unlabeled_dataset = CIFAR100SSL(
            root, train_unlabeled_idxs, train=True,
            transform=TransformStrong(mean=cifar100_mean, std=cifar100_std))
    else:
        raise

    test_dataset = datasets.CIFAR100(
        root, train=False, transform=transform_val, download=False)

    return train_labeled_dataset, train_unlabeled_dataset, test_dataset,train_labeled_idxs

def x_u_split(args, labels):
    label_per_class = args.num_labeled // args.num_classes
    labels = np.array(labels)
    labeled_idx = []
    unlabeled_idx = np.array(range(len(labels)))
    for i in range(args.num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    assert len(labeled_idx) == args.num_labeled

    if args.expand_labels or args.num_labeled < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.eval_step / args.num_labeled)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    print("NOTE: Using new set of idx!!!")
    np.random.shuffle(labeled_idx)
    return labeled_idx, unlabeled_idx

def retrieve_split(args,labels):
    if args.seed is None:
        # print("SEED IS NONE")
        labeled_idx_dir = args.path_to_npy + f'{args.dataset}_{args.num_labeled}.npy'
    else:
        labeled_idx_dir = args.path_to_npy + f'{args.dataset}_{args.num_labeled}_seed{args.seed}.npy'

    labeled_idx = np.load(labeled_idx_dir)
    print("Loaded idx from ", labeled_idx_dir)
    assert len(labeled_idx) == args.num_labeled
    unlabeled_idx = np.array(range(len(labels)))
    if args.expand_labels or args.num_labeled < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.eval_step / args.num_labeled)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])

    return labeled_idx, unlabeled_idx


class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)

class TransformStrong(object):
    def __init__(self, mean, std):
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        strong2 = self.strong(x)
        strong = self.strong(x)
        return self.normalize(strong2), self.normalize(strong)

class TransformSimCLR(object):
    def __init__(self, mean, std):
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        self.simclr_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2)])

    def __call__(self, x):
        simclr = self.simclr_transforms(x)
        simclr2 = self.simclr_transforms(x)
        return self.normalize(simclr), self.normalize(simclr2)

class TransformJoint(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        self.simclr_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        simclr = self.simclr_transforms(x)
        # simclr2 = self.simclr_transforms(x)
        return self.normalize(weak), self.normalize(strong), self.normalize(simclr)
        # , self.normalize(simclr2)

class CIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class CIFAR100SSL(datasets.CIFAR100):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


DATASET_GETTERS = {'cifar10': get_cifar10,
                   'cifar100': get_cifar100}
