from typing import Optional, Generic, TypeVar, Union

import torch

from XXX.uib.utils.safe_module import SafeModule

_current_num_samples: Optional[int] = 1
_stochastic_run: bool = False


def get_current_num_samples():
    return _current_num_samples


def in_stochastic_run():
    return _stochastic_run


def unflatten_tensor(input: torch.Tensor, k: int):
    input = input.view([-1, k] + list(input.shape[1:]))
    return input


def flatten_tensor(expanded_input: torch.Tensor):
    return expanded_input.flatten(0, 1)


def expand_tensor(input: torch.tensor, k: int):
    expanded_shape = [input.shape[0], k] + list(input.shape[1:])
    return input.unsqueeze(1).expand(expanded_shape)


def upsample_tensor(input_x_: torch.tensor, k: int):
    return flatten_tensor(expand_tensor(input_x_, k))


def stochastic_pipeline(model, input_x_k_):
    global _current_num_samples, _stochastic_run
    assert not _stochastic_run

    old_stochastic_run = _stochastic_run
    _current_num_samples = input_x_k_.shape[1]

    _stochastic_run = True

    try:
        input_xk_ = flatten_tensor(input_x_k_)
        output_xk_ = model(input_xk_)

        if not old_stochastic_run:
            # Only unflatten the tensor if we are not within a nested call.
            output_x_k_ = unflatten_tensor(output_xk_, _current_num_samples)
            return output_x_k_
        else:
            return output_xk_
    finally:
        _stochastic_run = old_stochastic_run
        if not _stochastic_run:
            _current_num_samples = 1


def stochastic_forward(model, input_x_, num_samples):
    global _current_num_samples, _stochastic_run

    old_stochastic_run = _stochastic_run

    # We adapt to the number of samples already drawn.
    if _current_num_samples >= num_samples:
        upsample_factor = 1
    else:
        upsample_factor = max(1, num_samples // _current_num_samples)

    _stochastic_run = True

    try:
        input_xk_ = upsample_tensor(input_x_, upsample_factor)
        _current_num_samples *= upsample_factor
        output_xk_ = model(input_xk_)

        if not old_stochastic_run:
            # Only unflatten the tensor if we are not within a nested call.
            output_x_k_ = unflatten_tensor(output_xk_, _current_num_samples)
            return output_x_k_
        else:
            return output_xk_
    finally:
        _stochastic_run = old_stochastic_run
        if not _stochastic_run:
            _current_num_samples = 1


class StochasticModel(SafeModule):
    """A module that we can sample multiple times from given a single input batch.
    To be efficient, the module allows for a part of the forward pass to be deterministic.
    """

    num_samples: int

    def __init__(self, num_samples: int):
        super().__init__()

        self.num_samples = num_samples

    # Returns B x n x output (or Bn x output if called from another StochasticModel to allow for nesting).
    def safe_forward(self, input_x_: torch.Tensor):
        def inner(input_x_):
            intermediate_x_ = self.deterministic_forward_impl(input_x_)
            output_x_k_ = stochastic_forward(self.stochastic_forward_impl, intermediate_x_, self.num_samples)
            return output_x_k_

        result = stochastic_forward(inner, input_x_, 1)
        return result

    def deterministic_forward_impl(self, input_x_: torch.Tensor):
        return input_x_

    def stochastic_forward_impl(self, intermediate_xk_: torch.Tensor):
        return intermediate_xk_

    @staticmethod
    def predict(output_x_k_: torch.Tensor):
        # TODO: I think this sucks because it is not clear how to meaningfully merge the samples
        # (Depending on the outputs of the network...)
        return output_x_k_.mean(dim=1, keepdim=False)


MT = TypeVar("MT", bound=torch.nn.Module)


class StochasticModelWrapper(StochasticModel, Generic[MT]):
    wrapped_model: MT

    def __init__(self, wrapped_model: MT, num_samples: int):
        super().__init__(num_samples)

        self.wrapped_model = wrapped_model

    def stochastic_forward_impl(self, intermediate_xk_: torch.Tensor):
        return self.wrapped_model(intermediate_xk_)


def as_stochastic_model(model: MT, num_samples=1) -> StochasticModelWrapper[MT]:
    return StochasticModelWrapper(model, num_samples)


class AsDeterministicModel(torch.nn.Module):
    wrapped_model: StochasticModel

    def __init__(self, wrapped_model: StochasticModel):
        super().__init__()

        self.wrapped_model = wrapped_model

    def forward(self, *args, **kwargs):
        outputs_b_k_ = self.wrapped_model(*args, **kwargs)
        if outputs_b_k_.shape[1] == 1:
            return outputs_b_k_[:, 0]
        else:
            return outputs_b_k_.mean(dim=1, keepdim=False)

