import torch
from torch import nn
from typing import Optional
from .int_linear import QuantLinear
from .int_matmul import QuantMatMul
import torch.nn.functional as F
from lavin.model import RMSNorm,apply_rotary_emb
from collections import OrderedDict
import math
from timm.models.layers import  DropPath
import copy

class OmniLlamaRMSNorm(nn.Module):
    def __init__(self, ori_norm, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.register_buffer('weight',ori_norm.weight)
        self.bias = None
        self.variance_epsilon = eps
        self.use_temporary_parameter = False
        self.eps = eps

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight




class QuantLlamaMLP(nn.Module):
    def __init__(
        self,
        org_module: nn.Module,
        args=None,
    ):
        super().__init__()
        self.w1 = QuantLinear(org_module.w1,
                            args.weight_quant_params,
                            args.act_quant_params)
        self.w2 = QuantLinear(org_module.w2,
                                args.weight_quant_params,
                                args.act_quant_params)
        self.w3 = QuantLinear(org_module.w3,
                                args.weight_quant_params,
                                args.act_quant_params)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x),inplace=False) * self.w3(x))


class QuantLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, 
                 org_module: nn.Module,
                 args=None):
        super().__init__()
        self.n_local_heads = org_module.n_local_heads
        self.head_dim = org_module.head_dim
        self.wq = QuantLinear(
            org_module.wq,
            args.weight_quant_params,
            args.act_quant_params,
        )
        self.wk = QuantLinear(
            org_module.wk,
            args.weight_quant_params,
            args.act_quant_params,
        )
        self.wv = QuantLinear(
            org_module.wv,
            args.weight_quant_params,
            args.act_quant_params,
        )
        
        self.wo = QuantLinear(
            org_module.wo, 
            args.weight_quant_params, 
            args.act_quant_params
        )

        self.use_weight_quant = False
        self.use_act_quant = False

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        keys = xk
        values = xv

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        # print("query q type:{}".format(xq.dtype))
        
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

    def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
        # setting weight quantization here does not affect actual forward pass
        self.use_weight_quant = weight_quant
        self.use_act_quant = act_quant
        for m in self.modules():
            if isinstance(m, (QuantLinear, QuantMatMul)):
                m.set_quant_state(weight_quant, act_quant)
                


class QuantTransformerBlock(nn.Module):
    def __init__(self, 
                 ori_layer,
                 args):
        super().__init__()
        self.n_heads = ori_layer.n_heads
        self.dim = ori_layer.dim
        self.head_dim = ori_layer.dim // ori_layer.n_heads
        self.attention = QuantLlamaAttention(
            org_module=ori_layer.attention,
            args=args,
            )
        self.feed_forward = QuantLlamaMLP(
            org_module=ori_layer.feed_forward,
            args=args,
        )
        self.layer_id = ori_layer.layer_id
        self.attention_norm = OmniLlamaRMSNorm(ori_layer.attention_norm, eps=ori_layer.attention_norm.eps)
        self.ffn_norm =OmniLlamaRMSNorm(ori_layer.ffn_norm, eps=ori_layer.ffn_norm.eps)
        self.drop_path = DropPath(ori_layer.drop_prob) if ori_layer.drop_prob > 0. else nn.Identity()
        self.let = False

        self.cache_weights = torch.zeros(
            (args.max_batch_size, 2)
        ).cuda()
        self.cache_weights_ffn = torch.zeros(
            (args.max_batch_size, 2)
        ).cuda()
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):

        h = x + self.drop_path(self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter))
        out = h + self.drop_path(self.feed_forward.forward(self.ffn_norm(h)))
        return out



    def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
        # setting weight quantization here does not affect actual forward pass
        self.use_weight_quant = weight_quant
        self.use_act_quant = act_quant
        names = []
        for name, m in self.named_modules():
            if isinstance(m, (QuantLinear, QuantMatMul)):
                names.append(name)
                m.set_quant_state(weight_quant, act_quant)

    def add_scaling(self):
        names = []
        for name, m in self.named_modules():
            if isinstance(m, (QuantLinear, QuantMatMul)):
                names.append(name)
                m.weight_quantizer.register_parameter('dynamic_scale',
                                     nn.Parameter(torch.ones((m.weight_quantizer.group_num,1))))

    def smooth_and_quant_temporary(self):
        if self.let:
            pass
        else:
            for name, module in self.named_modules():
                if isinstance(module, QuantLinear):
                    module.temp_weight = module.weight 
        # quant
        for name, module in self.named_modules():
            if isinstance(module, QuantLinear):
                if hasattr(module, "temp_weight"):
                    module.temp_weight = module.weight_quantizer(module.temp_weight)
                else:
                    module.temp_weight = module.weight_quantizer(module.weight)
                if not hasattr(module, "temp_bias"):
                    module.temp_bias = module.bias
                module.use_temporary_parameter=True

    def clear_temp_variable(self): 
       for name, module in self.named_modules():
            if isinstance(module, QuantLinear):
                del module.temp_weight
                del module.temp_bias

    @torch.no_grad()
    def smooth_and_quant_fake(self):
        for name, module in self.named_modules():
            if isinstance(module, QuantLinear):
                module.use_temporary_parameter=False
                module.set_quant_state(weight_quant = True, act_quant = False)
                temp = module.weight.data
                module.weight = module.weight_quantizer(module.weight)
                module.weight.data = temp
                module.weight_quantizer.mode = 'training'
                del module.weight_quantizer.upbound_factor
                del module.weight_quantizer.lowbound_factor
        self.register_scales_and_zeros()
    
    @torch.no_grad()
    def smooth_and_quant_inplace(self):
        for name, module in self.named_modules():
            if isinstance(module, QuantLinear):
                module.weight = module.weight_quantizer(module.weight)
                module.use_temporary_parameter=False
                module.set_quant_state(weight_quant = False, act_quant = False)
                del module.weight_quantizer.upbound_factor
                del module.weight_quantizer.lowbound_factor
    
    @torch.no_grad()
    def smooth_and_quant(self):
        for name, module in self.named_modules():
            if isinstance(module, QuantLinear):
                module.use_temporary_parameter=False
                module.set_quant_state(weight_quant = True, act_quant = False)
                module.weight_quantizer.mode = 'training'
                module.weight_quantizer.per_token_dynamic_calibration(module.weight.data)
                del module.weight_quantizer.upbound_factor
                del module.weight_quantizer.lowbound_factor
        self.register_scales_and_zeros_params()
                
                
    @torch.no_grad()
    def smooth_and_quant_inplace_test(self):
        for name, module in self.named_modules():
            if isinstance(module, QuantLinear):
                module.weight = module.weight_quantizer.quant_int(module.weight)
                module.use_temporary_parameter=False

    def lwc_parameters(self):
        params = []
        for n, m in self.named_parameters():
            if n.find('bound_factor') > -1:
                params.append(m)
        return iter(params)  

    def omni_parameters(self, use_shift=True):
        params = []
        template = "smooth" if use_shift else "smooth_scale"
        for n, m in self.named_parameters():
            if n.find('bound_factor') > -1 or n.find(template) > -1:
                params.append(m)
        return iter(params)  
    
    def omni_state_dict(self, destination=None, prefix='', keep_vars=False):
        if destination is None:
            destination = OrderedDict()
        for name, param in self.named_parameters():
            if name.find('smooth') > -1 or name.find('bound_factor') > -1:
                destination[prefix + name] = param if keep_vars else param.detach()
        return destination
    def lwc_state_dict(self, destination=None, prefix='', keep_vars=False):
        if destination is None:
            destination = OrderedDict()
        for name, param in self.named_parameters():
            if name.find('bound_factor') > -1:
                destination[prefix + name] = param if keep_vars else param.detach()
        return destination
    
    def register_scales_and_zeros(self):
        for name, module in self.named_modules():
            if isinstance(module, QuantLinear):
                module.weight_quantizer.register_scales_and_zeros()
