import yaml

import os
import logging
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

def get_logger(name, save_path=None, level='INFO'):
    """
    create logger function
    """
    logger = logging.getLogger(name)
    logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s', level=getattr(logging, level))

    if not save_path is None:
        os.makedirs(save_path, exist_ok=True)
        log_format = logging.Formatter('[%(asctime)s %(levelname)s] %(message)s')
        fileHandler = logging.FileHandler(os.path.join(save_path, 'log.txt'))
        fileHandler.setFormatter(log_format)
        logger.addHandler(fileHandler)

    return logger


def get_server_dataset(cfgs):
    
    from src.fl_datasets.cv_datasets import get_cifar, get_cifar_server, get_fashionmnist, get_fashionmnist_server
    
    data_cfgs = cfgs['Dataset']
    dataset = data_cfgs['dataset']
    
    if dataset in ['cifar10', 'cifar100', 'cifar10_openset', 'cifar100_openset']:
        lb_data, lb_targets, _, test_data, test_targets = get_cifar(cfgs, 
                                                                    name=dataset, 
                                                                    num_labels=data_cfgs['num_labels'], 
                                                                    num_classes=data_cfgs['num_classes'], 
                                                                    data_dir=data_cfgs['data_dir'])
        lb_dset, test_dset = get_cifar_server(cfgs, lb_data, lb_targets, test_data, test_targets)
    elif dataset in ['fashionmnist', 'fashionmnist_openset']:
        lb_data, lb_targets, _, test_data, test_targets = get_fashionmnist(cfgs, 
                                                                           name=dataset, 
                                                                           num_labels=data_cfgs['num_labels'], 
                                                                           num_classes=data_cfgs['num_classes'], 
                                                                           data_dir=data_cfgs['data_dir'])
        lb_dset, test_dset = get_fashionmnist_server(cfgs, lb_data, lb_targets, test_data, test_targets)
    else:
        raise NotImplementedError
    
    server_data_dict = {'train': lb_dset, 'test': test_dset}
    
    return server_data_dict


def get_client_dataset(cfgs):
    
    from src.fl_datasets.cv_datasets import get_cifar, get_fashionmnist
    
    data_cfgs = cfgs['Dataset']
    dataset = data_cfgs['dataset']
    
    if dataset in ['cifar10', 'cifar100', 'cifar10_openset', 'cifar100_openset']:
        _, _, clients_set, _, _ = get_cifar(cfgs, 
                                            name=dataset, 
                                            num_labels=data_cfgs['num_labels'], 
                                            num_classes=data_cfgs['num_classes'], 
                                            data_dir=data_cfgs['data_dir'])
    elif dataset in ['fashionmnist', 'fashionmnist_openset']:
        _, _, clients_set, _, _ = get_fashionmnist(cfgs, 
                                                   name=dataset, 
                                                   num_labels=data_cfgs['num_labels'], 
                                                   num_classes=data_cfgs['num_classes'], 
                                                   data_dir=data_cfgs['data_dir'])
    else:
        raise NotImplementedError
    
    return clients_set


def get_client_specific_data(cfgs, cid, clients_set):
    
    from src.fl_datasets.cv_datasets import get_cifar_client, get_fashionmnist_client
    
    data_cfgs = cfgs['Dataset']
    dataset = data_cfgs['dataset']
    
    if dataset in ['cifar10', 'cifar100', 'cifar10_openset', 'cifar100_openset']:
        client_data = get_cifar_client(cfgs, cid, clients_set)
    elif dataset in ['fashionmnist', 'fashionmnist_openset']:
        client_data = get_fashionmnist_client(cfgs, cid, clients_set)
    else:
        raise NotImplementedError        
        
    return client_data


def get_dataloader(dset,
                   batch_size=None,
                   shuffle=False,
                   num_workers=2,
                   pin_memory=True,
                   drop_last=False, data_sampler=None):
    
    from src.fl_datasets.utils import BalancedBatchSampler

    assert batch_size is not None

    if data_sampler is None:
        return DataLoader(dset, 
                          batch_size=batch_size, 
                          shuffle=shuffle, 
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          drop_last=drop_last)
        
    elif data_sampler == 'balanced':
        labels = dset.targets
        sampler = BalancedBatchSampler(labels=labels, 
                                       batch_size=batch_size,
                                       num_classes_per_batch=len(np.unique(labels)))
        print("Balanced loader!")
        return DataLoader(dset,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          batch_sampler=sampler)
        
        
def get_net_builder(net_name, from_name: bool):
    """
    built network according to network name
    return **class** of backbone network (not instance).

    Args
        net_name: 'WideResNet' or network names in torchvision.models
        from_name: If True, net_buidler takes models in torch.vision models. Then, net_conf is ignored.
    """
    if from_name:
        import torchvision.models as nets
        model_name_list = sorted(name for name in nets.__dict__
                                 if name.islower() and not name.startswith("__")
                                 and callable(nets.__dict__[name]))

        if net_name not in model_name_list:
            assert Exception(f"[!] Networks\' Name is wrong, check net config, \
                               expected: {model_name_list}  \
                               received: {net_name}")
        else:
            return nets.__dict__[net_name]
    else:
        import src.nets as nets
        builder = getattr(nets, net_name)
        return builder
    
    
def get_optimizer(net, optim_name='SGD', lr=0.1, 
                  momentum=0.9, weight_decay=0, 
                  layer_decay=1.0, nesterov=True, bn_wd_skip=True):
    '''
    return optimizer (name) in torch.optim.
    If bn_wd_skip, the optimizer does not apply
    weight decay regularization on parameters in batch normalization.
    '''
    from src.nets.utils import param_groups_layer_decay, param_groups_weight_decay
    
    assert layer_decay <= 1.0

    no_decay = {}
    if hasattr(net, 'no_weight_decay') and bn_wd_skip:
        no_decay = net.no_weight_decay()
    
    if layer_decay != 1.0:
        per_param_args = param_groups_layer_decay(net, lr, weight_decay, 
                                                  no_weight_decay_list=no_decay, layer_decay=layer_decay)
    else:
        per_param_args = param_groups_weight_decay(net, weight_decay, 
                                                   no_weight_decay_list=no_decay)

    if optim_name == 'SGD':
        optimizer = torch.optim.SGD(per_param_args, 
                                    lr=lr, momentum=momentum, weight_decay=weight_decay,
                                    nesterov=nesterov)
    elif optim_name == 'AdamW':
        optimizer = torch.optim.AdamW(per_param_args, 
                                      lr=lr, weight_decay=weight_decay)

    return optimizer


def get_custom_cosine_scheduler(optimizer, 
                                first_cycle_step, cycle_mult, 
                                max_lr, min_lr,
                                warmup_steps,
                                gamma, last_epoch=-1):
    
    from src.core.utils import CosineAnnealingWarmupRestarts as cosine_scheduler
    scheduler = cosine_scheduler(optimizer,
                                 first_cycle_step,
                                 cycle_mult,
                                 max_lr,
                                 min_lr,
                                 warmup_steps,
                                 gamma,
                                 last_epoch=last_epoch)
    return scheduler


def load_config(cfg_path):
    with open(cfg_path, "r") as f:
        return yaml.safe_load(f)