from copy import deepcopy
from typing import Tuple, List
import torch.multiprocessing as mp

from fling.component.client import ClientTemplate


def _client_trainer(client: ClientTemplate, kwargs: dict) -> Tuple:
    # This is the function that each client will execute train function.
    # It will receive a task and its arguments, execute it, and return its result and the updated client.
    res = client.train(**kwargs)
    return res, client


def _client_tester(client: ClientTemplate, kwargs: dict) -> Tuple:
    # This is the function that each client will execute test function.
    # It will receive a task and its arguments, execute it, and return its result and the updated client.
    res = client.test(**kwargs)
    return res, client


def _client_finetuner(client: ClientTemplate, kwargs: dict) -> Tuple:
    # This is the function that each client will execute finetune function.
    # It will receive a task and its arguments, execute it, and return its result and the updated client.
    res = client.finetune(**kwargs)
    return res, client


op2func = {'train': _client_trainer, 'test': _client_tester, 'finetune': _client_finetuner}


def copy_attributes(src: object, dst: object) -> None:
    r"""
    Overview:
        Copy all the attributes of src to dst.
        This function requires that src and dst is the same class.
    Arguments:
        src: The attributes of this object will be copied to dst.
        dst: The attributes of this object will be over-written by src's.
    """
    for attr in src.__dict__:
        setattr(dst, attr, getattr(src, attr))


class SerialLauncher:
    r"""
    Overview:
        Use one process to serially execute operations all clients.
    """

    def launch(self, clients: ClientTemplate, task_name: str, **kwargs) -> List:
        r"""
        Overview:
            Launch the tasks in each client.
        Arguments:
            clients: Clients to be launched.
            task_name: Task name of the operation in each client.
            kwargs: Arguments required by corresponding operations (e.g. train, test, finetune)
        Returns:
            loggers: A list, each element corresponds to the logger generated by one client.
        """
        tasks = [(client, kwargs) for client in clients]
        results = []

        # Get the operation function according to the task name.
        try:
            op_func = op2func[task_name]
        except KeyError:
            raise ValueError(f'Unrecognized task name: {task_name}')

        for task in tasks:
            results.append(op_func(task[0], task[1]))

        # Retrieve the loggers.
        loggers = [results[i][0] for i in range(len(results))]

        return loggers


class MultiProcessLauncher:
    r"""
    Overview:
        Accelerate the process of operations on each client.
        Use one process to monitor operations on individual clients.
    """

    def __init__(self, num_proc: int):
        r"""
        Overview:
            Initialization for launcher.
        Arguments:
            num_proc: Number of processes used.
        """
        self.num_proc = num_proc

    def launch(self, clients: List, task_name: str, **kwargs) -> List:
        r"""
        Overview:
            Launch the tasks in each client.
        Arguments:
            clients: Clients to be launched.
            task_name: Task name of the operation in each client.
            kwargs: Arguments required by corresponding operations (e.g. train, test, finetune)
        Returns:
            loggers: A list, each element corresponds to the logger generated by one client.
        """
        tasks = [(client, kwargs) for client in clients]

        # Get the operation function according to the task name.
        try:
            op_func = op2func[task_name]
        except KeyError:
            raise ValueError(f'Unrecognized task name: {task_name}')

        with mp.Pool(self.num_proc) as pool:
            # Use starmap to apply the worker function to every task
            # Each task is a tuple that contains the task object and the arguments
            results = pool.starmap(op_func, tasks)

        # Retrieve the loggers and updated clients respectively.
        loggers = [results[i][0] for i in range(len(results))]
        new_clients = [results[i][1] for i in range(len(results))]

        # Copy the attributes of new clients to original clients.
        for i in range(len(clients)):
            new_client = new_clients[i]
            client = clients[i]
            assert new_client.client_id == client.client_id
            copy_attributes(src=new_client, dst=client)

        return loggers


def get_launcher(args: dict) -> object:
    r"""
    Overview:
        Build the launcher according to the configurations.
    Arguments:
        args: The input configurations.
    Returns:
        Corresponding launcher.
    """
    # Copy the args or the args will be modified by the following ``pop()``
    launcher_args = deepcopy(args.launcher)
    launcher_name = launcher_args.pop('name')

    # Build different types of launchers.
    if launcher_name == 'serial':
        return SerialLauncher()
    elif launcher_name == 'multiprocessing':
        return MultiProcessLauncher(**launcher_args)
    else:
        raise ValueError(f'Unrecognized launcher type: {launcher_name}')
