"""
This is the script that contains the backend code. No need to look at this to implement new functionality
Functions that run separate processes. These processes run on GPUs, and are queried by processes running only CPUs
"""

import dill
import inspect
import queue
import torch
import torch.multiprocessing as mp
from rich.console import Console
from time import time
from typing import Callable, Union

from configs import config

console = Console(highlight=False)

if mp.current_process().name == 'MainProcess':
    # No need to initialize the models inside each process
    import vision_models
    # Create a list of all the defined models
    list_models = [m[1] for m in inspect.getmembers(vision_models, inspect.isclass)
                   if vision_models.BaseModel in m[1].__bases__]
    if config.multiprocessing:
        manager = mp.Manager()
    else:
        manager = None
else:
    list_models = None
    manager = None

import os
def find_best_gpu_device():
    device_count = torch.cuda.device_count()
    devices = list(range(device_count))

    def get_total_memory(device):
        return torch.cuda.get_device_properties(device).total_memory

    def get_available_memory(device):
        return get_total_memory(device) - torch.cuda.memory_allocated(device)

    best_device = max(devices, key=get_available_memory)
    return best_device

def make_fn(model_class, process_name, counter):
    """
    model_class.name and process_name will be the same unless the same model is used in multiple processes, for
    different tasks
    """
    # We initialize each one on a separate GPU, to make sure there are no out of memory errors
    # num_gpus = torch.cuda.device_count()
    # gpu_number = counter % num_gpus
    gpu_number = find_best_gpu_device()
    model_instance = model_class(gpu_number=gpu_number)

    def _function(*args, **kwargs):
        if process_name != model_class.name:
            kwargs['process_name'] = process_name

        if model_class.to_batch and not config.multiprocessing:
            # Batchify the input. Model expects a batch. And later un-batchify the output.
            args = [[arg] for arg in args]
            kwargs = {k: [v] for k, v in kwargs.items()}

            # The defaults that are not in args or kwargs, also need to listify
            full_arg_spec = inspect.getfullargspec(model_instance.forward)
            if full_arg_spec.defaults is None:
                default_dict = {}
            else:
                default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults))
            non_given_args = full_arg_spec.args[1:][len(args):]
            non_given_args = set(non_given_args) - set(kwargs.keys())
            for arg_name in non_given_args:
                kwargs[arg_name] = [default_dict[arg_name]]

        out = model_instance.forward(*args, **kwargs)
        if model_class.to_batch and not config.multiprocessing:
            out = out[0]
        return out

    return _function


global make_fn_process, queues_in, consumers, finish_all_consumers
def init_vision_models():
    global make_fn_process, queues_in, consumers, finish_all_consumers
    if config.multiprocessing:

        def make_fn_process(model_class, process_name, counter):

            if model_class.to_batch:
                seconds_collect_data = model_class.seconds_collect_data  # Window of seconds to group inputs
                max_batch_size = model_class.max_batch_size

                def _function(queue_in):

                    fn = make_fn(model_class, process_name, counter)

                    to_end = False
                    while True:
                        start_time = time()
                        time_left = seconds_collect_data
                        batch_inputs = []
                        batch_queues = []
                        while time_left > 0 and len(batch_inputs) < max_batch_size:
                            try:
                                received = queue_in.get(timeout=time_left)
                                if received is None:
                                    to_end = True
                                    break
                                else:
                                    batch_inputs.append(received[0])
                                    batch_queues.append(received[1])
                            except queue.Empty:  # Time-out expired
                                break  # Break inner loop (or do nothing, would break anyway because time_left < 0)
                            time_left = seconds_collect_data - (time() - start_time)
                        if len(batch_inputs) > 0:
                            batch_kwargs = collate(batch_inputs, model_class.forward)
                            outs = fn(**batch_kwargs)
                            try:
                                for out, qu in zip(outs, batch_queues):
                                    qu.put(out)
                            except Exception as e:
                                # No message, because we are just carrying the error from before
                                for qu in batch_queues:
                                    qu.put(None)
                        if to_end:
                            print(f'{process_name} model exiting')
                            break

            else:
                def _function(queue_in):
                    fn = make_fn(model_class, process_name, counter)
                    while True:
                        received = queue_in.get()
                        if received is None:
                            print(f'{process_name} exiting')
                            return
                        (args, kwargs), queue_out = received
                        out = fn(*args, **kwargs)
                        queue_out.put(out)

            return _function


        if mp.current_process().name == 'MainProcess':
            # queues_in: Union[dict[str, mp.Queue], None] = dict()
            # consumers: dict[str, Union[mp.Process, Callable]] = dict()
            queues_in = dict()
            consumers = dict()

            counter_ = 0
            for model_class_ in list_models:
                for process_name_ in model_class_.list_processes():
                    if process_name_ in config.load_models and config.load_models[process_name_]:
                        queue_in_ = manager.Queue()  # For transfer of data from producer to consumer
                        queues_in[process_name_] = queue_in_

                        fn_process = make_fn_process(model_class_, process_name_, counter_)
                        # Otherwise, it is not possible to pickle the _function (not defined at top level)
                        aux = mp.reducer.dump
                        mp.reducer.dump = dill.dump
                        consumer = mp.Process(target=fn_process, kwargs={'queue_in': queue_in_})
                        consumer.start()
                        mp.reducer.dump = aux
                        consumers[process_name_] = consumer

                        counter_ += 1

        else:
            queues_in = None


        def finish_all_consumers():
            # Wait for consumers to finish
            for q_in in queues_in.values():
                q_in.put(None)
            for cons in consumers.values():
                cons.join()

    else:

        consumers = dict()

        counter_ = 0
        for model_class_ in list_models:
            for process_name_ in model_class_.list_processes():
                if process_name_ in config.load_models and config.load_models[process_name_]:
                    consumers[process_name_] = make_fn(model_class_, process_name_, counter_)
                    counter_ += 1

        queues_in = None

        def finish_all_consumers():
            pass


def forward(model_name, *args, queues=None, **kwargs):
    """
    Sends data to consumer (calls their "forward" method), and returns the result
    """
    global make_fn_process, queues_in, consumers, finish_all_consumers

    error_msg = f'No model named {model_name}. ' \
                'The available models are: {}. Make sure to activate it in the configs files'
    if not config.multiprocessing:
        try:
            out = consumers[model_name](*args, **kwargs)
        except KeyError as e:
            raise KeyError(error_msg.format(list(consumers.keys()))) from e
        except Exception as e:
            raise e
    else:
        if queues is None:
            consumer_queues_in, queue_results = None, None
        else:
            consumer_queues_in, queue_results = queues
        try:
            if consumer_queues_in is not None:
                consumer_queue_in = consumer_queues_in[model_name]
            else:
                consumer_queue_in = queues_in[model_name]
        except KeyError as e:
            options = list(consumer_queues_in.keys()) if consumer_queues_in is not None else list(queues_in.keys())
            raise KeyError(error_msg.format(options)) from e
        if queue_results is None:
            # print('No queue exists to get results. Creating a new one, but this is inefficient. '
            #       'Consider providing an existing queue for the process')
            queue_results = manager.Queue()  # To get outputs
        consumer_queue_in.put([(args, kwargs), queue_results])
        out = queue_results.get()  # Wait for result
    return out


def collate(batch_inputs, fn):
    """
    Combine a list of inputs into a single dictionary. The dictionary contains all the parameters of the
    function to be called. If the parameter is not defined in some samples, the default value is used. The
    value of the parameters is always a list.
    """
    # Separate into args and kwargs
    args_input, kwarg_input = list(zip(*batch_inputs))
    full_arg_spec = inspect.getfullargspec(fn)
    if full_arg_spec.defaults is None:
        default_dict = {}
    else:
        default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults))
        if 'process_name' in default_dict:  # process_name is a special parameter filled in later
            del default_dict['process_name']

    args_list = full_arg_spec.args[1:]  # Remove self

    # process_name is a special parameter filled in later
    if 'process_name' in args_list:
        assert args_list[-1] == 'process_name', 'process_name must be the last argument'
        args_list.remove('process_name')

    kwargs_output = {k: [] for k in args_list}
    for i, (args, kwargs) in enumerate(zip(args_input, kwarg_input)):
        if len(args) + len(kwargs) > len(args_list):
            raise Exception(
                f'You provided more arguments than the function {fn.__name__} accepts, or some kwargs/args '
                f'overlap. The arguments are: {args_list}')
        for j, arg_name in enumerate(args_list):
            if len(args) > j:
                kwargs_output[arg_name].append(args[j])
            elif arg_name in kwargs:
                kwargs_output[arg_name].append(kwargs[arg_name])
            else:
                assert arg_name in default_dict, f'You did not provide a value for the argument {arg_name}.'
                kwargs_output[arg_name].append(default_dict[arg_name])

    return kwargs_output
