from typing import Optional
import torch
from torch.autograd.function import Function
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.modules.utils import _single, _pair, _triple
from torch.cuda.amp import autocast

import lbmqt.cpp_extension.backward_func as ext_backward_func
from lbmqt.optim.optimizer import LowBitOptimizer
from lbmqt.qscheme import QScheme
from lbmqt.conf import config


def get_forward_param(p, scheme: QScheme):
    if p is not None:
        forward_p = scheme.dequantize_data(p, p.pname)
        requires_grad = False
        if config.training and p.trainable:
            requires_grad = True
        forward_p.requires_grad = requires_grad
        return forward_p
    else:
        return None


class linear(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, scheme: Optional[QScheme] = None, optimizer: Optional[LowBitOptimizer] = None):
        # actually parameter 'optimizer' cannot be None
        forward_weight = get_forward_param(weight, scheme)
        forward_bias = get_forward_param(bias, scheme)
        
        if config.training:
            ctx.saved = input, weight, bias, scheme
            ctx.optimizer = optimizer
            weight_requires_grad = forward_weight.requires_grad
            bias_requires_grad = forward_bias.requires_grad if (bias is not None) else None
            weight_pname = weight.pname
            bias_pname = bias.pname if bias is not None else None
            ctx.other_args = weight_requires_grad, bias_requires_grad, weight_pname, bias_pname

        out = F.linear(input, forward_weight, forward_bias)
        if scheme.is_quantized(weight.pname):
            del forward_weight
        if bias is not None and scheme.is_quantized(bias.pname):
            del forward_bias
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias, scheme = ctx.saved
        weight_requires_grad, bias_requires_grad, weight_pname, bias_pname = ctx.other_args
        optimizer = ctx.optimizer

        del ctx.saved

        C_in = input.shape[-1]
        C_out = grad_output.shape[-1]

        # input gradient
        weight = scheme.dequantize_data(weight, weight_pname)
        grad_output_flatten = grad_output.view(-1, C_out)
        input_flatten = input.view(-1, C_in)
        with autocast(enabled=False):
            grad_input = grad_output_flatten.float().mm(weight)
        if scheme.is_quantized(weight_pname):
            del weight
        
        # weight gradient
        if weight_requires_grad:
            grad_weight = grad_output_flatten.t().mm(input_flatten)
            optimizer.accumulate_gradient(weight_pname, grad_weight)
        else:
            grad_weight = None
        
        # bias gradient
        if bias is not None and bias_requires_grad:
            grad_bias = grad_output_flatten.sum(0)
            optimizer.accumulate_gradient(bias_pname, grad_bias)
        else:
            grad_bias = None
        
        return grad_input, None, None, None, None


class convnd(Function):
    @staticmethod
    def run_forward(n, forward_op, ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme: Optional[QScheme] = None, optimizer: Optional[LowBitOptimizer] = None):
        forward_weight = get_forward_param(weight, scheme)
        forward_bias = get_forward_param(bias, scheme)

        if config.training:
            ctx.saved = input, weight, bias, scheme
            ctx.optimizer = optimizer
            weight_requires_grad = forward_weight.requires_grad
            bias_requires_grad = forward_bias.requires_grad if (bias is not None) else None
            weight_pname = weight.pname
            bias_pname = bias.pname if bias is not None else None
            ctx.other_args = weight_requires_grad, bias_requires_grad, weight_pname, bias_pname, stride, padding, dilation, groups

        out = forward_op(input, forward_weight, forward_bias, stride, padding, dilation, groups)
        if scheme.is_quantized(weight.pname):
            del forward_weight
        if bias is not None and scheme.is_quantized(bias.pname):
            del forward_bias
        
        return out

    @staticmethod
    def run_backward(n, ctx, grad_output, bias_reduce_dims, aug):
        weight_requires_grad, bias_requires_grad, weight_pname, bias_pname, stride, padding, dilation, groups = ctx.other_args
        padding = aug(padding)
        stride = aug(stride)
        dilation = aug(dilation)
        # if config.debug_layers_backward:
        #     print(f'in backward function for layer {name}')

        input, weight, bias, scheme = ctx.saved
        optimizer = ctx.optimizer
        del ctx.saved

        weight = scheme.dequantize_data(weight, weight_pname)

        with autocast(enabled=False):
            input = input.float()
            grad_output = grad_output.float()
            grad_input, grad_weight = ext_backward_func.cudnn_convolution_backward(
                input, grad_output, weight, padding, stride, dilation, groups,
                True, False, False,
                [ctx.needs_input_grad[0], weight_requires_grad])
        if scheme.is_quantized(weight_pname):
            del weight
        if weight_requires_grad:
            optimizer.accumulate_gradient(weight_pname, grad_weight)

        if bias is not None and bias_requires_grad:
            grad_bias = grad_output.sum(bias_reduce_dims)
            optimizer.accumulate_gradient(bias_pname, grad_bias)
        else:
            grad_bias = None

        return (grad_input, None, None, None, None, None, None, None, None)


class conv1d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme=None, optimizer=None):
        return convnd.run_forward(1, F.conv1d, ctx, input, weight, bias, stride, padding, dilation, groups, scheme, optimizer)

    @staticmethod
    def backward(ctx, grad_output):
        return convnd.run_backward(1, ctx, grad_output, [0, 2], _single)


class conv2d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme=None, optimizer=None):
        return convnd.run_forward(2, F.conv2d, ctx, input, weight, bias, stride, padding, dilation, groups, scheme, optimizer)

    @staticmethod
    def backward(ctx, grad_output):
        return convnd.run_backward(2, ctx, grad_output, [0, 2, 3], _pair)


class conv3d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, scheme=None, optimizer=None):
        return convnd.run_forward(3, F.conv3d, ctx, input, weight, bias, stride, padding, dilation, groups, scheme, optimizer)

    @staticmethod
    def backward(ctx, grad_output):
        return convnd.run_backward(3, ctx, grad_output, [0, 2, 3, 4], _triple)


class conv_transposend(Function):
    @staticmethod
    def run_forward(n, forward_op, ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme: Optional[QScheme] = None, optimizer: Optional[LowBitOptimizer] = None):
        forward_weight = get_forward_param(weight, scheme)
        forward_bias = get_forward_param(bias, scheme)

        if config.training:
            ctx.saved = input, weight, bias, scheme
            ctx.optimizer = optimizer
            weight_requires_grad = forward_weight.requires_grad
            bias_requires_grad = forward_bias.requires_grad if (bias is not None) else None
            weight_pname = weight.pname
            bias_pname = bias.pname if bias is not None else None
            ctx.other_args = weight_requires_grad, bias_requires_grad, weight_pname, bias_pname, stride, padding, output_padding, dilation, groups

        # if config.debug_memory_op_forward:
        #     global conv2d_layer_ct, total_act_mem
        #     print("========== conv%dd_transpose forward %d ==========" % (n, conv2d_layer_ct))
        #     get_memory_usage(True)
        #     conv2d_layer_ct += 1
        #     total_act_mem += compute_tensor_bytes(quantized)
        #     print("Act mem: %.2f MB" % (total_act_mem / 1024 ** 2))

        out = forward_op(input, forward_weight, forward_bias, stride, padding, output_padding, groups, dilation)
        if scheme.is_quantized(weight.pname):
            del forward_weight
        if scheme.is_quantized(bias.pname):
            del forward_bias
        
        return out

    @staticmethod
    def run_backward(n, ctx, grad_output, bias_reduce_dims, aug):
        weight_requires_grad, bias_requires_grad, weight_pname, bias_pname, stride, padding, output_padding, dilation, groups = ctx.other_args
        padding = aug(padding)
        output_padding = aug(output_padding)
        stride = aug(stride)
        dilation = aug(dilation)

        input, weight, bias, scheme = ctx.saved
        optimizer = ctx.optimizer
        del ctx.saved

        # empty_cache(config.empty_cache_threshold)

        # if config.debug_memory_op_backward:
        #     global conv2d_layer_ct
        #     print("========== conv%dd_transpose backward %d ==========" % (n, conv2d_layer_ct))
        #     get_memory_usage(True)
        #     conv2d_layer_ct += 1
        #     print("WS: %.2f MB" % (compute_tensor_bytes([grad_output, input, input]) / 1024 ** 2))

        weight = scheme.dequantize_data(weight, weight_pname)
        with autocast(enabled=False):
            input = input.float()
            grad_output = grad_output.float()
            grad_input, grad_weight = ext_backward_func.cudnn_convolution_transpose_backward(
                input, grad_output, weight, padding, output_padding, stride, dilation, groups,
                config.cudnn_benchmark_conv2d, False, False, [ctx.needs_input_grad[0], weight_requires_grad])
        if scheme.is_quantized(weight_pname):
            del weight
        if weight_requires_grad:
            optimizer.accumulate_gradient(weight_pname, grad_weight)

        if bias is not None and bias_requires_grad:
            grad_bias = grad_output.sum(bias_reduce_dims)
            optimizer.accumulate_gradient(bias_pname, grad_bias)
        else:
            grad_bias = None

        return grad_input, None, None, None, None, None, None, None, None, None


class conv_transpose1d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme=None, optimizer=None):
        return conv_transposend.run_forward(1, F.conv_transpose1d, ctx, input, weight, bias, stride,
                                            padding, output_padding, groups, dilation, scheme, optimizer)

    @staticmethod
    def backward(ctx, grad_output):
        return conv_transposend.run_backward(1, ctx, grad_output, [0, 2], _single)


class conv_transpose2d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme=None, optimizer=None):
        return conv_transposend.run_forward(2, F.conv_transpose2d, ctx, input, weight, bias, stride,
                                            padding, output_padding, groups, dilation, scheme, optimizer)

    @staticmethod
    def backward(ctx, grad_output):
        return conv_transposend.run_backward(2, ctx, grad_output, [0, 2, 3], _pair)


class conv_transpose3d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, scheme=None, optimizer=None):
        return conv_transposend.run_forward(3, F.conv_transpose3d, ctx, input, weight, bias, stride,
                                            padding, output_padding, groups, dilation, scheme, optimizer)

    @staticmethod
    def backward(ctx, grad_output):
        return conv_transposend.run_backward(3, ctx, grad_output, [0, 2, 3, 4], _triple)

