#!/usr/bin/env python3

from typing import List, Optional

import torch
from linear_operator import to_linear_operator
from linear_operator.operators import CatLinearOperator
from torch.nn.parallel import DataParallel

from .. import settings
from .kernel import Kernel


class MultiDeviceKernel(DataParallel, Kernel):
    r"""
    Allocates the covariance matrix on distributed devices, e.g. multiple GPUs.

    Args:
        base_kernel: Base kernel to distribute
        device_ids: list of `torch.device` objects to place kernel chunks on
        output_device: Device where outputs will be placed
    """

    def __init__(
        self,
        base_kernel: Kernel,
        device_ids: List[torch.device],
        output_device: Optional[torch.device] = None,
        create_cuda_context: Optional[bool] = True,
        **kwargs,
    ):
        # Need to warm up each GPU otherwise scattering in forward will be
        # EXTREMELY slow. This memory will be available as soon as we leave __init__
        if create_cuda_context:
            for d in device_ids:
                _ = torch.tensor([], device=d)

        DataParallel.__init__(self, module=base_kernel, device_ids=device_ids, output_device=output_device, dim=-2)

        self.output_device = output_device if output_device else device_ids[0]

        self.__cached_x1 = torch.empty(1)
        self.__cached_x2 = torch.empty(1)

    @property
    def base_kernel(self):
        return self.module

    def forward(self, x1, x2, diag=False, **kwargs):
        if diag:
            return self.module.forward(x1, x2, diag=True, **kwargs).to(self.output_device)

        if x1.size(-2) < len(self.device_ids) + 1:
            return self.module.forward(x1, x2, diag=diag, **kwargs).to(self.output_device)

        if not x1.device == self.__cached_x1.device or not torch.equal(x1, self.__cached_x1):
            self._x1_scattered, self._kwargs = self.scatter((x1,), kwargs, self.device_ids)
            self.__cached_x1 = x1

        if not x2.device == self.__cached_x2.device or not torch.equal(x2, self.__cached_x2):
            self._x2_subs = [x2.to(x1_[0].device) for x1_ in self._x1_scattered]
            self.__cached_x2 = x2

        inputs = tuple((x1_[0], x2_) for x1_, x2_ in zip(self._x1_scattered, self._x2_subs))

        if not self.device_ids:
            return self.module.forward(*inputs, **self._kwargs)

        if len(self.device_ids) == 1:
            return self.module.forward(*inputs[0], **self._kwargs[0])

        # JIT modules can't be pickled and replicated yet
        # But reinitializing the distance_module every forward pass
        # is slow and should be removed once JIT modules can be pickled
        def set_distance_module_to_none(module):
            if hasattr(module, "distance_module"):
                module.distance_module = None

        self.module.apply(set_distance_module_to_none)
        # Can't cache the replication because the base kernel module can change every time (e.g. param updates)
        replicas = self.replicate(self.module, self.device_ids[: len(inputs)])

        # TODO: parallel_apply might be too heavyweight in some cases?
        with settings.lazily_evaluate_kernels(False):
            outputs = self.parallel_apply(replicas, inputs, self._kwargs)

        return self.gather(outputs, self.output_device)

    def gather(self, outputs, output_device):
        return CatLinearOperator(
            *(to_linear_operator(o) for o in outputs), dim=self.dim, output_device=self.output_device
        )

    def num_outputs_per_input(self, x1, x2):
        return self.base_kernel.num_outputs_per_input(x1, x2)
