import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union
import tqdm
import numpy as np
import pdb
import math
import logging
from mx.elemwise_ops import quantize_elemwise_op
from mx.mx_ops import quantize_mx_op
from mx.specs import MxSpecs

CLIPMIN = 1e-5


class UniformAffineQuantizer(nn.Module):
    def __init__(
        self,
        metric="minmax",
        dynamic_method="per_cluster",
        shape=None,
        lwc=False,
        args=None,
        mtype='linear',
        btype='w',
    ):
        """
        support cluster quantize
        dynamic_method support per_token and per_cluster
        """
        super().__init__()
        self.mx_specs = MxSpecs()
        self.metric = metric

        self.cached_xmin = None
        self.cached_xmax = None
        self.dynamic_method = dynamic_method
        self.lwc = lwc
        self.mtype = mtype
        self.btype = btype

        init_value = 4.             # inti value of learnable weight clipping

        self.enable = True
        self.configure(args, mtype=mtype, btype=btype)
        if lwc:
            if self.mx_specs['block_size'] > 0:
                assert shape[-1] % self.mx_specs['block_size'] == 0, "group_size must be a factor of the last dimension"
                dim1 = int(shape[0] * math.ceil(shape[1] // self.mx_specs['block_size']))
            else:
                dim1 = shape[0]
            self.upbound_factor = nn.Parameter(torch.ones((dim1, 1)) * init_value)
            self.lowbound_factor = nn.Parameter(torch.ones((dim1, 1)) * init_value)
        self.sigmoid = nn.Sigmoid()

    def configure(self, args, mtype='linear', btype='w'):
        assert not args.per_tensor, "Omniquant does not support per_tensor quantization"
        keys = ['scale_bits', 'block_size',
                'double_quant']
        for key in keys:
            try:
                val = vars(args)[f'{key}_{mtype}']
                self.mx_specs[key] = val if val != 'none' else None
            except:
                logging.info(f'[{mtype}] Set {key} to None')
                self.mx_specs[key] = None
                continue
        try:
            val = vars(args)[f'{btype}_elem_format_{mtype}']
            self.mx_specs[f'{btype}_elem_format'] = val if val != 'none' else None
        except:
            logging.info(f'[{mtype}] Set {btype}_elem_format to None')
            self.mx_specs[f'{btype}_elem_format'] = None
        if mtype == "":
            self.mx_specs[f'{btype}_elem_format'] = None

        try:
            val = vars(args)[f'{btype}_scale_mode']
            self.mx_specs[f'{btype}_scale_mode'] = val if val != 'none' else None
        except:
            logging.info(f'[{mtype}] Set {btype}_scale_mode to None')
            self.mx_specs[f'{btype}_scale_mode'] = None
        self.mx_specs['per_tensor'] = args.per_tensor
        self.mx_specs['custom_cuda'] = True
        # logging.info(f"Using {mtype} quantization with {btype} format: {self.mx_specs}")

    def fake_quant(self, x):
        dtype = x.dtype
        with torch.no_grad():
            bf_x = quantize_elemwise_op(
                x.float(), mx_specs=self.mx_specs, round=self.mx_specs["round_output"]
            )
            qx = quantize_mx_op(
                bf_x,
                self.mx_specs,
                elem_format=self.mx_specs[f'{self.btype}_elem_format'],
                scale_mode=self.mx_specs[f'{self.btype}_scale_mode'],
                axes=[-1],
                round=self.mx_specs["round_mx_output"],
            )
            qx = qx.to(dtype)
        return (qx - x).detach() + x

    def forward(self, x: torch.Tensor):
        if self.mtype == "" and self.btype == "":
            # No quantization, just return the input
            return x
        if self.dynamic_method == "per_token" or self.dynamic_method == "per_channel":
            x_dequant = self.per_token_dynamic_calibration(x)
        else:
            x_dequant = self.fake_quant(x)

        return x_dequant

    def per_token_dynamic_calibration(self, x):
        dtype = x.dtype
        self.input_shape = x.shape
        if self.mx_specs['block_size'] > 0:
            x = x.reshape(-1, self.mx_specs['block_size'])
        reduce_shape = [-1]
        xmin = x.amin(reduce_shape, keepdim=True)
        xmax = x.amax(reduce_shape, keepdim=True)
        if self.lwc:
            xmax = self.sigmoid(self.upbound_factor) * xmax
            xmin = self.sigmoid(self.lowbound_factor) * xmin
        x = x.clamp(xmin, xmax)
        x = x.reshape(self.input_shape)
        with torch.no_grad():
            bf_x = quantize_elemwise_op(
                x.float(), mx_specs=self.mx_specs, round=self.mx_specs["round_output"]
            )
            qx = quantize_mx_op(
                bf_x,
                self.mx_specs,
                elem_format=self.mx_specs[f'{self.btype}_elem_format'],
                scale_mode=self.mx_specs[f'{self.btype}_scale_mode'],
                axes=[-1],
                round=self.mx_specs["round_mx_output"],
            )
            qx = qx.to(dtype).reshape(self.input_shape)
        return (qx - x).detach() + x
