import torch
import torch.nn as nn
import torch.nn.functional as F


from quantization.quantizer import Quantizer
from quantization.transforms.transforms import BaseTransform


def _broadcast_and_apply_transform(tran_dim, weight, transform, inv_t):
    W_ = weight
    init_shape = W_.shape
    temp = W_.reshape(-1, init_shape[-1] // tran_dim, tran_dim)
    temp = transform(temp, inv_t=inv_t, dim=-1)
    weight = temp.reshape(init_shape)
    return weight


def _fold_rms_gamma(weight, gamma):
    return weight @ torch.diag(gamma)


def _repeat_bias_term(ouput_size, bias_term, dim_size):
    repeat_factor = ouput_size // dim_size
    return bias_term.repeat(repeat_factor)


class QLinear(nn.Linear):

    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        bias: bool = True,
        weight_quantizer: Quantizer = None,
        act_quantizer: Quantizer = None,
        norm_gamma: torch.Tensor = None,
        device: torch.device = None,
        dtype: torch.dtype = None
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.weight_quantizer = weight_quantizer
        self.act_quantizer = act_quantizer

        # Store norm_gamma as a buffer so it follows .to(...)/.cuda() and is included in state_dict.
        # Note: buffers are not trainable and won't receive gradients.
        if norm_gamma is not None and not isinstance(norm_gamma, torch.Tensor):
            norm_gamma = torch.as_tensor(norm_gamma)
        self.register_buffer("norm_gamma", norm_gamma, persistent=True)

        self._train_mode = True

    def run_transforms(self,
                       weight, bias,
                       in_transform: BaseTransform = None,
                       out_transform: BaseTransform = None,
                       reverse_r2_transform_dim: bool = None):

        ####################################################
        # Apply transform on the INPUT features
        ####################################################
        if in_transform:
            if reverse_r2_transform_dim is not None:
                # Not None dim indicates R2 reverse transform is needed
                weight = _broadcast_and_apply_transform(reverse_r2_transform_dim, weight, in_transform, inv_t=True)
            else:
                if self.norm_gamma is not None:
                    weight = _fold_rms_gamma(weight, self.norm_gamma.to(device=weight.device, dtype=weight.dtype))
                weight = in_transform(weight, inv_t=True, dim=-1)

            if hasattr(in_transform, "bias"):
                # weight is already W @ gamma @ in_transform_inv, a.k.a weight --> W_tilde
                if in_transform.bias is not None:
                    bias_term = in_transform.bias
                    if reverse_r2_transform_dim is not None:
                        bias_term = _repeat_bias_term(weight.shape[-1], bias_term, dim_size=reverse_r2_transform_dim)

                    bias_term = bias_term.to(device=bias.device, dtype=bias.dtype)
                    bias = bias - (weight @ bias_term).to(dtype=bias.dtype)

        ####################################################
        # Apply transform on the OUTPUT features
        ####################################################
        if out_transform:
            weight = _broadcast_and_apply_transform(tran_dim=out_transform.block_size, weight=weight.t(),
                                                    transform=out_transform, inv_t=False)
            weight = weight.t()

            if bias is not None:
                bias = _broadcast_and_apply_transform(tran_dim=out_transform.block_size, weight=bias,
                                                      transform=out_transform, inv_t=False)

                if hasattr(out_transform, "bias"):
                    if out_transform.bias is not None:
                        bias_term = _repeat_bias_term(bias.shape[-1], out_transform.bias,
                                                      dim_size=out_transform.block_size).to(bias.dtype)
                        bias = bias + bias_term

        return weight, bias

    def quantize_weights(self, weight):
            w_scales, w_zeros = self.weight_quantizer.get_quantization_params(weight)
            return self.weight_quantizer(weight, w_scales, w_zeros)

    def forward(
        self, 
        x: torch.Tensor, 
        in_transform: BaseTransform = None, 
        out_transform: BaseTransform = None,
        reverse_r2_transform_dim: bool = None
    ) -> torch.Tensor:

        weight = self.weight
        bias = self.bias if self.bias is not None else torch.zeros(self.weight.shape[0], dtype=self.weight.dtype, device=self.weight.device)

        if self._train_mode:

            weight, bias = self.run_transforms(weight, bias, in_transform, out_transform, reverse_r2_transform_dim)

            if self.weight_quantizer is not None:
                weight = self.quantize_weights(weight)

        if self.act_quantizer is not None:
            a_scales, a_zeros = self.act_quantizer.get_quantization_params(x)
            x = self.act_quantizer(x, a_scales, a_zeros)

        return F.linear(x, weight, bias)

    def fix_parametrization(
        self, 
        in_transform: BaseTransform = None, 
        out_transform: BaseTransform = None,
        reverse_r2_transform_dim: bool = None
    ) -> None:
        weight = self.weight
        bias = self.bias

        weight, bias = self.run_transforms(weight, bias, in_transform, out_transform, reverse_r2_transform_dim)

        if self.weight_quantizer is not None:
            weight = self.quantize_weights(weight)
            self.weight_quantizer._track_global_scale = False

        if self.act_quantizer is not None:
            self.act_quantizer._track_global_scale = False

        self.weight.data = weight
        if bias is not None:
            self.bias.data = bias

        self._train_mode = False
