from typing import Any, Optional
import torch
from torch.autograd import Function

from pado.utils.dist_utils import all_gather_tensor, all_reduce_tensor, get_world_size

__all__ = [
    "pact",
    "gradient_scale",
    "sync_batch_norm_func",
    "masked_batch_norm_func",
    "masked_sync_batch_norm_func",
]


class PACTFunc(Function):
    # Some useful extensions:
    # https://github.com/cornell-zhang/dnn-gating/blob/master/utils/pg_utils.py

    @staticmethod
    def forward(ctx: Any,  # noqa
                inp: torch.Tensor,
                alpha: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(inp, alpha)
        res = torch.clamp(inp, 0, alpha.item())
        return res

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor):  # noqa
        inp, alpha = ctx.saved_tensors

        grad_input = grad_alpha = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.clone()
            grad_input[inp < 0] = 0
            grad_input[inp > alpha] = 0
        if ctx.needs_input_grad[1]:
            grad_alpha = grad_output.clone()
            grad_alpha[inp <= alpha] = 0
            grad_alpha = torch.sum(grad_alpha, dim=0, keepdim=True)  # to preserve [1],
        return grad_input, grad_alpha


pact = PACTFunc.apply


class GradientScaleFunc(Function):
    # forward as-is, backward gradient scaling

    @staticmethod
    def forward(ctx: Any,  # noqa
                inp: torch.Tensor,
                scale: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(scale)
        return inp

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor):  # noqa
        scale = ctx.saved_tensors[0]
        grad_output.mul_(scale)
        return grad_output, None


gradient_scale = GradientScaleFunc.apply


class SyncBatchNormFunc(Function):
    """Should be only called when bn_training=True."""

    @staticmethod
    def forward(ctx: Any,  # noqa
                inp: torch.Tensor,
                weight: Optional[torch.Tensor],
                bias: Optional[torch.Tensor],
                running_mean: Optional[torch.Tensor],
                running_var: Optional[torch.Tensor],
                momentum: float,
                eps: float) -> torch.Tensor:
        c = inp.shape[1]
        count = inp.numel() // c  # elements to be summed

        if (count == 1) and (get_world_size() < 2):
            raise ValueError(f"SyncBN requires summed elements at least 1.")

        # mean, inv_std = torch.batch_norm_stats(inp, eps)  # unbiased var, we need biased var.
        sum_dim = [0] + list(range(2, inp.ndim))
        mean = torch.sum(inp, dim=sum_dim) / count
        sq_mean = torch.sum(inp * inp, dim=sum_dim) / count
        var = torch.clamp_min(sq_mean - (mean * mean), 1e-6)
        inv_std = torch.rsqrt(var + eps)

        count = torch.full((1,), count, dtype=mean.dtype, device=mean.device)

        combined = torch.cat([mean, inv_std, count], dim=0)  # (2C + 1,)
        combined = all_gather_tensor(combined)
        combined = torch.stack(combined, dim=0)  # (world_size, 2C + 1)
        mean_all, inv_std_all, count_all = torch.split(combined, c, 1)  # (world, C), (world, C), (world, 1)

        mean, inv_std = torch.batch_norm_gather_stats_with_counts(
            inp,
            mean_all,
            inv_std_all,
            running_mean,
            running_var,
            momentum,
            eps,
            count_all.view(-1)
        )
        ctx.save_for_backward(inp, weight, mean, inv_std, count_all.to(torch.int32))
        output = torch.batch_norm_elemt(inp, weight, bias, mean, inv_std, eps)
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor):  # noqa
        inp, weight, mean, inv_std, count = ctx.saved_tensors
        c = inp.shape[1]

        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output,
            inp,
            mean,
            inv_std,
            weight,
            ctx.needs_input_grad[0],
            ctx.needs_input_grad[1],
            ctx.needs_input_grad[2]
        )

        grad_input = None
        if ctx.needs_input_grad[0]:
            # synchronize stats
            combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)  # (2C,)
            combined = all_reduce_tensor(combined, "sum", detach=False)
            sum_dy, sum_dy_xmu = torch.split(combined, c)

            grad_input = torch.batch_norm_backward_elemt(
                grad_output,
                inp,
                mean,
                inv_std,
                weight,
                sum_dy,
                sum_dy_xmu,
                count
            )

        if (weight is None) or (not ctx.needs_input_grad[1]):
            grad_weight = None
        if (weight is None) or (not ctx.needs_input_grad[2]):  # isn't this should be bias?
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None


sync_batch_norm_func = SyncBatchNormFunc.apply


class MaskedBatchNormFunc(Function):
    """Should be only called when bn_training=True."""

    @staticmethod
    def forward(ctx: Any,  # noqa
                inp: torch.Tensor,
                weight: Optional[torch.Tensor],
                bias: Optional[torch.Tensor],
                mask: torch.Tensor,
                running_mean: Optional[torch.Tensor],
                running_var: Optional[torch.Tensor],
                momentum: float,
                eps: float) -> torch.Tensor:
        # c = inp.shape[1]

        # mask should be in broadcast-able shape as input
        assert tuple(mask.shape) == (inp.shape[0],) + (1,) + inp.shape[2:]
        inp = inp * mask
        count = mask.sum()

        if count == 1:
            raise ValueError(f"MaskedBN requires non-masked elements at least 1.")

        sum_dim = [0] + list(range(2, inp.ndim))
        mean = torch.sum(inp, dim=sum_dim) / count
        sq_mean = torch.sum(inp * inp, dim=sum_dim) / count
        var = torch.clamp_min(sq_mean - (mean * mean), 1e-6)
        inv_std = torch.rsqrt(var + eps)

        count_num = int(count.item())
        try:
            mean, inv_std = torch.batch_norm_gather_stats(
                inp,
                mean.view(1, -1),
                inv_std.view(1, -1),
                running_mean,
                running_var,
                momentum,
                eps,
                count_num
            )
            output = torch.batch_norm_elemt(inp, weight, bias, mean, inv_std, eps)
        except NotImplementedError:  # not supported backend: XLA, fall back to manual computation
            v_shape = [1] * inp.ndim
            v_shape[1] = inp.shape[1]
            inp_minus_mean = inp - mean.view(v_shape)
            inp_normalized = inp_minus_mean * inv_std.view(v_shape)
            output = inp_normalized * weight.view(v_shape) + bias.view(v_shape)

            # if (running_mean is not None) and (running_var is not None):
        #     # r' = r * (1 - momentum) + new_r * momentum
        #     # r' = r + momentum * (new_r - r)
        #     running_mean.data.add_(mean.detach().data - running_mean.data, alpha=momentum)
        #     running_var.data.add_(var.detach().data.mul_(count / (count - 1)) - running_var.data, alpha=momentum)

        ctx.save_for_backward(inp, weight, mean, inv_std, mask, count.to(torch.int32))
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor):  # noqa
        inp, weight, mean, inv_std, mask, count = ctx.saved_tensors

        try:
            sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
                grad_output,
                inp,
                mean,
                inv_std,
                weight,
                ctx.needs_input_grad[0],
                ctx.needs_input_grad[1],
                ctx.needs_input_grad[2]
            )

            grad_input = None
            if ctx.needs_input_grad[0]:
                grad_input = torch.batch_norm_backward_elemt(
                    grad_output,
                    inp,
                    mean,
                    inv_std,
                    weight,
                    sum_dy,
                    sum_dy_xmu,
                    count
                )
                grad_input *= mask

        except NotImplementedError:  # not supported backend: XLA, fall back to manual computation
            v_shape = [1] * inp.ndim
            v_shape[1] = inp.shape[1]
            s_dim = list(range(inp.ndim))
            s_dim.pop(1)  # always channel is 1-th dim

            g_out = grad_output * weight.view(v_shape)
            inp_minus_mean = inp - mean.view(v_shape)

            inv_std_view = inv_std.view(v_shape)
            inp_normalized = inp_minus_mean * inv_std_view

            mean_inp_minus_mean = inp_minus_mean.mul(2.0).div_(count)

            g_var = torch.sum(g_out * inp_minus_mean * (-0.5) * (inv_std_view ** 3), dim=s_dim, keepdim=True)
            g_mu = torch.sum(g_out * (-inv_std_view), dim=s_dim, keepdim=True
                             ) - g_var * torch.sum(mean_inp_minus_mean, dim=s_dim, keepdim=True)

            grad_input = (g_out * inv_std_view) + (g_var * mean_inp_minus_mean) + g_mu.div(count)
            grad_input *= mask

            grad_weight = torch.sum(grad_output * inp_normalized, dim=s_dim)
            grad_bias = torch.sum(grad_output, dim=s_dim)

        if (weight is None) or (not ctx.needs_input_grad[1]):
            grad_weight = None
        if (weight is None) or (not ctx.needs_input_grad[2]):  # isn't this should be bias?
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None


masked_batch_norm_func = MaskedBatchNormFunc.apply


class MaskedSyncBatchNormFunc(Function):
    """Should be only called when bn_training=True."""

    @staticmethod
    def forward(ctx: Any,  # noqa
                inp: torch.Tensor,
                weight: Optional[torch.Tensor],
                bias: Optional[torch.Tensor],
                mask: torch.Tensor,
                running_mean: Optional[torch.Tensor],
                running_var: Optional[torch.Tensor],
                momentum: float,
                eps: float) -> torch.Tensor:
        c = inp.shape[1]

        # mask should be in broadcast-able shape as input
        assert tuple(mask.shape) == (inp.shape[0],) + (1,) + inp.shape[2:]
        inp = inp * mask
        count = mask.sum()  # assume mask shape same as input, except channel dimension.

        if (count == 1) and (get_world_size() < 2):
            raise ValueError(f"MaskedSyncBN requires non-masked elements at least 1.")

        sum_dim = [0] + list(range(2, inp.ndim))
        mean = torch.sum(inp, dim=sum_dim) / count
        sq_mean = torch.sum(inp * inp, dim=sum_dim) / count
        var = torch.clamp_min(sq_mean - (mean * mean), 1e-6)
        inv_std = torch.rsqrt(var + eps)

        combined = torch.cat([mean, inv_std, count.view(1)], dim=0)  # (2C + 1,)
        combined = all_gather_tensor(combined)
        combined = torch.stack(combined, dim=0)  # (world_size, 2C + 1)
        mean_all, inv_std_all, count_all = torch.split(combined, c, 1)  # (world, C), (world, C), (world, 1)

        mean, inv_std = torch.batch_norm_gather_stats_with_counts(
            inp,
            mean_all,
            inv_std_all,
            running_mean,
            running_var,
            momentum,
            eps,
            count_all.view(-1)
        )
        ctx.save_for_backward(inp, weight, mean, inv_std, mask, count_all.to(torch.int32))
        output = torch.batch_norm_elemt(inp, weight, bias, mean, inv_std, eps)

        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor):  # noqa
        inp, weight, mean, inv_std, mask, count = ctx.saved_tensors
        c = inp.shape[1]

        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output,
            inp,
            mean,
            inv_std,
            weight,
            ctx.needs_input_grad[0],
            ctx.needs_input_grad[1],
            ctx.needs_input_grad[2]
        )

        grad_input = None
        if ctx.needs_input_grad[0]:
            # synchronize stats
            combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)  # (2C,)
            combined = all_reduce_tensor(combined, "sum", detach=False)
            sum_dy, sum_dy_xmu = torch.split(combined, c)  # (C,), (C,)

            grad_input = torch.batch_norm_backward_elemt(
                grad_output,
                inp,
                mean,
                inv_std,
                weight,
                sum_dy,
                sum_dy_xmu,
                count
            )
            grad_input *= mask

        if (weight is None) or (not ctx.needs_input_grad[1]):
            grad_weight = None
        if (weight is None) or (not ctx.needs_input_grad[2]):  # isn't this should be bias?
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None


masked_sync_batch_norm_func = MaskedSyncBatchNormFunc.apply
