import torch
import torch.nn as nn
from . import kernels

class ParallelLinear(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx, x, expert_weights, k,
        sorted_expert_idxs, sorted_scattered_idxs,
        padded_block_idxs, expert_offsets,
        gates=None, grouped_in=False, grouped_out=False,
    ):

        output = kernels.ops.scatter2scatter(
            X=x, W=expert_weights,
            sorted_expert_idxs=sorted_expert_idxs,
            sorted_scattered_idxs=sorted_scattered_idxs,
            padded_block_idxs=padded_block_idxs,
            k=k, x_grouped=grouped_in, y_grouped=grouped_out
        )
        if gates is not None:
            output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
            output = torch.bmm(
                gates[:, None, :],
                output_expanded
            ).squeeze(1)
        else:
            output_expanded = None

        ctx.save_for_backward(
            x, expert_weights,
            sorted_expert_idxs,
            sorted_scattered_idxs,
            padded_block_idxs, expert_offsets,
            gates,
            output_expanded
        )
        ctx.grouped_in = grouped_in
        ctx.grouped_out = grouped_out
        ctx.k = k
        return output
    @staticmethod
    def backward(ctx, grad_out):
        (x, expert_weights,
         sorted_expert_idxs,
         sorted_scattered_idxs,
         padded_block_idxs, expert_offsets,
         gates, output_expanded) = ctx.saved_tensors
        k = ctx.k
        grouped_in = ctx.grouped_in
        grouped_out = ctx.grouped_out
        # print("backward")
        if gates is not None:
            # calculate gates gradient
            d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
            gates_flat = gates.flatten()
            gate_fan = gates.size(1)
            # print("expanded and grouping")
            grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
        else:
            d_gates = None
            gates_flat = None
            gate_fan = 1
            grouped_grad_out = None

        if grouped_out:
            grouped_grad_out = grad_out
        else:
            grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
                                                 fan_out=gate_fan, coeff=gates_flat,
                                                 out=grouped_grad_out)
        if grouped_in:
            grouped_x = x
            d_expanded_input = None
        else:
            grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
            d_expanded_input = grouped_x
        d_weights = kernels.ops.group_bwd_W(
            DY=grouped_grad_out, X=grouped_x,
            expert_offsets=expert_offsets,
            E=expert_weights.size(0)
        )
        d_expanded_input = kernels.ops.scatter2scatter(
            X=grouped_grad_out, x_grouped=True,
            W=expert_weights.permute(0, 2, 1),
            padded_block_idxs=padded_block_idxs,
            sorted_expert_idxs=sorted_expert_idxs,
            sorted_scattered_idxs=sorted_scattered_idxs,
            k=1,
            y_grouped=grouped_in,
            out=d_expanded_input # Reuse grouped_x buffer
        )

        if k == 1:
            d_input = d_expanded_input
        else:
            d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
        # print("backward end.")
        return (
            # x, expert_weights, k,
            d_input, d_weights, None,
            # sorted_expert_idxs, sorted_scattered_idxs,
            None, None,
            # padded_block_idxs, expert_offsets,
            None, None,
            # gates
            d_gates, None, None
        )

def parallel_linear(inputs, expert_weights, k,
                    sorted_expert_idxs, sorted_scattered_idxs,
                    padded_block_idxs, expert_offsets,
                    gates=None):
    results = ParallelLinear.apply(inputs, expert_weights, k,
                                   sorted_expert_idxs, sorted_scattered_idxs,
                                   padded_block_idxs, expert_offsets, gates)
    return results

class ParallelExperts(nn.Module):
    def __init__(self, num_experts, input_size, output_size) -> None:
        super().__init__()
        # self.input_experts = nn.ModuleList(
        #     [nn.Linear(input_size, output_size, bias=bias) for _ in range(num_experts)]
        # )
        self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
        self.reset_parameters()
        self.num_experts = num_experts
        self.input_size = input_size
        self.output_size = output_size

    def extra_repr(self):
        return 'num_experts={}, input_size={}, output_size={}'.format(
            self.num_experts, self.input_size, self.output_size)

    def reset_parameters(self) -> None:
        nn.init.uniform_(self.weight, -1. / self.weight.size(2), 1. / self.weight.size(2))

    def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
                padded_block_idxs, expert_offsets,
                gates=None, grouped_in=False, grouped_out=False):

        results = ParallelLinear.apply(
            inputs, self.weight.permute(0, 2, 1), k,
            sorted_expert_idxs, sorted_scattered_idxs,
            padded_block_idxs, expert_offsets,
            gates, grouped_in, grouped_out
        )
        return results
