"""
Licensed Materials - Property of IBM
Restricted Materials of IBM
20190891
© Copyright IBM Corp. 2021 All Rights Reserved.
"""
#!/usr/bin/env python3

"""
Aggregator is an application which will allow users
to execute controlled Federated Learning tasks
"""

import re
import os
import sys
import logging

fl_path = os.path.abspath('.')
if fl_path not in sys.path:
    sys.path.append(fl_path)

from ibmfl.aggregator.states import States
from ibmfl.util.config import configure_logging_from_file, \
    get_aggregator_config
from ibmfl.connection.route_declarations import get_aggregator_router
from ibmfl.connection.router_handler import Router
from ibmfl.evidencia.util.config import config_to_json_str
import copy

#Set up logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s.%(msecs)03d %(levelname)-6s %(name)s :: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')



class Aggregator(object):
    """
    Aggregator class to create an aggregator application
    """

    def __init__(self, **kwargs):
        """
        Initializes an `Aggregator` object

        :param config_file: path to yaml file containing configuration
        :type config_file: `str`
        """
        configure_logging_from_file()

        cls_config = get_aggregator_config(**kwargs)
        print(f'cls_config: {cls_config}')
        self.data_handler = None
        self.data_handler_DA = None
        self.data_handler_DB = None
        self.fl_model = None
        self.fl_models = []
        self.shapley_value_test_model = None

        data_config = cls_config.get('data')
        # data_config_DA = cls_config.get('data_DA')
        # data_config_DB = cls_config.get('data_DB')
        model_config = cls_config.get('model')
        connection_config = cls_config.get('connection')
        ph_config = cls_config.get('protocol_handler')
        fusion_config = cls_config.get('fusion')
        mh_config = cls_config.get('metrics')
        evidencia_config = cls_config.get('evidencia')
        number_of_tiers = cls_config.get('hyperparams').get('global').get('tiers')
        
        max_timeout = None

        try:
            files_list = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I', 9: 'J', 10: 'K', 11: 'L', 12: 'M', 13: 'N', 14: 'O', 15: 'P', 16: 'Q', 17: 'R', 18: 'S', 19: 'T', 20: 'U', 21: 'V', 22: 'W', 23: 'X', 24: 'Y', 25: 'Z'}
            # Load data (optional field)
            # - in some cases the aggregator doesn't have data for testing purposes
            if data_config:
                data_cls_ref = data_config.get('cls_ref')
                data_info = data_config.get('info')
                self.data_handler = data_cls_ref(data_config=data_info)
                self.data_handlers = []
                
                for tier_id in range(0, number_of_tiers):
                    logger.info('loading data for tier {}'.format(tier_id))
                    data_config_DA = cls_config.get('data_D{}'.format(files_list[tier_id]))
                    logger.info('data_config_DA: {}'.format(data_config_DA))
                    data_cls_ref_DA = data_config_DA.get('cls_ref')
                    data_info_DA = data_config_DA.get('info')
                    self.data_handler_DA = data_cls_ref_DA(data_config=data_info_DA)
                    self.data_handlers.append(self.data_handler_DA)
                # data_cls_ref_DA = data_config_DA.get('cls_ref')
                # data_info_DA = data_config_DA.get('info')
                # self.data_handler_DA = data_cls_ref_DA(data_config=data_info_DA)
                # data_cls_ref_DB = data_config_DB.get('cls_ref')
                # data_info_DB = data_config_DB.get('info')
                # self.data_handler_DB = data_cls_ref_DB(data_config=data_info_DB)

            # Read and create model (optional field)
            # In some cases aggregator doesn't need to load the model:
            if model_config:
                model_cls_ref = model_config.get('cls_ref')
                spec = model_config.get('spec')
                model_info = model_config.get('info')
                # self.fl_model = model_cls_ref('', spec, info=model_info)
                for i in range(0, number_of_tiers):
                    self.fl_models.append(model_cls_ref('', spec, info=model_info))
                self.fl_model = model_cls_ref('', spec, info=model_info)

                self.shapley_value_test_model = model_cls_ref('', spec, info=model_info)
            print(f'self.fl_models: {self.fl_models}')
            # Load hyperparams
            self.hyperparams = cls_config.get('hyperparams')
            connection_cls_ref = connection_config.get('cls_ref')
            connection_info = connection_config.get('info')
            connection_synch = connection_config.get('sync')
            if 'max_timeout' in cls_config.get('hyperparams').get('global'):
                max_timeout = cls_config.get('hyperparams').get(
                    'global').get('max_timeout')

            self.connection = connection_cls_ref(connection_info)
            self.connection.initialize_sender()

            ph_cls_ref = ph_config.get('cls_ref')
            ph_info = ph_config.get('info')
            self.proto_handler = ph_cls_ref(self.connection.sender,
                                            connection_synch,
                                            max_timeout,
                                            info=ph_info)

            self.router = Router()
            get_aggregator_router(self.router, self.proto_handler)

            self.evidencia = None
            if evidencia_config:
                evidencia_cls_ref = evidencia_config.get('cls_ref')
                if 'info' in evidencia_config:
                    evidencia_info = evidencia_config.get('info')
                    self.evidencia = evidencia_cls_ref(evidencia_info)
                else:
                    self.evidencia = evidencia_cls_ref()

            fusion_cls_ref = fusion_config.get('cls_ref')
            fusion_info = fusion_config.get('info')
            self.fusion = fusion_cls_ref(self.hyperparams,
                                         self.proto_handler,
                                         data_handler=self.data_handler,
                                         data_handlers=self.data_handlers,
                                         fl_models=self.fl_models,
                                         shapley_value_test_model=self.shapley_value_test_model,
                                         evidencia=self.evidencia,
                                         info=fusion_info)
            if mh_config:
                mh_cls_ref = mh_config.get('cls_ref')
                mh_info = mh_config.get('info')
                self.metrics_handler = mh_cls_ref(info=mh_info)
                mh = self.fusion.metrics_manager
                mh.register(self.metrics_handler.handle)

            self.connection.initialize_receiver(router=self.router)

        except Exception as ex:
            logger.info(
                'Error occurred while loading aggregator configuration')
            logger.exception(ex)

        else:
            logger.info("Aggregator initialization successful")
            if self.evidencia:
                self.evidencia.add_claim("configuration",
                                        "'{}'".format(config_to_json_str(cls_config)))
        
    def start(self):
        """
        Start a server for the aggregator in a new thread
        Parties can connect to register

        """
        try:
            self.connection.start()
        except Exception as ex:
            logger.error("Error occurred during start")
            logger.error(ex)
        else:
            logger.info("Aggregator start successful")

    def stop(self):
        """
        Stop the aggregator server

        :param: None
        :return: None
        """
        try:
            self.proto_handler.stop_parties()
            self.connection.stop()
        except Exception as ex:
            logger.error("Error occurred during stop")
            logger.error(ex)
        else:
            logger.info("Aggregator stop successful")

    def start_training(self):
        """
        Start federated learning training. Request all the registered
        parties to initiate training and send model update

        :param: None
        :return: Boolean
        :rtype: `boolean`
        """
        logger.info('Initiating Global Training.')
        try:
            self.fusion.initialization()
        except Exception as ex:
            logger.exception('Exception occurred during the initialization '
                             'of the global training.')
            logger.exception(ex)
            return False
        try:
            self.fusion.start_global_training_by_tier()
        except Exception as ex:
            logger.exception('Exception occurred while training.')
            logger.exception(ex)
            return False
        else:
            logger.info('Finished Global Training')
        return True

    def save_model(self):
        """
        Request all parties to save models
        """
        logger.info('Initiating save model request.')
        try:
            self.fusion.save_parties_models()
        except Exception as ex:

            logger.exception(ex)
        else:
            logger.info('Finished save requests')

    def eval_model(self):
        """
        Request all parties to print evaluations
        """
        logger.info('Initiating evaluation requests.')
        try:
            self.fusion.evaluate_model()

        except Exception as ex:
            logger.exception('Exception occurred during party evaluations.')
            logger.exception(ex)
        else:
            logger.info('Finished eval requests')

    def model_synch(self):
        """
        Send global model to the parties
        """
        logger.info('Initiating global model sync requests.')
        try:
            self.fusion.send_global_models()
        except Exception as ex:
            logger.exception('Exception occurred during sync model.')
            logger.exception(ex)
        else:
            logger.info('Finished sync model requests')


if __name__ == '__main__':
    """
    Main function can be used to create an application out
    of our Aggregator class which could be interactive
    """
    if len(sys.argv) < 2 or len(sys.argv) > 2:
        logging.error('Please provide yaml configuration')

    server_process = None
    config_file = sys.argv[1]
    if not os.path.isfile(config_file):
        logging.error("config file '{}' does not exist".format(config_file))

    agg = Aggregator(config_file=config_file)
    agg.proto_handler.state = States.CLI_WAIT
    logging.info("State: " + str(agg.proto_handler.state))
    # Start server
    logging.info("Starting server")
    agg.start()
    # Indefinite loop to accept user commands to execute
    while 1:
        msg = sys.stdin.readline()
        # TODO: move it to Aggregator
        if re.match('START', msg):
            agg.proto_handler.state = States.CLI_WAIT
            logging.info("State: " + str(agg.proto_handler.state))
            # Start server
            agg.start()

        elif re.match('STOP', msg):
            agg.proto_handler.state = States.STOP
            logging.info("State: " + str(agg.proto_handler.state))
            agg.stop()
            break

        elif re.match('TRAIN', msg):
            agg.proto_handler.state = States.TRAIN
            logging.info("State: " + str(agg.proto_handler.state))
            success = agg.start_training()
            if not success:
                agg.stop()
                break

        elif re.match('SAVE', msg):
            agg.proto_handler.state = States.SAVE
            logging.info("State: " + str(agg.proto_handler.state))
            agg.save_model()

        elif re.match('EVAL', msg):
            agg.proto_handler.state = States.EVAL
            logging.info("State: " + str(agg.proto_handler.state))
            agg.eval_model()

        elif re.match('SYNC', msg):
            agg.proto_handler.state = States.SYNC
            logging.info("State: " + str(agg.proto_handler.state))
            agg.model_synch()

    exit()
