import math

import torch
from spmd.tensor.api import DTensor
from spmd.tensor.ops.utils import register_prop_rule
from spmd.tensor.ops.common_rules import pointwise_rule
from spmd.tensor.dispatch import OpSchema, OutputSharding
from spmd.tensor.placement_types import DTensorSpec, Replicate, _Partial, Shard

extra_pointwise_op = [
    "aten.leaky_relu.default", "aten.leaky_relu_backward.default", "aten.elu.default",
    "aten.elu_backward.default"
]

for op in extra_pointwise_op:
    DTensor._op_to_rules[op] = pointwise_rule


@register_prop_rule("aten.native_layer_norm_backward.default")
def _prop_native_layer_norm_backward(op_schema: OpSchema) -> OutputSharding:
    (
        grad,
        input,
        normalized_shape,
        result1,
        result2,
        weight,
        bias,
        grad_input_mask,
    ) = op_schema.args_schema
    assert isinstance(grad, DTensorSpec)
    assert isinstance(weight, DTensorSpec)
    assert isinstance(bias, DTensorSpec)
    assert isinstance(grad_input_mask, (list, tuple))
    assert all(isinstance(s, Replicate) for s in weight.placements)
    assert all(isinstance(s, Replicate) for s in bias.placements)
    # ensure sharding on dim 0, which will trigger the "Partial" output on weight and bias grads
    weight_grad = weight
    bias_grad = bias
    if not all(isinstance(s, Replicate) for s in grad.placements):
        weight_grad = DTensorSpec(mesh=weight.mesh,
                                  placements=[_Partial()] * weight.mesh.ndim,
                                  shape=weight.shape)
        bias_grad = DTensorSpec(mesh=bias.mesh,
                                placements=[_Partial()] * bias.mesh.ndim,
                                shape=bias.shape)
    return OutputSharding(
        # NOTE: type errors below are legit. This is because DTensor currently
        # doesn't support Optional return values. Need to be fixed in DTensor repo.
        output_spec=(
            grad if grad_input_mask[0] else None,
            weight_grad if grad_input_mask[1] else None,
            bias_grad if grad_input_mask[2] else None,
        ), )


@register_prop_rule("aten.convolution.default")
def _prop_convolution_default(op_schema: OpSchema) -> OutputSharding:
    input, weight, bias = op_schema.args_schema[0:3]
    assert isinstance(input, DTensorSpec)
    assert isinstance(weight, DTensorSpec)
    output_placements = [Replicate()] * input.mesh.ndim
    output_shape = list(input.shape)

    stride, padding = op_schema.args_schema[3:5]
    out_hw = int(
        math.floor(input.shape[2] + 2 * padding[0] - (weight.shape[2] - 1) - 1) / stride[0] + 1)
    output_shape[1] = weight.shape[0]
    output_shape[2], output_shape[3] = out_hw, out_hw

    for idx, s in enumerate(input.placements):
        if isinstance(s, Shard) and s.dim == 0:
            output_placements[idx] = Shard(dim=0)

    for idx, s in enumerate(weight.placements):
        if isinstance(s, Shard) and s.dim == 0:
            output_placements[idx] = Shard(dim=1)
        s_input = input.placements[idx]
        if isinstance(s, Shard) and s.dim == 1 and isinstance(s_input, Shard) and s_input.dim == 1:
            output_placements[idx] = _Partial()

    output_spec = DTensorSpec(mesh=input.mesh,
                              placements=output_placements,
                              shape=torch.Size(output_shape))

    return OutputSharding(output_spec=output_spec)


@register_prop_rule("aten.convolution_backward.default")
def _prop_convolution_backward_default(op_schema: OpSchema) -> OutputSharding:
    (grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding,
     groups, output_mask) = op_schema.args_schema

    weight_placement = [Replicate()] * weight.mesh.ndim

    for idx, s in enumerate(input.placements):
        if isinstance(s, Shard) and s.dim == 0:
            weight_placement[idx] = _Partial()
        else:
            weight_placement[idx] = weight.placements[idx]

    weight_grad = DTensorSpec(mesh=weight.mesh, placements=weight_placement, shape=weight.shape)

    return OutputSharding(output_spec=(
        input,
        weight_grad,
        None,
    ), )


@register_prop_rule("aten.native_batch_norm.default")
@register_prop_rule("aten.cudnn_batch_norm.default")
def _prop_batch_norm_default(op_schema: OpSchema) -> OutputSharding:
    input, weight, bias = op_schema.args_schema[0:3]
    running_mean, running_var = op_schema.args_schema[3:5]

    reserve_placement = [Replicate()] * weight.mesh.ndim

    reserve = DTensorSpec(mesh=weight.mesh, placements=reserve_placement, shape=torch.Size([0]))

    return OutputSharding(output_spec=(input, running_mean, running_var, reserve), )


@register_prop_rule("aten.cudnn_batch_norm_backward.default")
def _prop_cudnn_batch_norm_backward_default(op_schema: OpSchema) -> OutputSharding:
    input, grad_output, weight, running_mean, running_var = op_schema.args_schema[0:5]

    weight_placement = [Replicate()] * weight.mesh.ndim

    for idx, s in enumerate(weight.placements):
        grad_place = grad_output.placements[idx]
        if isinstance(s, Replicate) and isinstance(grad_place, Shard) and grad_place.dim == 0:
            weight_placement[idx] = _Partial()
        if isinstance(s, Shard):
            weight_placement[idx] = weight.placements[idx]

    weight_grad = DTensorSpec(mesh=weight.mesh, placements=weight_placement, shape=weight.shape)
    bias_grad = DTensorSpec(mesh=weight.mesh, placements=weight_placement, shape=weight.shape)

    return OutputSharding(
        # NOTE: type errors below are legit. This is because DTensor currently
        # doesn't support Optional return values. Need to be fixed in DTensor repo.
        output_spec=(
            input,
            weight_grad,
            bias_grad,
        ), )


@register_prop_rule("aten.native_batch_norm_backward.default")
def _prop_batch_norm_backward_default(op_schema: OpSchema) -> OutputSharding:
    grad_output, input, weight, running_mean, running_var = op_schema.args_schema[0:5]

    weight_placement = [Replicate()] * weight.mesh.ndim

    for idx, s in enumerate(weight.placements):
        grad_place = grad_output.placements[idx]
        if isinstance(s, Replicate) and isinstance(grad_place, Shard) and grad_place.dim == 0:
            weight_placement[idx] = _Partial()
        if isinstance(s, Shard):
            weight_placement[idx] = weight.placements[idx]

    weight_grad = DTensorSpec(mesh=weight.mesh, placements=weight_placement, shape=weight.shape)
    bias_grad = DTensorSpec(mesh=weight.mesh, placements=weight_placement, shape=weight.shape)

    return OutputSharding(
        # NOTE: type errors below are legit. This is because DTensor currently
        # doesn't support Optional return values. Need to be fixed in DTensor repo.
        output_spec=(
            input,
            weight_grad,
            bias_grad,
        ), )


@register_prop_rule("aten.max_pool2d_with_indices.default")
def _prop_max_pool2d_with_indices(op_schema: OpSchema) -> OutputSharding:
    input = op_schema.args_schema[0]
    kernel, stride, padding = op_schema.args_schema[1:4]

    output_shape = list(input.shape)
    out_hw = int(math.floor(input.shape[2] + 2 * padding[0] - (kernel[0] - 1) - 1) / stride[0] + 1)
    output_shape[2], output_shape[3] = out_hw, out_hw

    output_spec = DTensorSpec(mesh=input.mesh,
                              placements=input.placements,
                              shape=torch.Size(output_shape))

    return OutputSharding(output_spec=(output_spec, input))


@register_prop_rule("aten.max_pool2d_with_indices_backward.default")
def _prop_max_pool2d_with_indices_backward(op_schema: OpSchema) -> OutputSharding:
    grad_output, input = op_schema.args_schema[0:2]

    return OutputSharding(output_spec=input)
