import os
import copy
import torch
import logging
import numpy as np
import torchvision.datasets as datasets
from scipy.io import loadmat
from .dataset import BasicDataset
from .utils import split_labeled_unlabeled, reassign_target



def load_normal_dataset(data_root, dataset, train=True, lb_per_class=0):
    logging.info("load_normal_dataset {}...".format(dataset))
    dset = getattr(datasets, dataset)
    data_root = os.path.join(data_root, dataset)
    if dataset == 'SVHN':
        if train:
            dset = dset(data_root, split='train', download=True)
        else:
            dset = dset(data_root, split='test', download=True)
        classes = 10
        imgs, labels = dset.data, dset.labels
    else: # CIFAR10, CIFAR100, MNIST
        dset = dset(data_root, train=train, download=True)
        classes = len(dset.classes)
        imgs, labels = dset.data, dset.targets
    
    if type(imgs) == torch.Tensor:
        imgs = imgs.numpy()
    if type(labels) == list:
        labels = np.array(labels)
    elif type(labels) == torch.Tensor:
        labels = labels.numpy()
    labels = labels.astype(np.int64)

    perm = np.random.permutation(imgs.shape[0])
    imgs, labels = imgs[perm], labels[perm]
    if train:
        lb_img, lb_targets, ulb_img, ulb_targets = split_labeled_unlabeled(imgs, labels, classes, lb_per_class)
        lb_set = BasicDataset(
            dataset, lb_img, lb_targets, classes, is_train=train)
        ulb_set = BasicDataset(
            dataset, ulb_img, ulb_targets, classes, is_train=train)
    else:
        lb_set = None
        ulb_set = BasicDataset(
            dataset, imgs, labels, classes, is_train=train)
    return {
        'lb_set': lb_set,
        'ulb_set': ulb_set,
    }


def load_open_dataset(data_root, dataset, num_seen_class, train=True, lb_per_class=0, close_set=False):
    logging.info("load_normal_dataset {}...".format(dataset))
    dset = getattr(datasets, dataset)
    data_root = os.path.join(data_root, dataset)
    # load data
    if dataset == 'SVHN':
        if train:
            dset = dset(data_root, split='train', download=True)
        else:
            dset = dset(data_root, split='test', download=True)
        num_all_classes = 10
        imgs, labels = dset.data, dset.labels
    elif dataset in ['CIFAR10', 'CIFAR100']:
        dset = dset(data_root, train=train, download=True)
        num_all_classes = len(dset.classes)
        imgs, labels = dset.data, dset.targets
    else:
        raise NotImplementedError

    # choose seen classes
    if dataset == 'CIFAR10':
        seen_classes = set(range(2, 8))
    elif dataset == 'CIFAR100':
        num_super_classes = num_seen_class // 5  # args.num_super_classes
        super_classes = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
                                  3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
                                  6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
                                  0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
                                  5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
                                  16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
                                  10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
                                  2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
                                  16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
                                  18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
        seen_classes = set(np.arange(num_all_classes)[super_classes < num_super_classes])
    elif dataset == 'SVHN':
        seen_classes = set(range(2,8))
    else:
        raise NotImplementedError
    
    if type(imgs) == torch.Tensor:
        imgs = imgs.numpy()
    if type(labels) == list:
        labels = np.array(labels)
    elif type(labels) == torch.Tensor:
        labels = labels.numpy()
    labels = labels.astype(np.int64)

    perm = np.random.permutation(imgs.shape[0])
    imgs, labels = imgs[perm], labels[perm]

    labels = reassign_target(labels, num_all_classes, seen_classes)

    if train:
        lb_img, lb_targets, ulb_img, ulb_targets = split_labeled_unlabeled(imgs, labels, num_seen_class, lb_per_class)
        lb_set = BasicDataset(
            dataset, lb_img, lb_targets, num_seen_class, is_train=train)
        if close_set:
            seen_idx = np.where(ulb_targets < num_seen_class)[0]
            ulb_img, ulb_targets = ulb_img[seen_idx], ulb_targets[seen_idx]
            ulb_set = BasicDataset(
                dataset, ulb_img, ulb_targets, num_seen_class, is_train=train)
        else:
            ulb_set = BasicDataset(
                dataset, ulb_img, ulb_targets, num_seen_class + 1, is_train=train)
    else:
        lb_set = None
        if close_set:
            seen_idx = np.where(labels < num_seen_class)[0]
            imgs, labels = imgs[seen_idx], labels[seen_idx]
            ulb_set = BasicDataset(
                dataset, imgs, labels, num_seen_class, is_train=train)
        else:
            ulb_set = BasicDataset(
                dataset, imgs, labels, num_seen_class + 1, is_train=train)
        
    return {
        'lb_set': lb_set,
        'ulb_set': ulb_set,
    }
    

def fetch_dataset(data_root, dataset, lb_per_class=0, train=True): 
    """
    return Dict :{
        subset: BasicDataset,
        ...,
        lb_set: BasicDataset
    }
    """
    
    logging.info(f'fetching dataset-{dataset} from {data_root}...')
    if dataset in ['CIFAR10', 'CIFAR100', 'SVHN']:
        return load_normal_dataset(data_root, dataset, train, lb_per_class)
    else:
        raise  NotImplementedError
    

def fetch_os_dataset(data_root, dataset, lb_per_class=0, num_seen_class='', train=True, close_set=False):
    """
    return Dict :{
        subset: BasicDataset,
        ...,
        lb_set: BasicDataset
    }
    """
    logging.info(f'fetching open-set dataset-{dataset} from {data_root}...')
    if dataset in ['CIFAR10', 'CIFAR100', 'SVHN']:
        return load_open_dataset(data_root, dataset, num_seen_class, train, lb_per_class, close_set)
    
    else:
        raise NotImplementedError


        
    
