import os
import numpy as np
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

from .dataset_base import BasicDataset
from src.fl_datasets.utils import reassign_target


def get_inlier(args, name, data_dir='./data'):
    mean, std = {}, {}
    mean['cifar10'] = [0.485, 0.456, 0.406]
    mean['cifar100'] = [x / 255 for x in [129.3, 124.1, 112.4]]

    mean['fashionmnist'] = [0.2860]
    std['fashionmnist'] = [0.3530]
    std['cifar10'] = [0.229, 0.224, 0.225]
    std['cifar100'] = [x / 255 for x in [68.2, 65.4, 70.4]]


    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])
    if name =="fashionmnist":
        transform_val = transforms.Compose([
        transforms.Lambda(lambda x: x.unsqueeze(0).float() / 255.0),
        transforms.Normalize(mean[name], std[name])
        ])

    if name == 'cifar10':
        seen_classes = set(range(2, 8))
        num_all_classes = 10
    elif name == 'cifar100':
        num_super_classes = args.num_classes // 5  # args.num_super_classes
        num_all_classes = 100
        super_classes = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
                                  3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
                                  6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
                                  0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
                                  5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
                                  16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
                                  10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
                                  2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
                                  16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
                                  18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
        seen_classes = set(np.arange(num_all_classes)[super_classes < num_super_classes])
    elif name == "fashionmnist":
        seen_classes = set([0, 1, 2, 3, 4, 6])
        num_all_classes = 10
    else:
        raise NotImplementedError
    
    if name == "fashionmnist":
        dset=getattr(torchvision.datasets, 'FashionMNIST', )
    else:
        dset = getattr(torchvision.datasets, name.upper())
    test_dset = dset(data_dir, train=False, download=True)
    test_data, test_targets = test_dset.data, reassign_target(test_dset.targets, num_all_classes, seen_classes)
    seen_indices = np.where(test_targets < args.num_classes)[0]
    eval_dset = BasicDataset('supervised', test_data[seen_indices], test_targets[seen_indices],
                             len(seen_classes), transform_val, False, None, False)
    test_full_dset = BasicDataset('supervised', test_data, test_targets, num_all_classes, transform_val, False, None, False)
    
    inlier_test_loader = DataLoader(eval_dset, batch_size=100, shuffle=False)
    full_test_loader = DataLoader(test_full_dset, batch_size=100, shuffle=False)

    
    return inlier_test_loader, full_test_loader


def get_outlier(args, name, in_name, data_dir='./data'):

    mean, std = {}, {}
    mean['cifar10'] = [0.485, 0.456, 0.406]
    mean['cifar100'] = [x / 255 for x in [129.3, 124.1, 112.4]]

    mean['fashionmnist'] = [0.2860]
    std['fashionmnist'] = [0.3530]
    std['cifar10'] = [0.229, 0.224, 0.225]
    std['cifar100'] = [x / 255 for x in [68.2, 65.4, 70.4]]
    transform_ood = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean[in_name], std[in_name])
    ])
    
    if name == 'cifar10':
        dset = getattr(torchvision.datasets, name.upper())
        test_dset = dset(args.data_dir, train=False, download=True, transform=transform_ood)
        test_loader = DataLoader(test_dset, batch_size=100, shuffle=False)
        return test_loader
    
    elif name == 'cifar100':
        dset = getattr(torchvision.datasets, name.upper())
        test_dset = dset(args.data_dir, train=False, download=True, transform=transform_ood)
        test_loader = DataLoader(test_dset, batch_size=100, shuffle=False)
        return test_loader
    
    elif name == 'svhn':
        dset = getattr(torchvision.datasets, name.upper())
        test_dset = dset(args.data_dir, split='test', download=True, transform=transform_ood)
        test_loader = DataLoader(test_dset, batch_size=100, shuffle=False)
        return test_loader
    
    elif name == 'lsun':
        test_dir = os.path.join(data_dir, 'lsun')
        test_dset = torchvision.datasets.ImageFolder(test_dir, transform=transform_ood)
        test_loader = DataLoader(test_dset, batch_size=100, shuffle=False)
        return test_loader
    
    elif name == 'imagenet':
        test_dir = os.path.join(data_dir, 'imagenet')
        test_dset = torchvision.datasets.ImageFolder(test_dir, transform=transform_ood)
        test_loader = DataLoader(test_dset, batch_size=100, shuffle=False)
        return test_loader
    elif name == 'gaussian':
        test_dir = os.path.join('/ssd/sjheo/ossl/ProSub/prosub/data/data/', 'gaussian28')
        test_dset = torchvision.datasets.ImageFolder(test_dir, transform=transform_ood)
        test_loader = DataLoader(test_dset, batch_size=100, shuffle=False)
        return test_loader 