import os

import numpy as np

import torchvision
from torchvision import transforms

from src.fl_datasets.augmentation import RandAugmentMC
from src.fl_datasets.utils import split_data, split_clients, reassign_target
from .dataset_base import BasicDataset


def get_cifar(cfgs, name, num_labels, num_classes, data_dir='./data'):

    data_cfgs = cfgs['Dataset']
    
    name = name.split('_')[0]  # cifar10_openset -> cifar10
    data_dir = os.path.join(data_dir, name.lower())
    dset = getattr(torchvision.datasets, name.upper())
    
    # Train dataset
    train_dset = dset(data_dir, train=True, download=True)
    train_data, train_targets = train_dset.data, train_dset.targets

    if name == 'cifar10':
        seen_classes = set(range(2, 8))
        num_all_classes = 10
    elif name == 'cifar100':
        num_super_classes = num_classes // 5  # num_classes: # of inlier classes
        num_all_classes = 100
        super_classes = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
                                  3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
                                  6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
                                  0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
                                  5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
                                  16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
                                  10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
                                  2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
                                  16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
                                  18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
        seen_classes = set(np.arange(num_all_classes)[super_classes < num_super_classes])
        
    # = split labeled / unlabeled
    lb_data, lb_targets, ulb_data, ulb_targets = split_data(cfgs=data_cfgs, 
                                                            data=train_data, 
                                                            target=train_targets, 
                                                            num_labels=num_labels, 
                                                            num_all_classes=num_all_classes, 
                                                            seen_classes=seen_classes, 
                                                            index=None, include_lb_to_ulb=False)
    
    # = split clients dataset
    clients_set = split_clients(cfgs=data_cfgs, 
                                ulb_data=ulb_data, 
                                ulb_targets=ulb_targets)
    
    # Test dataset
    test_dset = dset(data_dir, train=False, download=True)
    test_data, test_targets = test_dset.data, reassign_target(test_dset.targets, num_all_classes, seen_classes)
    

    return lb_data, lb_targets, clients_set, test_data, test_targets


def get_cifar_server(cfgs, lb_data, lb_targets, test_data, test_targets):
    
    data_cfgs = cfgs['Dataset']
    name = data_cfgs['dataset']
    name = name.split('_')[0]
    
    alg = cfgs['server_alg']
    
    # trasform
    transform_weak, transform_strong, transform_val = set_transform(alg=alg, 
                                                                    crop_size=data_cfgs['image_size'], 
                                                                    crop_ratio=data_cfgs['crop_ratio'], 
                                                                    name=name)
    
    # labeled dataset
    num_classes = data_cfgs['num_classes']
    lb_dset = BasicDataset(alg=alg, 
                           data=lb_data, targets=lb_targets, 
                           num_classes=num_classes, transform=transform_weak, 
                           is_ulb=False, strong_transform=transform_strong)
    # test dataset
    num_all_classes = int(name.split('cifar')[-1])
    test_dset = BasicDataset(alg='supervised', 
                             data=test_data, 
                             targets=test_targets, 
                             num_classes=num_all_classes, transform=transform_val, 
                             is_ulb=False, strong_transform=None)
    
    return lb_dset, test_dset


def get_cifar_client(cfgs, cid, clients_set):
    
    data_cfgs = cfgs['Dataset']
    name = data_cfgs['dataset']
    name = name.split('_')[0]
    
    alg = cfgs['client_alg']

    # trasform
    transform_weak, transform_strong, _ = set_transform(alg=alg, 
                                                        crop_size=data_cfgs['image_size'], 
                                                        crop_ratio=data_cfgs['crop_ratio'], 
                                                        name=name)
    
    # unlabeled dataset
    num_all_classes = int(name.split('cifar')[-1])
    
    c_dataset = clients_set[cid]
    c_data, c_targets = c_dataset['data'], c_dataset['targets']
    
    if alg == 'supervised':
        is_ulb = False
    else:
        is_ulb = True
    c_dset = BasicDataset(alg=alg, 
                          data=c_data, targets=c_targets, 
                          num_classes=num_all_classes, transform=transform_weak, 
                          is_ulb=is_ulb, strong_transform=transform_strong)
    
    return c_dset


def set_transform(alg, crop_size, crop_ratio, name):

    mean, std = {}, {}
    mean['cifar10'] = [0.485, 0.456, 0.406]
    mean['cifar100'] = [x / 255 for x in [129.3, 124.1, 112.4]]

    std['cifar10'] = [0.229, 0.224, 0.225]
    std['cifar100'] = [x / 255 for x in [68.2, 65.4, 70.4]]
    
    transform_ = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])
    
    transform_weak = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.RandomCrop(crop_size, padding=int(crop_size * (1 - crop_ratio)), padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])
    
    transform_strong = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.RandomCrop(crop_size, padding=int(crop_size * (1 - crop_ratio)), padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        RandAugmentMC(n=2, m=10),
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name])
    ])
    
    transform_eval = transforms.Compose([
        transforms.Resize(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean[name], std[name], )
    ])
    
    if alg in ['openmatch', 'ours', 'prosub']:
        transform_weak = [transform_weak, transform_]
        transform_strong = transform_strong
    else:
        transform_weak = transform_weak
        transform_strong = transform_strong
    
    return transform_weak, transform_strong, transform_eval