import rotation_utils
import hadamard_utils
import quant_utils
from transformers import (
    LlamaForCausalLM,
    Qwen2ForCausalLM,
)

class SrLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
       
        self.register_buffer("Q", rotation_utils.get_orthogonal_matrix(self.config.hidden_size, mode='hadamard'))
        self.seqlen = 2048
        self.train_mode = False
        self.relu_init_mode = False 
        self.model_type = 'llama'
        
    def act_wrapper(self):   
        quant_utils.add_norm_wrapper(self)
        quant_utils.add_actquant(self)

        qlayers_1 = quant_utils.find_norm(self)  
        qlayers_2 = quant_utils.find_qlayers(self)
        qlayers = {**qlayers_1, **qlayers_2} 
        
        for name in qlayers:
            qlayers[name].name = name
            qlayers[name].train_mode = self.train_mode 
            
            self._act_wrapper_to_qkv_proj(qlayers, name, is_qlayer_trainable = False)
            # self._act_wrapper_to_qkv_proj(qlayers, name, is_qlayer_trainable = True, sparse_mode=True, q=0.6)

            self._act_wrapper_to_upgate_proj(qlayers, name, is_qlayer_trainable = False)

            self._act_wrapper_to_o_proj(qlayers, name, is_qlayer_trainable = False)

            self._act_wrapper_to_down_proj(qlayers, name, is_qlayer_trainable = False)
      
    def _act_wrapper_to_qkv_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False, q=None): 
        if self.train_mode:
            if is_qlayer_trainable:
                if 'input_layernorm' in name:
                    qlayers[name].rotate_mode = 'Q'
                    qlayers[name].Q = self.Q

                    if sparse_mode:
                        qlayers[name].sparse_mode = True
                        qlayers[name].sparse_func = quant_utils.QuantileShiftedReLU(q = q, init_mode = self.relu_init_mode)
        else: 
            if 'input_layernorm' in name:
                qlayers[name].rotate_mode = 'Q'
                qlayers[name].Q = self.Q

                if sparse_mode:
                    qlayers[name].sparse_mode = True
                    qlayers[name].sparse_func = quant_utils.QuantileShiftedReLU(q = q, init_mode = self.relu_init_mode)

            if 'q_proj' in name:
                qlayers[name].rotate_mode = 'Q'
                qlayers[name].Q = self.Q
                qlayers[name].rotate_weight()
                
            if 'k_proj' in name:
                qlayers[name].rotate_mode = 'Q'
                qlayers[name].Q = self.Q
                qlayers[name].rotate_weight()
                
            if 'v_proj' in name :
                qlayers[name].rotate_mode = 'Q'
                qlayers[name].Q = self.Q
                qlayers[name].rotate_weight()
            
    def _act_wrapper_to_upgate_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False, q=None): 
        if self.train_mode: 
            if is_qlayer_trainable:
                if 'post_attention_layernorm' in name:
                    qlayers[name].rotate_mode = 'Q'
                    qlayers[name].Q = self.Q

                    if sparse_mode:
                        qlayers[name].sparse_mode = True
                        qlayers[name].sparse_func = quant_utils.QuantileShiftedReLU(q = q, init_mode = self.relu_init_mode)
        else:
            if 'post_attention_layernorm' in name:
                    qlayers[name].rotate_mode = 'Q'
                    qlayers[name].Q = self.Q

                    if sparse_mode:
                        qlayers[name].sparse_mode = True
                        qlayers[name].sparse_func = quant_utils.QuantileShiftedReLU(q = q, init_mode = self.relu_init_mode)

            if 'up_proj' in name:
                qlayers[name].rotate_mode = 'Q'
                qlayers[name].Q = self.Q
                qlayers[name].rotate_weight()

            if 'gate_proj' in name:
                qlayers[name].rotate_mode = 'Q'
                qlayers[name].Q = self.Q
                qlayers[name].rotate_weight()

    def _act_wrapper_to_o_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False, q=None):
        if self.train_mode:
             if is_qlayer_trainable:
                if 'o_proj' in name:
                    qlayers[name].rotate_mode = 'Head_in'
                    qlayers[name].online_partial_had = True
                    had_K, K = hadamard_utils.get_hadK(self.config.num_attention_heads)
                    qlayers[name].had_K = had_K
                    qlayers[name].K = K
                    qlayers[name].had_dim = self.config.hidden_size//self.config.num_attention_heads

                    if sparse_mode:
                        qlayers[name].sparse_mode = True
                        qlayers[name].sparse_func = quant_utils.QuantileShiftedReLU(q = q, init_mode = self.relu_init_mode)

                if 'v_proj' in name :
                    qlayers[name].out_rotate_mode = 'Head_out'
                    had_K, K = hadamard_utils.get_hadK(self.config.num_attention_heads)
                    qlayers[name].had_K = had_K
                    qlayers[name].K = K,
                    qlayers[name].had_dim = self.config.hidden_size//self.config.num_attention_heads     
        else:
            if 'o_proj' in name:
                qlayers[name].rotate_mode = 'Head_in'
                had_K, K = hadamard_utils.get_hadK(self.config.num_attention_heads)
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].had_dim = self.config.hidden_size//self.config.num_attention_heads
                qlayers[name].rotate_weight()

                if sparse_mode:
                    qlayers[name].sparse_mode = True
                    qlayers[name].sparse_func = quant_utils.QuantileShiftedReLU(q = q, init_mode = self.relu_init_mode)

            if 'v_proj' in name :
                qlayers[name].out_rotate_mode = 'Head_out'
                had_K, K = hadamard_utils.get_hadK(self.config.num_attention_heads)
                qlayers[name].had_K = had_K
                qlayers[name].K = K,
                qlayers[name].had_dim = self.config.hidden_size//self.config.num_attention_heads            

    def _act_wrapper_to_down_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False):
        if self.train_mode:
            if is_qlayer_trainable:
                if 'down_proj' in name:
                    qlayers[name].rotate_mode = 'H'
                    had_K, K = hadamard_utils.get_hadK(self.config.intermediate_size)
                    qlayers[name].had_K = had_K
                    qlayers[name].K = K
        else:
            if 'down_proj' in name:
                qlayers[name].rotate_mode = 'H'
                had_K, K = hadamard_utils.get_hadK(self.config.intermediate_size)
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].rotate_weight()
     
class SrQwen2ForCausalLM(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
       
        self.register_buffer("Q", rotation_utils.get_orthogonal_matrix(self.config.hidden_size, mode='hadamard'))
        
        self.seqlen = 2048
        self.train_mode = False
        self.relu_init_mode = False 
        self.model_type = 'qwen'
        
    def act_wrapper(self):   
        quant_utils.add_norm_wrapper(self)
        quant_utils.add_actquant(self)

        qlayers_1 = quant_utils.find_norm(self)  
        qlayers_2 = quant_utils.find_qlayers(self)
        qlayers = {**qlayers_1, **qlayers_2} 
        
        for name in qlayers:
            qlayers[name].name = name
            qlayers[name].train_mode = self.train_mode 
            qlayers[name].model_type = self.model_type
      
            self._act_wrapper_to_qkv_proj(qlayers, name, is_qlayer_trainable = False)

            self._act_wrapper_to_upgate_proj(qlayers, name, is_qlayer_trainable = False)
           
            self._act_wrapper_to_o_proj(qlayers, name, is_qlayer_trainable = False)

            self._act_wrapper_to_down_proj(qlayers, name, is_qlayer_trainable = False)
      
    def _act_wrapper_to_qkv_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False, q=None): 
        if 'q_proj' in name:
            qlayers[name].rotate_mode = 'Q'
            qlayers[name].Q = self.Q
            qlayers[name].rotate_weight()
            
        if 'k_proj' in name:
            qlayers[name].rotate_mode = 'Q'
            qlayers[name].Q = self.Q
            qlayers[name].rotate_weight()
            
        if 'v_proj' in name :
            qlayers[name].rotate_mode = 'Q'
            qlayers[name].Q = self.Q
            qlayers[name].rotate_weight()
            
    def _act_wrapper_to_upgate_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False, q=None): 
        if 'up_proj' in name:
            qlayers[name].rotate_mode = 'Q'
            qlayers[name].Q = self.Q
            qlayers[name].rotate_weight()

        if 'gate_proj' in name:
            qlayers[name].rotate_mode = 'Q'
            qlayers[name].Q = self.Q
            qlayers[name].rotate_weight()

    def _act_wrapper_to_o_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False, q=None):
        if 'o_proj' in name:
            qlayers[name].rotate_mode = 'Head_in'
            had_K, K = hadamard_utils.get_hadK(self.config.num_attention_heads)
            qlayers[name].had_K = had_K
            qlayers[name].K = K
            qlayers[name].had_dim = self.config.hidden_size//self.config.num_attention_heads
            qlayers[name].rotate_weight()

        if 'v_proj' in name :
            qlayers[name].out_rotate_mode = 'Head_out'
            had_K, K = hadamard_utils.get_hadK(self.config.num_attention_heads)
            qlayers[name].had_K = had_K
            qlayers[name].K = K,
            qlayers[name].had_dim = self.config.hidden_size//self.config.num_attention_heads            

    def _act_wrapper_to_down_proj(self, qlayers, name, is_qlayer_trainable=False, sparse_mode=False):
        if 'down_proj' in name:
                qlayers[name].rotate_mode = 'H'
                had_K, K = hadamard_utils.get_hadK(self.config.intermediate_size)
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].rotate_weight()