# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

import os
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function, InplaceFunction
from torch.cuda import amp

from .Qconfig import qconfig
from .QFunction import *
from .utils import *


class QLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, args=None, layer_type=""):
        super().__init__(in_features, out_features, bias)
        self.args = deepcopy(args)
        self.layer_type = layer_type
        assert layer_type != "", "layer_type is not defined"
        assert layer_type in qconfig.qlinear_config.keys(), f"{layer_type} not in qlinear_config"

        self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qlinear_config[layer_type])
        self.apply_quantize_fw, self.apply_quantize_fo, self.apply_quantize_bw, self.apply_quantize_ba = (
            self.apply_quantize,
            self.apply_quantize,
            self.apply_quantize,
            self.apply_quantize,
        )

        self.refine_rowcol_blocksize()

        self.fbit = self.args.fwbit if self.args.fwbit else self.Ubit
        self.bbit = self.args.bwbit if self.args.bwbit else self.Ubit
        quantize_flag = format_string_with_condition(
            layer_type,
            {
                "apply-fw": self.apply_quantize_fw,
                "apply-fo": self.apply_quantize_fo,
                "apply-bw": self.apply_quantize_bw,
                "apply-ba": self.apply_quantize_ba,
            },
            self.args.symm,
            self.fbit,
            self.bbit,
            {
                "row-fa": self.args.row_blocksize_fa,
                "col-fa": self.args.col_blocksize_fa,
                "row-fw": self.args.row_blocksize_fw,
                "col-fw": self.args.col_blocksize_fw,
                "row-fo": self.args.row_blocksize_fo,
                "col-fo": self.args.col_blocksize_fo,
                "row-ba": self.args.row_blocksize_ba,
                "col-ba": self.args.col_blocksize_ba,
                "row-bw": self.args.row_blocksize_bw,
                "col-bw": self.args.col_blocksize_bw,
                "row-bo": self.args.row_blocksize_bo,
                "col-bo": self.args.col_blocksize_bo,
            },
        )
        if quant_get_local_rank() == 0:
            print(quantize_flag)

    def refine_rowcol_blocksize(self):
        self.args.row_blocksize_fa, self.args.col_blocksize_fa = self.args.row_blocksize, self.args.col_blocksize
        self.args.row_blocksize_fw, self.args.col_blocksize_fw = self.args.row_blocksize, self.args.col_blocksize
        self.args.row_blocksize_fo, self.args.col_blocksize_fo = self.args.row_blocksize, self.args.col_blocksize
        self.args.row_blocksize_ba, self.args.col_blocksize_ba = self.args.row_blocksize, self.args.col_blocksize
        self.args.row_blocksize_bw, self.args.col_blocksize_bw = self.args.row_blocksize, self.args.col_blocksize
        self.args.row_blocksize_bo, self.args.col_blocksize_bo = self.args.row_blocksize, self.args.col_blocksize

        if self.args.refine_attn_blocksize:
            if self.layer_type in ["attn_q", "attn_k", "attn_v"]:
                self.apply_quantize_fo = False
                self.args.row_blocksize_ba, self.args.col_blocksize_ba = (
                    self.args.refine_row_blocksize,
                    self.args.refine_col_blocksize,
                )
            if self.layer_type in ["attn_proj"]:
                self.apply_quantize_ba = False
                self.args.row_blocksize_fo, self.args.col_blocksize_fo = (
                    self.args.refine_row_blocksize,
                    self.args.refine_col_blocksize,
                )

        if self.args.refine_mlp_blocksize:
            if self.layer_type in ["mlp_gate", "mlp_up", "mlp_down"]:
                self.args.row_blocksize_fo, self.args.col_blocksize_fo = (
                    self.args.refine_row_blocksize,
                    self.args.refine_col_blocksize,
                )
                self.args.row_blocksize_ba, self.args.col_blocksize_ba = (
                    self.args.refine_row_blocksize,
                    self.args.refine_col_blocksize,
                )

    def forward(self, Qinput, Iscale):
        if self.training:
            output = QuantLinear.apply(
                Qinput,
                Iscale,
                self.weight,
                self.bias,
                self.args,
                self.layer_name,
                self.apply_quantize_fw,
                self.apply_quantize_fo,
                self.apply_quantize_bw,
                self.apply_quantize_ba,
            )
            return output
        else:
            output = F.linear(Qinput, self.weight, self.bias)
            return output, None


# class QuantLinear(Function):
#     @staticmethod
#     def forward(ctx, input, weight, bias, args, layer_type):
#         ctx.saved = input, weight, bias, args, layer_type
#         return F.linear(input, weight, bias)
#
#     @staticmethod
#     def backward(ctx, grad_output):
#         input, weight, bias, args, layer_type = ctx.saved
#
#         C_in = input.shape[-1]
#         C_out = grad_output.shape[-1]
#
#         grad_output_flatten = grad_output.reshape(-1, C_out)
#         input_flatten = input.reshape(-1, C_in)
#
#         if grad_output_flatten.dtype == input_flatten.dtype:
#             grad_weight = grad_output_flatten.t().mm(input_flatten)
#         else:
#             grad_weight = grad_output_flatten.float().t().mm(input_flatten)
#
#         if grad_output_flatten.dtype == weight.dtype:
#             grad_input = grad_output_flatten.mm(weight)
#         else:
#             grad_input = grad_output_flatten.float().mm(weight)
#
#         if bias is not None:
#             grad_bias = grad_output_flatten.sum(0)
#         else:
#             grad_bias = None
#
#         grad_input_transform = grad_input.reshape(input.size())
#
#         return grad_input_transform, grad_weight, grad_bias, None, None

# B%% = block_cut(%%, args.row_blocksize, args.col_blocksize)
# RQ%%, Q%%, Wscale = block_quant(B%%, args.symm, args.fwbit, stochastic=False, epsilon=args.epsilon)
# Q%% = block_reshape(Q%%, %%, args.row_blocksize, args.col_blocksize)
# RQ%% = block_reshape(RQ%%, %%, args.row_blocksize, args.col_blocksize)


class QuantLinear(Function):
    @staticmethod
    @amp.custom_fwd(cast_inputs=torch.bfloat16)
    def forward(
        ctx,
        Qinput,
        Iscale,
        weight,
        bias,
        args,
        layer_name,
        apply_quantize_fw=True,
        apply_quantize_fo=True,
        apply_quantize_bw=True,
        apply_quantize_ba=True,
    ):

        # shrink Iscale to let the size of gradient the same as forward
        ideal_scale_num = Qinput.numel() / (args.min_blockunit_row * args.min_blockunit_col)
        actual_scale_num = calculate_scale_num(Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
        # actual_scale_num = Qinput.numel() / (args.row_blocksize_fa * args.col_blocksize_fa)
        assert Iscale.shape[0] == ideal_scale_num
        Iscale = Iscale[: int(actual_scale_num), :, :]

        Binput = block_cut(Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
        RQinput = Binput * Iscale
        RQinput = block_reshape(RQinput, Qinput, args.row_blocksize_fa, args.col_blocksize_fa)

        Bweight = block_cut(weight, args.row_blocksize_fw, args.col_blocksize_fw)
        RQweight, Qweight, Wscale = block_quant(
            Bweight,
            args.symm,
            args.fwbit,
            stochastic=False,
            epsilon=args.epsilon,
            apply_quantize=apply_quantize_fw,
            layer_name=layer_name + "WeightQuant",
        )
        Qweight = block_reshape(Qweight, weight, args.row_blocksize_fw, args.col_blocksize_fw)
        RQweight = block_reshape(RQweight, weight, args.row_blocksize_fw, args.col_blocksize_fw)

        if args.draw_distribution_forward:
            save_tensor(weight, Qweight, RQweight, fb="forward", aw="Weight", layer_name=layer_name)

        ctx.saved = Qinput, Iscale, Qweight, Wscale, bias, args, layer_name
        ctx.apply_quantize = apply_quantize_fw, apply_quantize_fo, apply_quantize_bw, apply_quantize_ba
        fc_output = F.linear(RQinput, RQweight, bias)

        Bfc_output = block_cut(fc_output, args.row_blocksize_fo, args.col_blocksize_fo)
        RQfc_output, Qfc_output, Oscale = block_quant(
            Bfc_output,
            args.symm,
            args.fabit,
            stochastic=False,
            epsilon=args.epsilon,
            apply_quantize=apply_quantize_fo,
            layer_name=layer_name + "LinearOutput",
        )
        RQfc_output = block_reshape(RQfc_output, fc_output, args.row_blocksize_fo, args.col_blocksize_fo)
        Qfc_output = block_reshape(Qfc_output, fc_output, args.row_blocksize_fo, args.col_blocksize_fo)

        if args.draw_distribution_forward:
            save_tensor(fc_output, Qfc_output, RQfc_output, fb="forward", aw="Output", layer_name=layer_name)

        # enlarge Oscale to let the size of gradient the same as forward
        ideal_scale_num = Qfc_output.numel() / (args.min_blockunit_row * args.min_blockunit_col)
        actual_scale_num = calculate_scale_num(Qfc_output, args.row_blocksize_fo, args.col_blocksize_fo)
        # actual_scale_num = Qfc_output.numel() / (args.row_blocksize_fo * args.col_blocksize_fo)
        assert Oscale.shape[0] == actual_scale_num
        Oscale = torch.nn.functional.pad(Oscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num)))

        return Qfc_output, Oscale

    @staticmethod
    @amp.custom_bwd
    def backward(ctx, Qgrad_output, Gscale):
        Qinput, Iscale, Qweight, Wscale, bias, args, layer_name = ctx.saved
        apply_quantize_fw, apply_quantize_fo, apply_quantize_bw, apply_quantize_ba = ctx.apply_quantize

        # shrink Gscale to let the size of gradient the same as forward
        ideal_scale_num = Qgrad_output.numel() / (args.min_blockunit_row * args.min_blockunit_col)
        actual_scale_num = calculate_scale_num(Qgrad_output, args.row_blocksize_bo, args.col_blocksize_bo)
        # actual_scale_num = Qgrad_output.numel() / (args.row_blocksize_bo * args.col_blocksize_bo)
        assert Gscale.shape[0] == ideal_scale_num
        Gscale = Gscale[: int(actual_scale_num), :, :]

        Bgrad_output = block_cut(Qgrad_output, args.row_blocksize_bo, args.col_blocksize_bo)
        RQgrad_output = Bgrad_output * Gscale
        grad_output = block_reshape(RQgrad_output, Qgrad_output, args.row_blocksize_bo, args.col_blocksize_bo)

        if args.draw_distribution_backward:
            save_tensor(
                grad_output, Qgrad_output, RQgrad_output, fb="backward in", aw="Activation", layer_name=layer_name
            )

        C_in = Qinput.shape[-1]
        C_out = Qgrad_output.shape[-1]

        Binput = block_cut(Qinput, args.row_blocksize_fa, args.col_blocksize_fa)
        input = Binput * Iscale
        input = block_reshape(input, Qinput, args.row_blocksize_fa, args.col_blocksize_fa)

        grad_output_flatten = grad_output.reshape(-1, C_out)
        input_flatten = input.reshape(-1, C_in)

        if grad_output_flatten.dtype == input_flatten.dtype:
            grad_weight = grad_output_flatten.t().mm(input_flatten)
        else:
            grad_weight = grad_output_flatten.float().t().mm(input_flatten)

        Bgrad_weight = block_cut(grad_weight, args.row_blocksize_bw, args.col_blocksize_bw)
        RQgrad_weight, Qgrad_weight, GWscale = block_quant(
            Bgrad_weight,
            args.symm,
            args.bwbit,
            stochastic=True,
            epsilon=args.epsilon,
            apply_quantize=apply_quantize_bw,
            layer_name=layer_name + "WeightGradient",
        )
        Qgrad_weight = block_reshape(Qgrad_weight, grad_weight, args.row_blocksize_bw, args.col_blocksize_bw)
        RQgrad_weight = block_reshape(RQgrad_weight, grad_weight, args.row_blocksize_bw, args.col_blocksize_bw)

        if args.draw_distribution_backward:
            save_tensor(grad_weight, Qgrad_weight, RQgrad_weight, fb="backward", aw="Weight", layer_name=layer_name)

        # Calculate Weight Gradient
        Bweight = block_cut(Qweight, args.row_blocksize_fw, args.col_blocksize_fw)
        weight = Bweight * Wscale
        weight = block_reshape(weight, Qweight, args.row_blocksize_fw, args.col_blocksize_fw)

        if grad_output_flatten.dtype == Qweight.dtype:
            grad_input = grad_output_flatten.mm(weight)
        else:
            grad_input = grad_output_flatten.float().mm(weight)

        Bgrad_input = block_cut(grad_input, args.row_blocksize_ba, args.col_blocksize_ba)
        RQgrad_input, Qgrad_input, GIscale = block_quant(
            Bgrad_input,
            args.symm,
            args.babit,
            stochastic=True,
            epsilon=args.epsilon,
            apply_quantize=apply_quantize_ba,
            layer_name=layer_name + "ActivationGradient",
        )
        Qgrad_input = block_reshape(Qgrad_input, grad_input, args.row_blocksize_ba, args.col_blocksize_ba)
        RQgrad_input = block_reshape(RQgrad_input, grad_input, args.row_blocksize_ba, args.col_blocksize_ba)

        if args.draw_distribution_backward:
            save_tensor(
                grad_input, Qgrad_input, RQgrad_input, fb="backward out", aw="Activation out", layer_name=layer_name
            )

        # enlarge Qgrad_input to let the size of gradient the same as forward
        ideal_scale_num = Qgrad_input.numel() / (args.min_blockunit_row * args.min_blockunit_col)
        actual_scale_num = calculate_scale_num(Qgrad_input, args.row_blocksize_ba, args.col_blocksize_ba)
        # actual_scale_num = Qgrad_input.numel() / (args.row_blocksize_ba * args.col_blocksize_ba)
        assert GIscale.shape[0] == actual_scale_num
        GIscale = torch.nn.functional.pad(GIscale, (0, 0, 0, 0, 0, int(ideal_scale_num - actual_scale_num)))

        Qgrad_input_transform = Qgrad_input.reshape(Qinput.size())

        if bias is not None:
            grad_bias = grad_output_flatten.sum(0)
        else:
            grad_bias = None

        return Qgrad_input_transform, GIscale, RQgrad_weight, grad_bias, None, None, None, None, None, None
