import collections.abc as container_abcs
import errno
import numpy as np
import os
import torch
import torch.optim as optim
from itertools import repeat
from torchvision.utils import save_image
from config import cfg


def check_exists(path):
    return os.path.exists(path)


def makedir_exist_ok(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise
    return


def save(input, path, protocol=2, mode='torch'):
    dirname = os.path.dirname(path)
    makedir_exist_ok(dirname)
    if mode == 'torch':
        torch.save(input, path, pickle_protocol=protocol)
    elif mode == 'numpy':
        np.save(path, input, allow_pickle=True)
    else:
        raise ValueError('Not valid save mode')
    return


def load(path, mode='torch'):
    if mode == 'torch':
        return torch.load(path, map_location=lambda storage, loc: storage)
    elif mode == 'numpy':
        return np.load(path, allow_pickle=True)
    else:
        raise ValueError('Not valid save mode')
    return


def save_img(img, path, nrow=10, padding=2, pad_value=0, range=None):
    makedir_exist_ok(os.path.dirname(path))
    normalize = False if range is None else True
    save_image(img, path, nrow=nrow, padding=padding, pad_value=pad_value, normalize=normalize, range=range)
    return


def to_device(input, device):
    output = recur(lambda x, y: x.to(y), input, device)
    return output


def ntuple(n):
    def parse(x):
        if isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
            return x
        return tuple(repeat(x, n))

    return parse


def apply_fn(module, fn):
    for n, m in module.named_children():
        if hasattr(m, fn):
            exec('m.{0}()'.format(fn))
        if sum(1 for _ in m.named_children()) != 0:
            exec('apply_fn(m,\'{0}\')'.format(fn))
    return


def recur(fn, input, *args):
    if isinstance(input, torch.Tensor) or isinstance(input, np.ndarray):
        output = fn(input, *args)
    elif isinstance(input, list):
        output = []
        for i in range(len(input)):
            output.append(recur(fn, input[i], *args))
    elif isinstance(input, tuple):
        output = []
        for i in range(len(input)):
            output.append(recur(fn, input[i], *args))
        output = tuple(output)
    elif isinstance(input, dict):
        output = {}
        for key in input:
            output[key] = recur(fn, input[key], *args)
    else:
        raise ValueError('Not valid input type')
    return output


def process_dataset(dataset):
    if cfg['data_name'] in ['MNIST', 'CIFAR10']:
        cfg['classes_size'] = dataset['train'].classes_size
    elif cfg['data_name'] in ['har']:
        cfg['classes_size'] = 6
    elif cfg['data_name'] in ['WikiText2']:
        cfg['vocab'] = dataset['train'].vocab
        cfg['num_tokens'] = len(dataset['train'].vocab)
        for split in dataset:
            dataset[split] = batchify(dataset[split], cfg['batch_size'][split])
    else:
        raise ValueError('Not valid data name')
    return


def process_control():
    cfg['model_split_rate'] = {'a': 1, 'b': 0.5, 'c': 0.25, 'd': 0.125, 'e': 0.0625}
    cfg['fed'] = int(cfg['control']['fed'])
    cfg['num_users'] = int(cfg['control']['num_users'])
    cfg['frac'] = float(cfg['control']['frac'])
    cfg['data_split_mode'] = cfg['control']['data_split_mode']
    cfg['model_split_mode'] = cfg['control']['model_split_mode']
    cfg['model_mode'] = cfg['control']['model_mode']
    cfg['norm'] = cfg['control']['norm']
    cfg['scale'] = bool(int(cfg['control']['scale']))
    cfg['mask'] = bool(int(cfg['control']['mask']))
    #cfg['fl_mode'] = cfg['control']['fl_mode']
    cfg['global_model_mode'] = cfg['model_mode'][0]
    cfg['global_model_rate'] = cfg['model_split_rate'][cfg['global_model_mode']]
    
    #for i in range(cfg['num_users']):
    cfg['comp_time'] = np.ones(cfg['num_users'])

    R = int(cfg['num_users']/10)
    for i in range(0,R):
        cfg['comp_time'][i] = 19
    for i in range(R, R*2):
        cfg['comp_time'][i] = 17
    for i in range(R*2, R*3):
        cfg['comp_time'][i] = 15
    for i in range(R*3, R*4):
        cfg['comp_time'][i] = 13
    for i in range(R*4, R*5):
        cfg['comp_time'][i] = 11
    for i in range(R*5, R*6):
        cfg['comp_time'][i] = 10.5
    for i in range(R*6, R*7):
        cfg['comp_time'][i] = 9
    for i in range(R*7, R*8):
        cfg['comp_time'][i] = 8.5
    for i in range(R*8, R*9):
        cfg['comp_time'][i] = 3
    for i in range(R*9,R*10):
        cfg['comp_time'][i] = 1

    
    cfg['comm_time'] = np.zeros(cfg['num_users'])


    model_mode = cfg['model_mode'].split('-')
    if cfg['model_split_mode'] == 'dynamic':
        mode_rate, proportion = [], []
        for m in model_mode:
            mode_rate.append(cfg['model_split_rate'][m[0]])
            proportion.append(int(m[1:]))
        cfg['model_rate'] = mode_rate
        cfg['proportion'] = (np.array(proportion) / sum(proportion)).tolist()
    elif cfg['model_split_mode'] == 'fix':
#        mode_rate, proportion = [], []
#        for m in model_mode:
#            mode_rate.append(cfg['model_split_rate'][m[0]])
#            proportion.append(int(m[1:]))
#        num_users_proportion = cfg['num_users'] // sum(proportion)
#        cfg['model_rate'] = []
#        for i in range(len(mode_rate)):
#            cfg['model_rate'] += np.repeat(mode_rate[i], num_users_proportion * proportion[i]).tolist()
#        cfg['model_rate'] = cfg['model_rate'] + [cfg['model_rate'][-1] for _ in
#                                                 range(cfg['num_users'] - len(cfg['model_rate']))]
        cfg['model_rate'] = np.ones(cfg['num_users'])
        cfg['model_comp'] = np.ones(cfg['num_users'])
        s = 2
        if s == 1:
            for i in range(cfg['num_users']):
                if cfg['comp_time'][i] + cfg['comm_time'][i] > 20:
                    cfg['model_rate'][i] = 0.0625
                    cfg['model_comp'][i] = 0.615
                elif cfg['comp_time'][i] + cfg['comm_time'][i] > 10:
                    cfg['model_rate'][i] = 0.125
                    cfg['model_comp'][i] = 0.69
                elif cfg['comp_time'][i] + cfg['comm_time'][i] > 5:
                    cfg['model_rate'][i] = 0.25
                    cfg['model_comp'][i] = 0.76
                elif cfg['comp_time'][i] + cfg['comm_time'][i] > 2.5:
                    cfg['model_rate'][i] = 0.5
                    cfg['model_comp'][i] = 0.84
                else:
                    cfg['model_rate'][i] = 1
                    cfg['model_comp'][i] = 1
        if s == 2:
            for i in range(cfg['num_users']):
                if cfg['comp_time'][i] > 16:
                        cfg['model_rate'][i] = 0.0625
                        cfg['model_comp'][i] = 0.615
                elif cfg['comp_time'][i] > 12.8:
                    if cfg['comm_time'][i] > 30:
                        cfg['model_rate'][i] = 0.0625
                        cfg['model_comp'][i] = 0.615
                    else:
                        cfg['model_rate'][i] = 0.125
                        cfg['model_comp'][i] = 0.69
                elif cfg['comp_time'][i] > 10.24:
                    if cfg['comm_time'][i] > 30:
                        cfg['model_rate'][i] = 0.0625
                        cfg['model_comp'][i] = 0.615
                    elif cfg['comm_time'][i] > 15:
                        cfg['model_rate'][i] = 0.125
                        cfg['model_comp'][i] = 0.69
                    else:
                        cfg['model_rate'][i] = 0.25
                        cfg['model_comp'][i] = 0.76
                elif cfg['comp_time'][i] > 8:
                    if cfg['comm_time'][i] > 30:
                        cfg['model_rate'][i] = 0.0625
                        cfg['model_comp'][i] = 0.615
                    elif cfg['comm_time'][i] > 15:
                        cfg['model_rate'][i] = 0.125
                        cfg['model_comp'][i] = 0.69
                    elif cfg['comm_time'][i] > 7.5:
                        cfg['model_rate'][i] = 0.25
                        cfg['model_comp'][i] = 0.76
                    else:
                        cfg['model_rate'][i] = 0.5
                        cfg['model_comp'][i] = 0.84
                else:
                    if cfg['comm_time'][i] > 30:
                        cfg['model_rate'][i] = 0.0625
                        cfg['model_comp'][i] = 0.615
                    elif cfg['comm_time'][i] > 15:
                        cfg['model_rate'][i] = 0.125
                        cfg['model_comp'][i] = 0.69
                    elif cfg['comm_time'][i] > 7.5:
                        cfg['model_rate'][i] = 0.25
                        cfg['model_comp'][i] = 0.76
                    elif cfg['comm_time'][i] > 3.75:
                        cfg['model_rate'][i] = 0.5
                        cfg['model_comp'][i] = 0.84
                    else:
                        cfg['model_rate'][i] = 1
                        cfg['model_comp'][i] = 1
        
    else:
        raise ValueError('Not valid model split mode')


    cfg['conv'] = {'hidden_size': [64, 128, 256, 512]}
    cfg['resnet'] = {'hidden_size': [64, 128, 256, 512]}
    cfg['transformer'] = {'embedding_size': 256, 'num_heads': 8, 'hidden_size': 512, 'num_layers': 4, 'dropout': 0.2}
    if cfg['data_name'] in ['MNIST']:
        cfg['data_shape'] = [1, 28, 28]
        cfg['optimizer_name'] = 'SGD'
        cfg['lr'] = 1e-2
        cfg['momentum'] = 0.9
        cfg['weight_decay'] = 5e-4
        cfg['scheduler_name'] = 'MultiStepLR'
        cfg['factor'] = 0.1
        if cfg['data_split_mode'] == 'iid':
            cfg['num_epochs'] = {'global': 200, 'local': 1}
            cfg['batch_size'] = {'train': 64, 'test': 64}
            cfg['milestones'] = [100]
        elif 'non-iid' in cfg['data_split_mode']:
            cfg['num_epochs'] = {'global': 200, 'local': 1}
            cfg['batch_size'] = {'train': 64, 'test': 64}
            cfg['milestones'] = [200]
        elif cfg['data_split_mode'] == 'none':
            cfg['num_epochs'] = 200
            cfg['batch_size'] = {'train': 100, 'test': 500}
            cfg['milestones'] = [100]
        else:
            raise ValueError('Not valid data_split_mode')
    elif cfg['data_name'] in ['har']:
        cfg['optimizer_name'] = 'SGD'
        cfg['lr'] = 1e-2
        cfg['momentum'] = 0.9
        cfg['weight_decay'] = 5e-4
        cfg['scheduler_name'] = 'MultiStepLR'
        cfg['factor'] = 0.1
        if cfg['data_split_mode'] == 'iid':
            cfg['num_epochs'] = {'global': 200, 'local': 1}
            cfg['batch_size'] = {'train': 64, 'test': 64}
            cfg['milestones'] = [100]
        elif 'non-iid' in cfg['data_split_mode']:
            cfg['num_epochs'] = {'global': 200, 'local': 1}
            cfg['batch_size'] = {'train': 64, 'test': 64}
            cfg['milestones'] = [200]
        elif cfg['data_split_mode'] == 'none':
            cfg['num_epochs'] = 200
            cfg['batch_size'] = {'train': 100, 'test': 500}
            cfg['milestones'] = [100]
    elif cfg['data_name'] in ['CIFAR10', 'CIFAR100']:
        cfg['data_shape'] = [3, 32, 32]
        cfg['optimizer_name'] = 'SGD'
        cfg['lr'] = 1e-1
        cfg['momentum'] = 0.9
        cfg['weight_decay'] = 5e-4
        cfg['scheduler_name'] = 'MultiStepLR'
        cfg['factor'] = 0.1
        if cfg['data_split_mode'] == 'iid':
            cfg['num_epochs'] = {'global': 200, 'local': 1}
            cfg['batch_size'] = {'train': 64, 'test': 64}
            cfg['milestones'] = [150, 250]
        elif 'non-iid' in cfg['data_split_mode']:
            cfg['num_epochs'] = {'global': 200, 'local': 1}
            cfg['batch_size'] = {'train': 64, 'test': 64}
            cfg['milestones'] = [300, 500]
        elif cfg['data_split_mode'] == 'none':
            cfg['num_epochs'] = 400
            cfg['batch_size'] = {'train': 100, 'test': 500}
            cfg['milestones'] = [150, 250]
        else:
            raise ValueError('Not valid data_split_mode')
    elif cfg['data_name'] in ['PennTreebank', 'WikiText2', 'WikiText103']:
        cfg['optimizer_name'] = 'SGD'
        cfg['lr'] = 1e-1
        cfg['momentum'] = 0.9
        cfg['weight_decay'] = 5e-4
        cfg['scheduler_name'] = 'MultiStepLR'
        cfg['factor'] = 0.1
        cfg['bptt'] = 64
        cfg['mask_rate'] = 0.15
        if cfg['data_split_mode'] == 'iid':
            cfg['num_epochs'] = {'global': 200, 'local': 1}
            cfg['batch_size'] = {'train': 100, 'test': 10}
            cfg['milestones'] = [50, 100]
        elif cfg['data_split_mode'] == 'none':
            cfg['num_epochs'] = 100
            cfg['batch_size'] = {'train': 100, 'test': 100}
            cfg['milestones'] = [25, 50]
        else:
            raise ValueError('Not valid data_split_mode')
    else:
        raise ValueError('Not valid dataset')
    return


def make_stats(dataset):
    if os.path.exists('./data/stats/{}.pt'.format(dataset.data_name)):
        stats = load('./data/stats/{}.pt'.format(dataset.data_name))
    elif dataset is not None:
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=False, num_workers=0)
        stats = Stats(dim=1)
        with torch.no_grad():
            for input in data_loader:
                stats.update(input['img'])
        save(stats, './data/stats/{}.pt'.format(dataset.data_name))
    return stats


class Stats(object):
    def __init__(self, dim):
        self.dim = dim
        self.n_samples = 0
        self.n_features = None
        self.mean = None
        self.std = None

    def update(self, data):
        data = data.transpose(self.dim, -1).reshape(-1, data.size(self.dim))
        if self.n_samples == 0:
            self.n_samples = data.size(0)
            self.n_features = data.size(1)
            self.mean = data.mean(dim=0)
            self.std = data.std(dim=0)
        else:
            m = float(self.n_samples)
            n = data.size(0)
            new_mean = data.mean(dim=0)
            new_std = 0 if n == 1 else data.std(dim=0)
            old_mean = self.mean
            old_std = self.std
            self.mean = m / (m + n) * old_mean + n / (m + n) * new_mean
            self.std = torch.sqrt(m / (m + n) * old_std ** 2 + n / (m + n) * new_std ** 2 + m * n / (m + n) ** 2 * (
                    old_mean - new_mean) ** 2)
            self.n_samples += n
        return


def make_optimizer(model, lr):
    if cfg['optimizer_name'] == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=cfg['momentum'],
                              weight_decay=cfg['weight_decay'])
    elif cfg['optimizer_name'] == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=lr, momentum=cfg['momentum'],
                                  weight_decay=cfg['weight_decay'])
    elif cfg['optimizer_name'] == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=cfg['weight_decay'])
    elif cfg['optimizer_name'] == 'Adamax':
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=cfg['weight_decay'])
    else:
        raise ValueError('Not valid optimizer name')
    return optimizer


def make_scheduler(optimizer):
    if cfg['scheduler_name'] == 'None':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[65535])
    elif cfg['scheduler_name'] == 'StepLR':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=cfg['step_size'], gamma=cfg['factor'])
    elif cfg['scheduler_name'] == 'MultiStepLR':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg['milestones'], gamma=cfg['factor'])
    elif cfg['scheduler_name'] == 'ExponentialLR':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
    elif cfg['scheduler_name'] == 'CosineAnnealingLR':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['num_epochs']['global'],
                                                         eta_min=cfg['min_lr'])
    elif cfg['scheduler_name'] == 'ReduceLROnPlateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=cfg['factor'],
                                                         patience=cfg['patience'], verbose=True,
                                                         threshold=cfg['threshold'], threshold_mode='rel',
                                                         min_lr=cfg['min_lr'])
    elif cfg['scheduler_name'] == 'CyclicLR':
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg['lr'], max_lr=10 * cfg['lr'])
    else:
        raise ValueError('Not valid scheduler name')
    return scheduler


def resume(model, model_tag, optimizer=None, scheduler=None, load_tag='checkpoint', strict=True, verbose=True):
    if cfg['data_split_mode'] != 'none':
        if os.path.exists('./output/model/{}_{}.pt'.format(model_tag, load_tag)):
            checkpoint = load('./output/model/{}_{}.pt'.format(model_tag, load_tag))
            last_epoch = checkpoint['epoch']
            data_split = checkpoint['data_split']
            label_split = checkpoint['label_split']
            model.load_state_dict(checkpoint['model_dict'], strict=strict)
            if optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer_dict'])
            if scheduler is not None:
                scheduler.load_state_dict(checkpoint['scheduler_dict'])
            logger = checkpoint['logger']
            if verbose:
                print('Resume from {}'.format(last_epoch))
        else:
            print('Not exists model tag: {}, start from scratch'.format(model_tag))
            from datetime import datetime
            from logger import Logger
            last_epoch = 1
            data_split = None
            label_split = None
            logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], datetime.now().strftime('%b%d_%H-%M-%S'))
            logger = Logger(logger_path)
        return last_epoch, data_split, label_split, model, optimizer, scheduler, logger
    else:
        if os.path.exists('./output/model/{}_{}.pt'.format(model_tag, load_tag)):
            checkpoint = load('./output/model/{}_{}.pt'.format(model_tag, load_tag))
            last_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model_dict'], strict=strict)
            if optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer_dict'])
            if scheduler is not None:
                scheduler.load_state_dict(checkpoint['scheduler_dict'])
            logger = checkpoint['logger']
            if verbose:
                print('Resume from {}'.format(last_epoch))
        else:
            print('Not exists model tag: {}, start from scratch'.format(model_tag))
            from datetime import datetime
            from logger import Logger
            last_epoch = 1
            logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], datetime.now().strftime('%b%d_%H-%M-%S'))
            logger = Logger(logger_path)
        return last_epoch, model, optimizer, scheduler, logger


def collate(input):
    for k in input:
#        print("???K???")
#        print(k.shape)
#        print(input[k])
        input[k] = torch.stack(input[k], 0)
#        print("@@@K@@@")
#        print(k)
#        print(input[k])
    return input


def batchify(dataset, batch_size):
    num_batch = len(dataset) // batch_size
    dataset.token = dataset.token.narrow(0, 0, num_batch * batch_size)
    dataset.token = dataset.token.reshape(batch_size, -1)
    return dataset
