import json
import hashlib
import os
from datetime import datetime
import argparse
import logging


def dict_hash(dictionary):
    """MD5 hash of a dictionary."""
    dhash = hashlib.md5()
    encoded = json.dumps(dictionary, sort_keys=True).encode()
    dhash.update(encoded)
    return dhash.hexdigest()

def run(settings_path,
        dry_run,
        param_override_dict=None,
        callback_intermediate_accuracy=None,
        callback_final_accuracy=None):
    '''
    Current algorithm options: [fl_settings.json]

    *Centralized
    *FedAvg
    *-->CoCoFL
    *-->FedAvgDropDevices
    *-->HeteroFL
    *-->FjORD
    '''

    with open(settings_path) as fd:
        settings = json.load(fd)
    
    run = settings['param__run']
    settings.pop('param__run')

    if param_override_dict is not None:
        try:
            settings.update({'algorithm': param_override_dict['algorithm']})
        except KeyError: pass
        try:
            settings.update({'dataset': param_override_dict['dataset']})
        except KeyError: pass

    #remove unwanted items for hashing
    remove_list = []
    add_list = []
    for key in settings:
        if key.startswith("param_dataset__"):
            if settings['dataset'] in settings[key]["for"]:
                for key2 in settings[key]:
                    if key2 != "for":
                        add_list.append({key2: settings[key][key2]})

        elif key.startswith("param__"):
            if settings['algorithm'] in settings[key]["for"]:
                for key2 in settings[key]:
                    if key2 != "for":
                        add_list.append({key2: settings[key][key2]})

    for key in settings:
        if 'param__' in key or 'param_dataset__' in key:
            remove_list.append(key)    
    for item in remove_list:
        settings.pop(item)
    for item in add_list:
        settings.update(item)

    if param_override_dict is not None:
        settings.update(param_override_dict)

    print("{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in settings.items()) + "}")
    assert settings['n_devices'] >= settings['n_active_devices'], f"Cannot be more active devices than overall devices"

    #set GPU's
    if run['torch_device'].startswith('cuda'):
        os.environ["CUDA_VISIBLE_DEVICES"] = run['torch_device'].split('cuda:')[1]
        run['torch_device'] = 'cuda'

    import torch
    torch.manual_seed(settings['seed'])

    torch.set_num_threads(12)
    torch.set_num_interop_threads(12)

    from utils.resources import DeviceResources, Constant, Uniform
    from torchvision import datasets, transforms

    path_hash = dict_hash(settings)

    run_path = "runs/" + settings['session_tag'] + '/run_'

    if settings['algorithm'] == 'FedAvg':
        from algorithms.fedavg import FedAvgSever
        flserver = FedAvgSever(run_path + path_hash)
    elif settings['algorithm'] == 'CoCoFL':
        from algorithms.cocofl import CoCoFLServer
        flserver = CoCoFLServer(run_path + path_hash)
    elif settings['algorithm'] == 'FedAvgDropDevices':
        from algorithms.fedavg_drop_devices import FedAvgDropDevicesServer
        flserver = FedAvgDropDevicesServer(run_path + path_hash)
    elif settings['algorithm'] == 'HeteroFL':
        from algorithms.heterofl import HeteroFLServer
        flserver = HeteroFLServer(run_path + path_hash)
    elif settings['algorithm'] == 'FjORD':
        from algorithms.fjord import FjordServer
        flserver = FjordServer(run_path + path_hash)
        flserver.set_fjord_p_values(settings['param_fjord_p'])
    elif settings['algorithm'] == 'Centralized':
        from algorithms.centralized import CentralizedServer
        settings['n_active_devices'] = 1
        settings['n_devices'] = 1
        path_hash = dict_hash(settings)
        flserver = CentralizedServer(run_path + path_hash)
    else:
        raise ValueError(settings['algorithm'])
        
    for key in settings:
        if key.startswith('param_'):
            if key == 'param_round_resources':
                if settings[key] == 'strong_medium_weak_uniform':
                    device_constraints = [DeviceResources() for _ in range(settings['n_devices'])]
                    for resource in device_constraints:
                        resource.set_time_selection_F(Constant(1.0))
                        resource.set_data_selection_F(Constant(1.0))
                        resource.set_memory_selection_F(Constant(1.0))

                    for resource in device_constraints[int(0.33*settings['n_devices']):int(0.66*settings['n_devices'])]:
                        resource.set_time_selection_F(Constant(0.66))
                        resource.set_data_selection_F(Uniform(0.5, 1.0))
                        resource.set_memory_selection_F(Constant(0.66))

                    for resource in device_constraints[int(0.66*settings['n_devices']):]:
                        resource.set_time_selection_F(Constant(0.33))
                        resource.set_data_selection_F(Uniform(0.5, 1.0))
                        resource.set_memory_selection_F(Constant(0.33))

                    flserver.set_device_constraints(device_constraints)

                elif settings[key].startswith('ablation__'):
                    device_constraints = [DeviceResources() for _ in range(settings['n_devices'])]
                    ablation_constraint = float(settings[key].split('__')[1])
                    for resource in device_constraints:
                        resource.set_time_selection_F(Constant(1.0))
                        resource.set_data_selection_F(Constant(1.0))
                        resource.set_memory_selection_F(Constant(1.0))

                    for resource in device_constraints[int(0.1*settings['n_devices']):]:
                        resource.set_time_selection_F(Constant(ablation_constraint))
                        resource.set_data_selection_F(Constant(1.0))
                        resource.set_memory_selection_F(Constant(ablation_constraint))

                    flserver.set_device_constraints(device_constraints)
                else: ValueError(settings[key])

            elif key == 'param_async_averaging_ratio':
                flserver.async_averaging_ratio = settings[key]

    flserver.n_active_devices = settings['n_active_devices']
    flserver.n_devices = settings['n_devices']
    flserver.torch_device = run['torch_device']
    flserver.n_rounds = run['n_rounds']
    flserver.lr = settings['lr']
    flserver.set_seed(settings['seed'])
    flserver.report_intermediate_f = callback_intermediate_accuracy
    flserver.report_final_f = callback_final_accuracy
    flserver.lr_schedule = settings['lr_schedule'] if settings['lr_schedule'] else None

    flserver.progress_output = run['progress_bar']

    if settings['optimizer'] == 'adam':
        flserver.set_optimizer(torch.optim.Adam, {})
    elif settings['optimizer'] == 'sgd':
        if settings['weight_decay'] is not None:
            flserver.set_optimizer(torch.optim.SGD, {'weight_decay' : settings['weight_decay']})
        else:
            flserver.set_optimizer(torch.optim.SGD, {})
    else: raise NotImplementedError

    if 'CIFAR' in settings['dataset']:
        tf = transforms.Compose(
                [transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        kwargs = {'download' : True, 'transform': tf}
        if settings['dataset'].endswith('100'):
            flserver.set_dataset(datasets.CIFAR100, "/tmp/", **kwargs)
            cnn_args = {'num_classes' : 100}
        elif settings['dataset'].endswith('10'):
            flserver.set_dataset(datasets.CIFAR10, "/tmp/", **kwargs)
            cnn_args = {'num_classes' : 10}
    elif 'CINIC10' in settings['dataset']:
        from utils.datasets.cinic10 import CINIC10
        tf = transforms.Compose(
                    [transforms.ToTensor(), transforms.Normalize((0.478, 0.472, 0.430), (0.242, 0.238, 0.258))])
        kwargs = {'download' : True, 'transform': tf}
        flserver.set_dataset(CINIC10, "/tmp/", **kwargs)
        cnn_args = {'num_classes' : 10}
    elif 'XCHEST' in settings['dataset']:
        from utils.datasets.xchest import XCHEST
        tf = transforms.Compose([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
        kwargs = {'download' : True, 'transform' : tf}
        flserver.set_dataset(XCHEST, "data/", **kwargs)
        flserver.is_unbalanced = True
        cnn_args = {'num_classes': 2}
    elif 'FEMNIST' in settings['dataset']:
        from utils.datasets.femnist import FEMNIST, femnist_to_cifar_format_transform
        tf = transforms.Compose([femnist_to_cifar_format_transform()])
        kwargs = {'transform' : tf} 
        flserver.set_dataset(FEMNIST, "data/", **kwargs)
        cnn_args = {'num_classes' : 62}
    elif 'IMDB' in settings['dataset']:
        from utils.datasets.imdb import IMDB
        seq_len = 512
        kwargs = {'seq_len' : seq_len}
        cnn_args = {}
        flserver.set_dataset(IMDB, "data/", **kwargs)
        if settings['algorithm'] in ['FjORD', 'HeteroFL']:
            flserver.drop_weakest = True
            print(f'[INFO]: {settings["algorithm"]} HAS TO DROP WEAKEST DEVICES')
    elif 'SHAKESPEARE' in settings['dataset']:
        from utils.datasets.shakespeare import SHAKESPEARE
        kwargs = {}
        cnn_args = {}
        flserver.set_dataset(SHAKESPEARE, "data/", **kwargs)
    else: raise ValueError(settings['dataset'])
    net_eval = None

    if 'ResNet18' in settings['model']:
        from nets.Baseline.ResNet.resnet import ResNet18 as baseline_ResNet18
        net_eval = baseline_ResNet18
        if settings['algorithm'] in ['Centralized', 'FedAvg', 'FedAvgDropDevices']:
            from nets.Baseline.ResNet.resnet import ResNet18
            net = ResNet18
            net_eval = ResNet18
        elif settings['algorithm'] in ['CoCoFL']:
            net_eval = baseline_ResNet18
            from nets.QuantizedNets.ResNet.resnet import QResNet18
            if settings['model'] == 'QResNet18': net = QResNet18
            else: raise ValueError(settings['model'])
        elif settings['algorithm'] in ['FjORD', 'HeteroFL']:
            from nets.SubsetNets.ResNet.resnet import SResNet18
            if settings['model'] == 'SResNet18' : net = SResNet18
            else: raise ValueError(settings['model'])
    elif 'ResNet50' in settings['model']:
        from nets.Baseline.ResNet.resnet import ResNet50 as baseline_ResNet50
        net_eval = baseline_ResNet50
        if settings['algorithm'] in ['Centralized', 'FedAvg', 'FedAvgDropDevices']:
            from nets.Baseline.ResNet.resnet import ResNet50
            net = ResNet50
            net_eval = ResNet50
        elif settings['algorithm'] in ['CoCoFL']:
            net_eval = baseline_ResNet50
            from nets.QuantizedNets.ResNet.resnet import QResNet50
            if settings['model'] == 'QResNet50': net = QResNet50
            else: raise ValueError(settings['model'])
        elif settings['algorithm'] in ['FjORD', 'HeteroFL']:
            from nets.SubsetNets.ResNet.resnet import SResNet50
            if settings['model'] == 'SResNet50': net = SResNet50
            else: raise ValueError(settings['model'])
    elif 'MobileNetLarge' in settings['model']:
        from nets.Baseline.MobileNet.mobilenet import MobileNetV2Large
        if settings['algorithm'] in ['Centralized', 'FedAvg', 'FedAvgDropDevices']:
            net_eval = MobileNetV2Large
            net = MobileNetV2Large
        elif settings['algorithm'] in ['CoCoFL']:
            from nets.QuantizedNets.MobileNet.mobilenet import QMobileNetLarge
            net_eval = MobileNetV2Large
            net = QMobileNetLarge
        elif settings['algorithm'] in ['FjORD', 'HeteroFL']:
            from nets.SubsetNets.MobileNet.mobilenet import SMobileNetLarge
            net_eval = MobileNetV2Large
            net = SMobileNetLarge
    elif 'MobileNet' in settings['model']:
        from nets.Baseline.MobileNet.mobilenet import MobileNetV2 as baseline_MobileNetV2
        net_eval = baseline_MobileNetV2
        if settings['algorithm'] in ['Centralized', 'FedAvg', 'FedAvgDropDevices']:
            net = baseline_MobileNetV2
            net_eval = baseline_MobileNetV2
        elif settings['algorithm'] in ['CoCoFL']:
            net_eval = baseline_MobileNetV2
            from nets.QuantizedNets.MobileNet.mobilenet import MobileNet, QMobileNet, FMobileNet
            from nets.QuantizedNets.MobileNet.mobilenet import MobileNet_ablation, QMobileNet_ablation, \
                                    FMobileNet_ablation
            if settings['model'] == 'QMobileNet': net = QMobileNet
            elif settings['model'] == 'MobileNet': net = MobileNet
            elif settings['model'] == 'FMobileNet': net = FMobileNet
            
            # Ablation Models
            elif settings['model'] == 'MobileNet_ablation': net = MobileNet_ablation
            elif settings['model'] == 'QMobileNet_ablation': net = QMobileNet_ablation
            elif settings['model'] == 'FMobileNet_ablation': net = FMobileNet_ablation

        elif settings['algorithm'] in ['FjORD', 'HeteroFL']:
            from nets.SubsetNets.MobileNet.mobilenet import SMobileNet
            if settings['model'] == 'SMobileNet': net = SMobileNet
            else: raise ValueError(settings['model'])
    elif 'DenseNet' in settings['model']:
        from nets.Baseline.DenseNet.densenet import DenseNet40 as baseline_DenseNet40
        net_eval = baseline_DenseNet40
        if settings['algorithm'] in ['Centralized', 'FedAvg', 'FedAvgDropDevices']:
            net = baseline_DenseNet40
            net_eval = baseline_DenseNet40
        elif settings['algorithm'] in ['CoCoFL']:
            net_eval = baseline_DenseNet40
            from nets.QuantizedNets.DenseNet.densenet import QDenseNet40
            if settings['model'] == 'QDenseNet' : net = QDenseNet40

            else: raise ValueError(settings['model'])
        elif settings['algorithm'] in ['FjORD', 'HeteroFL']:
            from nets.SubsetNets.DenseNet.densenet import SDenseNet40
            if settings['model'] == 'SDenseNet': net = SDenseNet40
            else: raise ValueError(settings['model'])
    elif 'TransformerSeq2Seq' in settings['model']:
        from nets.Baseline.Transformer.transformer import TransformerSeq2Seq
        net_eval = TransformerSeq2Seq
        if settings['algorithm'] in ['Centralized', 'FedAvg', 'FedAvgDropDevices']:
            net = TransformerSeq2Seq
        elif settings['algorithm'] in ['CoCoFL']:
            from nets.QuantizedNets.Transformer.transformer import QTransformerSeq2Seq
            net = QTransformerSeq2Seq
        elif settings['algorithm'] in ['FjORD', 'HeteroFL']:
            from nets.SubsetNets.Transformer.transformer import STransformerSeq2Seq
            net = STransformerSeq2Seq
        else: raise ValueError(settings['algorithm'])
    elif 'Transformer' in settings['model']:
        from nets.Baseline.Transformer.transformer import Transformer as baseline_Transformer
        net_eval = baseline_Transformer
        if settings['algorithm'] in ['Centralized', 'FedAvg', 'FedAvgDropDevices']:
            if settings['model'] != 'Transformer': raise ValueError(settings['algorithm'])
            net = baseline_Transformer
        elif settings['algorithm'] in ['CoCoFL']:
            from nets.QuantizedNets.Transformer.transformer import QTransformer
            if settings['model'] == 'QTransformer': net = QTransformer
            else: raise ValueError(settings['model'])
        elif settings['algorithm'] in ['HeteroFL', 'FjORD']:
            from nets.SubsetNets.Transformer.transformer import STransformer
            if settings['model'] == 'STransformer': net = STransformer
            else: raise ValueError(settings['model'])
        else: raise ValueError(settings['algorithm'])
    else:      
        raise ValueError(settings['model'])


    from utils.split import split_iid, split_noniid, split_noniid_resources, split_SHAKESPEARE_noniid_resources

    if 'noniid_resources_SHAKESPEARE' in settings['dataset']:
        flserver.split_f = split_SHAKESPEARE_noniid_resources(run_path + path_hash,
                                                        run['plot'] if not dry_run else False, settings['seed'])
    elif 'noniid_resources' in settings['dataset']:
        flserver.split_f = split_noniid_resources(settings['param_dirichlet_alpha'], run_path + path_hash,
                                                        run['plot'] if not dry_run else False, settings['seed'])
    elif 'noniid_' in settings['dataset']:
        flserver.split_f = split_noniid(settings['param_dirichlet_alpha'], run_path + path_hash,
                                                run['plot'] if not dry_run else False, settings['seed'])
    else:
        flserver.split_f = split_iid(run_path + path_hash, run['plot'] if not dry_run else False, settings['seed'])

    cnn_args_list = [cnn_args for _ in range(settings['n_devices'])]
    flserver.set_model([net for _ in range(settings['n_devices'])], cnn_args_list)
    flserver.set_model_evaluation(net_eval, cnn_args_list[0])
    
    if dry_run:
        print("DRY RUN PERFORMED SUCESSFULLY")
        print(f"Settings Hash: {path_hash}")
        flserver.initialize()
        return flserver
    try: 
        os.makedirs(run_path + path_hash)
        with open (run_path + path_hash + "/" + "fl_setting.json", "w") as fd:
            json.dump(settings,fd, indent=4)
    except FileExistsError:
        pass

    logging.basicConfig(format='%(asctime)s - %(message)s',
                            filename=run_path + path_hash + '/run.log', level=logging.INFO, filemode='w')
    logging.info('Started')
    print(f"Settings Hash: {path_hash}")
    logging.info(f"Settings Hash: {path_hash}")
    logging.info("{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in settings.items()) + "}")
    flserver.initialize()

    with open(run_path + path_hash + "/" + "t_start.json", "w") as fd:
            json.dump({"t_start": datetime.now().strftime('%Y_%M_%d_%H_%m_%s')}, fd, indent=4)
    
    if run['plot'] is True:
        import visualization.plots as plots
    
        flserver.set_plotting_callback(plots.plot_config, run_path + path_hash)

    try:
        flserver.run()
    except KeyboardInterrupt:
        pass
    
    if run['plot'] is True:
        plots.plot_config(run_path + path_hash)

    try:
        os.unlink("latest_run")
    except FileNotFoundError:
        pass
    os.symlink(run_path + path_hash, "latest_run")

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='FL commands')
    parser.add_argument("--path", type=str, nargs=1, default=["settings/shakespeare_experiments.json"], help="FL settings file location path")
    parser.add_argument("--param", nargs='+')
    parser.add_argument("--dry_run", action='store_true')
    args = parser.parse_args()

    override_dict = {}

    if args.param is not None:
        assert len(args.param) % 2 == 0, f"parameters do not match {args.param}"
        value_list = [args.param[i] for i in range(1,len(args.param),2)]
        key_list = [args.param[i] for i in range(0,len(args.param),2)]
        override_dict = {key_list[i] : value_list[i] for i in range(len(key_list))}

    try:
        override_dict['seed'] = int(override_dict['seed'])
    except KeyError: pass
    try:
        override_dict['param_dirichlet_alpha'] = float(override_dict['param_dirichlet_alpha'])
    except KeyError: pass

    print("override_dict \n{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in override_dict.items()) + "}")
    run(args.path[0], args.dry_run, param_override_dict=override_dict)