from cgi import test
from imghdr import tests
from posixpath import split
from sys import implementation
from unittest import TestLoader
from lib.datasets.mytransforms import TransformFixMatch, TransformTest
from lib.datasets.randaugment import Color
from scipy.sparse import base
from torch.utils import data
from torchvision import datasets
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data import Dataset
from torch.utils.data import Sampler
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import numpy as np
import torch
import torchvision.transforms as tv_transforms
import argparse, os
import math
from lib.datasets.mydatasets import CIFAR10SSL, MNISTSSL, SVHNSSL, ColourMNISTSSL
from lib.datasets.colour_mnist import ColourBiasedMNIST

COUNTS = {
    "svhn": {"train": 73257, "test": 26032, "valid": 7326},
    "cifar10": {"train": 50000, "test": 10000, "valid": 5000},
    "imagenet_32": {
        "train": 1281167,
        "test": 50000,
        "valid": 50050
    },
}

CORRUPTIONS = [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
    'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
    'brightness', 'contrast', 'elastic_transform', 'pixelate',
    'jpeg_compression'
]

rng = np.random.RandomState(seed=1)

_DATA_DIR = "./data"


data_path = "./data"

class RandomSampler(Sampler):
    """ sampling without replacement """
    def __init__(self, num_data, num_sample):
        iterations = num_sample // num_data + 1
        self.indices = torch.cat([torch.randperm(num_data) for _ in range(iterations)]).tolist()[:num_sample]

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

def x_u_split(args, labels):
    label_per_class = args.n_labels // args.tot_class
    unlabel_per_class = args.n_unlabels // 10 
    val_per_class = args.n_valid // args.tot_class
    labels = np.array(labels)
    labeled_idx = []
    val_idx = []
    unlabeled_idx = []
    # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
    for i in range(args.tot_class):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class+val_per_class+unlabel_per_class, False)
        labeled_idx.extend(idx[:label_per_class])
        val_idx.extend(idx[label_per_class:label_per_class+val_per_class])
        unlabeled_idx.extend(idx[label_per_class+val_per_class:])
    for i in range(args.tot_class, 10):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, unlabel_per_class, False)
        unlabeled_idx.extend(idx)

    labeled_idx = np.array(labeled_idx)

    assert len(labeled_idx) == label_per_class * args.tot_class
    if args.n_labels < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.iterations / args.n_labels)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    np.random.shuffle(labeled_idx)
    np.random.shuffle(unlabeled_idx)
    np.random.shuffle(val_idx)
    return labeled_idx, unlabeled_idx, val_idx


def test_split(args, labels):
    test_per_class = args.n_test // args.tot_class
    labels = np.array(labels)
    test_idx = []
    for i in range(args.tot_class):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, test_per_class, False)
        test_idx.extend(idx)

    test_idx = np.array(test_idx)
    np.random.shuffle(test_idx)
    return test_idx

def get_dataloaders(args, transform_fn, transform_test):
    rng = np.random.RandomState(seed=1)

    if args.dataset == "mnist":
        base_dataset = datasets.MNIST(data_path, train=True, download=True, transform=transform_fn)
        train_labeled_idxs, train_unlabeled_idxs, val_idxs = \
            x_u_split(args, base_dataset.targets)
    elif args.dataset == "cmnist":
        base_dataset = ColourMNISTSSL(data_path, None, train=True, transform=transform_fn, download=False, data_label_correlation=0.1, n_confusing_labels=2)
        train_labeled_idxs, train_unlabeled_idxs, val_idxs = x_u_split(args, base_dataset.targets)
    elif args.dataset == "svhn":
        base_dataset = datasets.SVHN(data_path, split='train', download=True, transform=transform_fn)
        train_labeled_idxs, train_unlabeled_idxs, val_idxs = \
            x_u_split(args, base_dataset.labels)
    elif args.dataset == "cifar10":
        # N x H x W x C -> N x C x H x W
        base_dataset = datasets.CIFAR10(data_path, train=True, download=True, transform=transform_fn)
        base_dataset.targets = np.array(base_dataset.targets)
        base_dataset.targets -= 2
        base_dataset.targets[np.where(base_dataset.targets == -2)[0]] = 8
        base_dataset.targets[np.where(base_dataset.targets == -1)[0]] = 9

        train_labeled_idxs, train_unlabeled_idxs, val_idxs = \
            x_u_split(args, base_dataset.targets)

    if args.dataset == "mnist":
        train_labeled_dataset = MNISTSSL(data_path, train_labeled_idxs, train=True, transform=transform_fn)
        train_unlabeled_dataset = MNISTSSL(data_path, train_unlabeled_idxs, train=True, transform=transform_fn, return_idx=False)
        val_dataset = MNISTSSL(data_path, val_idxs, train=True, transform=transform_test)
        test_dataset = MNISTSSL(data_path, None, train=False, transform=transform_test, download=False)
        
        target_ind = np.where(test_dataset.targets >= args.tot_class)[0]
        test_dataset.targets[target_ind] = args.tot_class
    elif args.dataset == "cmnist":
        train_labeled_dataset = ColourMNISTSSL(data_path, train_labeled_idxs, train=True, transform=transform_fn, download=False, data_label_correlation=0.1, n_confusing_labels=2)
        train_unlabeled_dataset = ColourMNISTSSL(data_path, train_unlabeled_idxs, train=True, transform=transform_fn, download=False, data_label_correlation=0.1, n_confusing_labels=2)
        val_dataset = ColourMNISTSSL(data_path, val_idxs, train=True, transform=transform_test, download=False, data_label_correlation=0.1, n_confusing_labels=2)
        test_dataset = ColourMNISTSSL(data_path, None, train=False, transform=transform_test, download=False, data_label_correlation=0.1, n_confusing_labels=2)

        target_ind = np.where(test_dataset.targets >= args.tot_class)[0]
        test_dataset.targets[target_ind] = args.tot_class

    elif args.dataset == "svhn":
        train_labeled_dataset = SVHNSSL(data_path, train_labeled_idxs, split='train', transform=transform_fn)
        train_unlabeled_dataset = SVHNSSL(data_path, train_unlabeled_idxs, split='train', transform=transform_fn, return_idx=False)
        val_dataset = SVHNSSL(data_path, val_idxs, split='train', transform=transform_test)
        test_dataset = SVHNSSL(data_path, None, split='test', transform=transform_test, download=False)
        
        target_ind = np.where(test_dataset.labels >= args.tot_class)[0]
        test_dataset.labels[target_ind] = args.tot_class

    elif args.dataset == "cifar10":
        train_labeled_dataset = CIFAR10SSL(data_path, train_labeled_idxs, train=True, transform=transform_fn)
        train_unlabeled_dataset = CIFAR10SSL(data_path, train_unlabeled_idxs, train=True, transform=transform_fn, return_idx=False)
        val_dataset = CIFAR10SSL(data_path, val_idxs, train=True, transform=transform_test)
        test_dataset = datasets.CIFAR10(data_path, train=False, transform=transform_test, download=False)
        test_dataset.data = []
        test_range = np.arange(0, args.n_test)
        test_dataset.targets = np.load('./data/CIFAR-10-C/' + 'labels.npy')[test_range]
        split_range = np.split(test_range, 15)
        i=0
        for corruption in CORRUPTIONS:
            test_data = np.load('./data/CIFAR-10-C/' + corruption + '.npy')[split_range[i]]
            test_dataset.data.append(test_data)
            i+=1
        test_dataset.data = np.concatenate(test_dataset.data, 0)

        train_labeled_dataset.targets -= 2
        train_unlabeled_dataset.targets -= 2
        train_labeled_dataset.targets[np.where(train_labeled_dataset.targets == -2)[0]] = 8
        train_labeled_dataset.targets[np.where(train_labeled_dataset.targets == -1)[0]] = 9
        train_unlabeled_dataset.targets[np.where(train_unlabeled_dataset.targets == -2)[0]] = 8
        train_unlabeled_dataset.targets[np.where(train_unlabeled_dataset.targets == -1)[0]] = 9
        val_dataset.targets -= 2
        val_dataset.targets[np.where(val_dataset.targets == -2)[0]] = 8
        val_dataset.targets[np.where(val_dataset.targets == -1)[0]] = 9
        test_dataset.targets -= 2
        test_dataset.targets[np.where(test_dataset.targets == -2)[0]] = 8
        test_dataset.targets[np.where(test_dataset.targets == -1)[0]] = 9

        target_ind = np.where(test_dataset.targets >= args.tot_class)[0]
        test_dataset.targets[target_ind] = args.tot_class

    unique_labeled = np.unique(train_labeled_idxs)
    val_labeled = np.unique(val_idxs)
    if args.dataset == 'svhn':
        id_num = (train_unlabeled_dataset.labels < args.tot_class).sum()
        ood_num = (train_unlabeled_dataset.labels >= args.tot_class).sum()
    else:
        id_num = (train_unlabeled_dataset.targets < args.tot_class).sum()
        ood_num = (train_unlabeled_dataset.targets >= args.tot_class).sum()

    print("ID data number: ", id_num)
    print("OOD data number: ", ood_num)        
    print("Dataset: ", args.dataset)
    print("Labeled examples: ", len(unique_labeled))
    print("Unlabeled examples: ", len(train_unlabeled_idxs))        
    print("Valdation examples: ", len(val_labeled))
    print("Test examples: ", len(test_dataset.data))

    l_loader = DataLoader(
            train_labeled_dataset, batch_size=args.batch_size//2, num_workers=args.num_workers, drop_last=True, 
            sampler=RandomSampler(len(train_labeled_dataset), args.iterations * args.batch_size//2))
    u_loader = DataLoader(
            train_unlabeled_dataset, batch_size=args.batch_size//2, num_workers=args.num_workers, drop_last=True,
            sampler=RandomSampler(len(train_unlabeled_dataset), args.iterations * args.batch_size//2))
    val_loader = DataLoader(
            val_dataset, sampler=SequentialSampler(val_dataset), batch_size=args.batch_size, shuffle=False, drop_last=False)
    test_loader = DataLoader(
            test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.batch_size, shuffle=False, drop_last=False)
    
    return l_loader, u_loader, val_loader, test_loader