from os import device_encoding
import torch
import torch.nn as nn
from .observer import ObserverBase
from .fake_quant import QuantizeBase
from .util_quant import floor_ste
from .util_quant import (
    fake_quantize_per_channel_affine,
    fake_quantize_per_tensor_affine,
    fake_quantize_learnable_per_tensor_affine_training,
    fake_quantize_learnable_per_channel_affine_training,
    fake_quantize_learnableplus_per_channel_affine_training,
    fake_quantize_learnableplus_per_tensor_affine_training,
    fake_quantize_msafinetune_per_channel_affine_training,
    fake_quantize_msafinetune_per_tensor_affine_training,
)

class MSAFinetuneFakeQuantize(QuantizeBase):

    def __init__(self, observer, bit=8, symmetric=False, ch_axis=-1, use_grad_scaling=False):
        super().__init__(observer, bit=bit, symmetric=symmetric, ch_axis=ch_axis)
        self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
        self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
        self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
        self.use_grad_scaling = use_grad_scaling
        self.finetune = False
        self.beta = 20.0

    def init(self, beta=20.0):
        self.finetune = True
        self.beta = beta

    def finetune_forward(self, X, hard_value=False):
        if self.ch_axis != -1:
            if self.use_grad_scaling:
                grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * self.quant_max) ** 0.5
            else:
                grad_factor = 1.0
            X = fake_quantize_msafinetune_per_channel_affine_training(
                    X, self.scale, self.zero_point.data.int(), self.ch_axis,
                    self.quant_min, self.quant_max, grad_factor, self.beta,
                    hard_value)
        else:
            if self.use_grad_scaling:
                grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
            else:
                grad_factor = 1.0
            X = fake_quantize_msafinetune_per_tensor_affine_training(
                    X, self.scale, self.zero_point.item(), 
                    self.quant_min, self.quant_max, grad_factor, self.beta,
                    hard_value)
        return X

    def get_hard_value(self, X):
        X = self.finetune_forward(X, hard_value=True)
        return X

    def forward(self, X):
        if self.observer_enabled == 1:
            self.observer(X.detach())
            _scale, _zero_point = self.observer.calculate_qparams(self.observer.min_val, self.observer.max_val)
            _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
            if self.scale.shape != _scale.shape:
                self.scale.resize_(_scale.shape)
                self.zero_point.resize_(_zero_point.shape)
            self.scale.copy_(_scale)
            self.zero_point.copy_(_zero_point)
        else:
            self.scale.data.abs_()
            self.scale.data.clamp_(min=self.eps.item())

        if self.fake_quant_enabled == 1:
            if not self.finetune:
                if self.ch_axis != -1:
                    X = fake_quantize_per_channel_affine(
                        X, self.scale.data, self.zero_point.data.int(), self.ch_axis,
                        self.quant_min, self.quant_max)
                else:
                    X = fake_quantize_per_tensor_affine(
                        X, self.scale.item(), self.zero_point.item(),
                        self.quant_min, self.quant_max)
            else:
                X = self.finetune_forward(X)
            
        return X