import json
import hashlib
import os
from datetime import datetime
import argparse
import logging


def dict_hash(dictionary):
    """MD5 hash of a dictionary."""
    dhash = hashlib.md5()
    encoded = json.dumps(dictionary, sort_keys=True).encode()
    dhash.update(encoded)
    return dhash.hexdigest()

def run(settings_path,
        dry_run,
        param_override_dict=None,
        callback_intermediate_accuracy=None,
        callback_final_accuracy=None):
    '''
    Current algorithm options: [fl_settings.json]

    *FLBase
    *-->FedAvg
    *-->Caldas
    *-->SLT
    *-->FedRolex
    *-->HeteroFL
    *-->FjORD
    '''

    with open(settings_path) as fd:
        settings = json.load(fd)
    
    run = settings['param__run']
    settings.pop('param__run')

    if param_override_dict is not None:
        try:
            settings.update({'algorithm': param_override_dict['algorithm']})
        except KeyError: pass
        try:
            settings.update({'dataset': param_override_dict['dataset']})
        except KeyError: pass

    #remove unwanted items for hashing
    remove_list = []
    add_list = []
    for key in settings:
        if key.startswith("param_dataset__"):
            if settings['dataset'] in settings[key]["for"]:
                for key2 in settings[key]:
                    if key2 != "for":
                        add_list.append({key2: settings[key][key2]})

        elif key.startswith("param__"):
            if settings['algorithm'] in settings[key]["for"]:
                for key2 in settings[key]:
                    if key2 != "for":
                        add_list.append({key2: settings[key][key2]})

    for key in settings:
        if 'param__' in key or 'param_dataset__' in key:
            remove_list.append(key)    
    for item in remove_list:
        settings.pop(item)
    for item in add_list:
        settings.update(item)

    if param_override_dict is not None:
        settings.update(param_override_dict)

    print("{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in settings.items()) + "}")
    assert settings['n_devices'] >= settings['n_active_devices'], f"Cannot be more active devices than overall devices"

    #set GPU's
    if run['torch_device'].startswith('cuda'):
        os.environ["CUDA_VISIBLE_DEVICES"] = run['torch_device'].split('cuda:')[1]
        run['torch_device'] = 'cuda'

    import torch
    torch.manual_seed(settings['seed'])

    from torchvision import datasets

    # ensure scale_factor is stored as a list (also in the hash)
    if type(settings['param_scale_factor']) != list:
        settings['param_scale_factor'] = [settings['param_scale_factor']]

    path_hash = dict_hash(settings)
    cnn_args = {}
    device_cnn_args = {}

    run_path = "runs/" + settings['session_tag'] + '/run_'


    # Algorithm
    if settings['algorithm'] == 'FedAvg':
        from algorithms.fedavg import FedAvgSever as Server
    elif settings['algorithm'] == 'FedRolex':
        from algorithms.fedrolex import FedRolexServer as Server
    elif settings['algorithm'] == 'Caldas':
        from algorithms.caldas import CaldasServer as Server
    elif settings['algorithm'] == 'SLT':
        from algorithms.SLT import SLTServer as Server
    elif settings['algorithm'] == 'HeteroFL':
        from algorithms.heterofl import HeteroFLServer as Server
    elif settings['algorithm'] == 'FjORD':
        from algorithms.fjord import FjordServer as Server
    else:
        raise ValueError(settings['algorithm'])
    
    flserver = Server(run_path + path_hash)
        
    # Set Keep factor list
    if settings['algorithm'] == 'FedAvg':
        #small model has to fallback to smallest model in heterogeneous case
        minimal_scale_factor = min(settings['param_scale_factor'])
        flserver.scale_factor_list = [minimal_scale_factor for _ in range(settings['n_devices'])]
    else:
        flserver.scale_factor_list = [settings['param_scale_factor'][
            int(i*len(settings['param_scale_factor'])/settings['n_devices'])] for i in range(settings['n_devices'])]

    flserver.n_active_devices = settings['n_active_devices']
    flserver.n_devices = settings['n_devices']
    flserver.torch_device = run['torch_device']
    flserver.n_rounds = run['n_rounds']
    flserver.lr = settings['lr']
    flserver.lr_min = float(settings['lr']/settings['lr_reduction_factor'])
    flserver.set_seed(settings['seed'])
    flserver.report_intermediate_f = callback_intermediate_accuracy
    flserver.report_final_f = callback_final_accuracy

    flserver.progress_output = run['progress_bar']

    # Optimizer
    if settings['optimizer'] == 'adam':
        flserver.set_optimizer(torch.optim.Adam, {'weight_decay' : settings['weight_decay']})
    elif settings['optimizer'] == 'sgd':
        if settings['weight_decay'] is not None:
            flserver.set_optimizer(torch.optim.SGD, {'weight_decay' : settings['weight_decay'], 'momentum' : 0.9})
        else:
            flserver.set_optimizer(torch.optim.SGD, {'momentum' : 0.9})
    else: raise NotImplementedError

    # Dataset
    if 'CIFAR' in settings['dataset']:
        from utils.datasets.cifar import tf_cifar_train, tf_cifar_test
        kwargs = {'download' : True}
        if settings['dataset'].endswith('100'):
            flserver.set_dataset(datasets.CIFAR100, "/tmp/", kwargs, tf_cifar_train, tf_cifar_test)
            cnn_args.update({'num_classes' : 100})
        elif settings['dataset'].endswith('10'):
            flserver.set_dataset(datasets.CIFAR10, "/tmp/", kwargs, tf_cifar_train, tf_cifar_test)
            cnn_args.update({'num_classes' : 10})
    elif 'FEMNIST' in settings['dataset']:
        from utils.datasets.femnist import tf_femnist_train, tf_femnist_test, FEMNIST
        kwargs = {}
        flserver.set_dataset(FEMNIST, "data/", kwargs, tf_femnist_train, tf_femnist_test)
        cnn_args.update({'num_classes': 62})
    elif 'TINYIMAGENET' in settings['dataset']:
        from utils.datasets.tinyimagenet import TinyImageNetDataset, tf_tinyimagenet_test, tf_tinyimagenet_train
        kwargs = {}
        flserver.set_dataset(TinyImageNetDataset, "/data", kwargs, tf_tinyimagenet_train, tf_tinyimagenet_test)
        cnn_args.update({'num_classes': 200})
    else: raise NotImplementedError
    net_eval = None

    from nets.SubsetNets.resnet_cifar import ResNet20
    from nets.SubsetNets.densenet_cifar import DenseNet40

    # Models
    if settings['model'] == 'ResNet20':
        if settings['algorithm'] == 'SLT':
            from nets.SLTNets.resnet_cifar import ResNet20 as SLTResNet20
            from utils.SLT_submodel import extract_submodel_resnet_structure
            net = SLTResNet20
            net_eval = SLTResNet20
            flserver.configs_path = 'nets/SLTNets/configs/config__resnet20.json'
            flserver.extract_fnc = extract_submodel_resnet_structure
        else:
            net = ResNet20
            net_eval = ResNet20
            if settings['algorithm'] == 'Caldas':
                from utils.caldas_submodel import extract_submodel_resnet_structure
                flserver.extract_fnc = extract_submodel_resnet_structure
            elif settings['algorithm'] == 'FedRolex':
                from utils.fedrolex_submodel import extract_submodel_resnet_structure
                flserver.extract_fnc = extract_submodel_resnet_structure
            elif settings['algorithm'] == 'HeteroFL':
                from utils.heterofl_submodel import extract_submodel_resnet_structure
                flserver.extract_fnc = extract_submodel_resnet_structure

    elif settings['model'] == 'ResNet32Large':
        from nets.SubsetNets.resnet_cifar import ResNet32Large as ResNet32
        if settings['algorithm'] == 'SLT':
            from nets.SLTNets.resnet_cifar import ResNet32Large as SLTResNet32
            from utils.SLT_submodel import extract_submodel_resnet_structure
            net = SLTResNet32
            net_eval = SLTResNet32
            flserver.configs_path = 'nets/SLTNets/configs/config__resnet32large.json'
            flserver.extract_fnc = extract_submodel_resnet_structure
        else:
            net = ResNet32
            net_eval = ResNet32
            if settings['algorithm'] == 'Caldas':
                from utils.caldas_submodel import extract_submodel_resnet_structure
                flserver.extract_fnc = extract_submodel_resnet_structure
            elif settings['algorithm'] == 'FedRolex':
                from utils.fedrolex_submodel import extract_submodel_resnet_structure
                flserver.extract_fnc = extract_submodel_resnet_structure
            elif settings['algorithm'] == 'HeteroFL' or settings['algorithm'] == 'FjORD':
                from utils.heterofl_submodel import extract_submodel_resnet_structure
                flserver.extract_fnc = extract_submodel_resnet_structure

    elif settings['model'] == 'DenseNet40':
        if settings['algorithm'] == 'SLT':
            from nets.SLTNets.densenet_cifar import DenseNet40 as SLTDenseNet40
            from utils.SLT_submodel import extract_submodel_densenet_structure

            net = SLTDenseNet40
            net_eval = SLTDenseNet40
            flserver.configs_path = 'nets/SLTNets/configs/config__densenet40.json'
            flserver.extract_fnc = extract_submodel_densenet_structure
        else:
            net = DenseNet40
            net_eval = DenseNet40
            if settings['algorithm'] == 'Caldas':
                from utils.caldas_submodel import extract_submodel_densenet_structure
                flserver.extract_fnc = extract_submodel_densenet_structure
            elif settings['algorithm'] == 'FedRolex':
                from utils.fedrolex_submodel import extract_submodel_densenet_structure
                flserver.extract_fnc = extract_submodel_densenet_structure
            elif settings['algorithm'] == 'HeteroFL' or settings['algorithm'] == 'FjORD':
                from utils.heterofl_submodel import extract_submodel_densenet_structure
                flserver.extract_fnc = extract_submodel_densenet_structure

    else: raise ValueError(settings['model'])

    # Data Split
    from utils.split import split_iid, split_noniid, split_noniid_resources
    if 'noniid_resources' in settings['dataset']:
        flserver.split_f = split_noniid_resources(settings['param_dirichlet_alpha'], run_path + path_hash,
                                                        run['plot'] if not dry_run else False, settings['seed'])
    elif 'noniid_' in settings['dataset']:
        flserver.split_f = split_noniid(settings['param_dirichlet_alpha'], run_path + path_hash,
                                                run['plot'] if not dry_run else False, settings['seed'])
    else:
        flserver.split_f = split_iid(run_path + path_hash, run['plot'] if not dry_run else False, settings['seed'])

    import copy
    device_cnn_args.update(cnn_args)
    device_cnn_args_list = [copy.deepcopy(device_cnn_args) for _ in range(settings['n_devices'])]
    for idx, args in enumerate(device_cnn_args_list):
        args.update({'scale_factor' : flserver.scale_factor_list[idx]})
    flserver.set_model([net for _ in range(settings['n_devices'])], device_cnn_args_list)

    if settings['algorithm'] == 'FedAvg':
        # Evaluate on server with subset network
        args = copy.deepcopy(cnn_args)
        args.update({'scale_factor' : min(settings['param_scale_factor'])})
        flserver.set_model_evaluation(net_eval, args)
    elif settings['algorithm'] == 'HeteroFL' or settings['algorithm'] == 'FjORD':
        # Evaluation with largest device network
        args = copy.deepcopy(cnn_args)
        args.update({'scale_factor': max(settings['param_scale_factor'])})
        flserver.set_model_evaluation(net_eval, args)
    else:
        # Evaluate on server with Full NN (scale_factor=1.0)
        args = copy.deepcopy(cnn_args)
        args.update({'scale_factor' : 1.0})
        flserver.set_model_evaluation(net_eval, args)

    if dry_run:
        print("DRY RUN PERFORMED SUCESSFULLY")
        print(f"Settings Hash: {path_hash}")
        flserver.initialize()
        return flserver
    try: 
        os.makedirs(run_path + path_hash)
        with open (run_path + path_hash + "/" + "fl_setting.json", "w") as fd:
            json.dump(settings,fd, indent=4)
    except FileExistsError:
        pass

    logging.basicConfig(format='%(asctime)s - %(message)s',
                            filename=run_path + path_hash + '/run.log', level=logging.INFO, filemode='w')
    logging.info('Started')
    print(f"Settings Hash: {path_hash}")
    logging.info(f"Settings Hash: {path_hash}")
    logging.info("{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in settings.items()) + "}")
    flserver.initialize()

    with open(run_path + path_hash + "/" + "t_start.json", "w") as fd:
            json.dump({"t_start": datetime.now().strftime('%Y_%M_%d_%H_%m_%s')}, fd, indent=4)
    
    if run['plot'] is True:
        import visualization.plots as plots
    
        flserver.set_plotting_callback(plots.plot_config, run_path + path_hash)

    logging.info(f'scale_factor_list: {flserver.scale_factor_list}')
    try:
        flserver.run()
    except KeyboardInterrupt:
        pass
    
    if run['plot'] is True:
        try:
            plots.plot_config(run_path + path_hash)
        except:
            print("Error plotting!")
            logging.info(f'Final plotting failed')

    try:
        os.unlink("latest_run")
    except FileNotFoundError:
        pass
    os.symlink(run_path + path_hash, "latest_run")

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='FL commands')
    parser.add_argument("--path", type=str, nargs=1, default=["fl_setting.json"], help="FL settings file location path")
    parser.add_argument("--param", nargs='+')
    parser.add_argument("--dry_run", action='store_true')
    args = parser.parse_args()

    override_dict = {}

    if args.param is not None:
        assert len(args.param) % 2 == 0, f"parameters do not match {args.param}"
        value_list = [args.param[i] for i in range(1,len(args.param),2)]
        key_list = [args.param[i] for i in range(0,len(args.param),2)]
        override_dict = {key_list[i] : value_list[i] for i in range(len(key_list))}

    try:
        override_dict['seed'] = int(override_dict['seed'])
    except KeyError: pass
    try:
        override_dict['param_dirichlet_alpha'] = float(override_dict['param_dirichlet_alpha'])
    except KeyError: pass
    try:
        try:
            override_dict['param_scale_factor'] = float(override_dict['param_scale_factor'])
        except ValueError:
            if override_dict['param_scale_factor'].startswith('['):
                value_string = override_dict['param_scale_factor']
                value_string = value_string.replace('[', '').replace(']', '').replace(' ', '')
                value_string_list = value_string.split(',')
                values = [float(item) for item in value_string_list]
                override_dict['param_scale_factor'] = values
                print(values)
    except KeyError: pass

    print("override_dict \n{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in override_dict.items()) + "}")
    run(args.path[0], args.dry_run, param_override_dict=override_dict)
