import numpy as np
from torch.utils.data import DataLoader, Dataset
from Config import *
import torchvision.transforms as T
import os
from tqdm import tqdm
from torchvision.datasets import CIFAR100, CIFAR10, FashionMNIST, SVHN, ImageFolder
from collections import OrderedDict

# Data
def create_eval_img_folder(dataset_path: str):
    val_dir = os.path.join(dataset_path, 'val')
    img_dir = os.path.join(val_dir, 'images')

    fp = open(os.path.join(val_dir, 'val_annotations.txt'), 'r')
    data = fp.readlines()
    eval_img_dict = OrderedDict()
    for line in tqdm(data):
        words = line.split('\t')
        eval_img_dict[words[0]] = words[1]
    fp.close()

    # Create folder if not present and move images into proper folders
    for img, folder in tqdm(eval_img_dict.items()):
        newpath = os.path.join(img_dir, folder)
        if not os.path.exists(newpath):
            os.makedirs(newpath)
        if os.path.exists(os.path.join(img_dir, img)):
            os.rename(os.path.join(img_dir, img), os.path.join(newpath, img))

def load_dataset(dataset, **kwargs):
    if dataset == 'cifar10':
        normVec = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
        cropsize = 32
    elif dataset == 'fashionmnist' or dataset == 'fashionmnistim':
        normVec = [0.1307], [0.3081]
        cropsize = 28
    else:
        normVec = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
        cropsize = 32

    train_transform = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomCrop(size=cropsize, padding=4),
        T.ToTensor(),
        T.Normalize(*normVec)
    ])

    if dataset == 'cifar10': 
        data_train = CIFAR10('../cifar10', train=True, download=True, transform=train_transform)
        data_unlabeled = None
        data_test  = CIFAR10('../cifar10', train=False, download=True, transform=train_transform)
        NO_CLASSES = 10
        adden = ADDENDUM
        NUM_TRAIN = len(data_train)        
        no_train = NUM_TRAIN
    elif dataset == 'cifar10im':
        data_train = CIFAR10('../cifar10', train=True, download=True, transform=train_transform)
        targets = np.array(data_train.targets)
        classes, class_counts = np.unique(targets, return_counts=True)
        nb_classes = len(classes)
        imb_class_counts = [int(1 / kwargs['ir'] * 5000), 5000] * 5 # imbalance ratio = 10
        class_idxs = [np.where(targets == i)[0] for i in range(nb_classes)]
        imb_class_idx = [class_id[:class_count] for class_id, class_count in zip(class_idxs, imb_class_counts)]
        imb_class_idx = np.hstack(imb_class_idx)
        no_train = imb_class_idx.shape[0]
        data_train.targets = targets[imb_class_idx]
        data_train.data = data_train.data[imb_class_idx]
        data_unlabeled = None
        data_test  = CIFAR10('../cifar10', train=False, download=True, transform=train_transform)
        NO_CLASSES = 10
        adden = ADDENDUM
    elif dataset == 'cifar100':
        data_train = CIFAR100('../cifar100', train=True, download=True, transform=train_transform)
        data_unlabeled = None
        data_test  = CIFAR100('../cifar100', train=False, download=True, transform=train_transform)
        NO_CLASSES = 100
        adden = 2000
        NUM_TRAIN = len(data_train)        
        no_train = NUM_TRAIN
    elif dataset == 'cifar100im':
        data_train = CIFAR100('../cifar100', train=True, download=True, transform=train_transform)
        data_unlabeled = None
        targets = np.array(data_train.targets)
        classes, class_counts = np.unique(targets, return_counts=True)
        nb_classes = len(classes)
        imb_class_counts = [int(1 / kwargs['ir'] * 500), 500] * 50  # imbalance ratio = 10
        class_idxs = [np.where(targets == i)[0] for i in range(nb_classes)]
        imb_class_idx = [class_id[:class_count] for class_id, class_count in zip(class_idxs, imb_class_counts)]
        imb_class_idx = np.hstack(imb_class_idx)
        no_train = imb_class_idx.shape[0]
        data_train.targets = targets[imb_class_idx]
        data_train.data = data_train.data[imb_class_idx]
        data_test = CIFAR100('../cifar100', train=False, download=True, transform=train_transform)
        NO_CLASSES = 100
        adden = 1000
    elif dataset == 'fashionmnist':
        data_train = FashionMNIST('../fashionMNIST', train=True, download=True, transform=train_transform)
        data_unlabeled = None
        data_test  = FashionMNIST('../fashionMNIST', train=False, download=True, transform=train_transform)
        NO_CLASSES = 10
        NUM_TRAIN = len(data_train)        
        adden = ADDENDUM
        no_train = NUM_TRAIN
    elif dataset == 'fashionmnistim':
        data_train = FashionMNIST('../fashionMNIST', train=True, download=True, transform=train_transform)
        targets = np.array(data_train.targets)
        classes, class_counts = np.unique(targets, return_counts=True)
        nb_classes = len(classes)
        imb_class_counts = [500, 5000] * 5 # imbalance ratio = 10
        class_idxs = [np.where(targets == i)[0] for i in range(nb_classes)]
        imb_class_idx = [class_id[:class_count] for class_id, class_count in zip(class_idxs, imb_class_counts)]
        imb_class_idx = np.hstack(imb_class_idx)
        no_train = imb_class_idx.shape[0]
        data_train.targets = targets[imb_class_idx]
        data_train.data = data_train.data[imb_class_idx]
        data_unlabeled = None
        data_test  = FashionMNIST('../fashionMNIST', train=False, download=True, transform=train_transform)
        NO_CLASSES = 10
        adden = ADDENDUM
    elif dataset == 'svhn':
        data_train = SVHN('../svhn', split='train', download=True, transform=train_transform)
        data_unlabeled = None
        data_test  = SVHN('../svhn', split='test', download=True, transform=train_transform)
        NO_CLASSES = 10
        NUM_TRAIN = len(data_train)        
        adden = ADDENDUM
        no_train = NUM_TRAIN
    elif dataset == 'svhnim':
        data_train = SVHN('../svhn', split='train', download=True, transform=train_transform)
        targets = np.array(data_train.labels)
        classes, class_counts = np.unique(targets, return_counts=True)
        nb_classes = len(classes)
        class_idxs = [np.where(targets == i)[0] for i in range(nb_classes)]
        imb_class_counts = [int(len(i)) for i in class_idxs]
        for cIdx in range(0,9,2): imb_class_counts[cIdx] = int(0.1 * imb_class_counts[cIdx])
        imb_class_idx = [class_id[:class_count] for class_id, class_count in zip(class_idxs, imb_class_counts)]
        imb_class_idx = np.hstack(imb_class_idx)
        no_train = imb_class_idx.shape[0]
        data_train.labels = targets[imb_class_idx]
        data_train.data = data_train.data[imb_class_idx]
        data_unlabeled = None
        data_test  = SVHN('../svhn', split='test', download=True, transform=train_transform)
        NO_CLASSES = 10
        adden = ADDENDUM

    return data_train, data_unlabeled, data_test, adden, NO_CLASSES, no_train