import torch
import torch.nn as nn
from SNN.spike_layer import snnEmbedding
from utils.quant_utils import set_op_by_name
from quantize.quant_norm import QuantRMSNorm
from SNN.spike_layer import snnRMSNorm, snnSdpaLlamaAttention, snnLlamaMLP, snnLinear2
from SNN.spike_layer_sim import snnSdpaLlamaAttention_new, snnLlamaMLP_old
from SNN.spike_silu import snnRMSNorm_new
from SNN.spike_layer_tdf import snnRMSNorm_tdf, snnSdpaLlamaAttention_tdf, snnLlamaMLP_tdf
from quantize import int_linear_fake
from transformers.models.llama import LlamaModel
import copy


def replicate_past_key_values(past_key_values, T):

    new_past_kv = []
    for key, value in past_key_values:
        # key, value: [B, H, L, D]
        tmp_key = key.clone()/T
        tmp_value = value.clone()/T
        key_repeat = torch.stack([tmp_key for _ in range(T)], dim=0)       # [T, B, H, L, D]
        value_repeat = torch.stack([tmp_value for _ in range(T)], dim=0)   # [T, B, H, L, D]
        # key_repeat = torch.stack([key.clone() if t == 0 else torch.zeros_like(key) for t in range(T)], dim=0)
        # value_repeat = torch.stack([value.clone() if t == 0 else torch.zeros_like(value) for t in range(T)], dim=0)
        key_repeat = key_repeat.permute(1, 0, 2, 3, 4) # [B, T, H, L, D]
        value_repeat = value_repeat.permute(1, 0, 2, 3, 4)
        new_past_kv.append((key_repeat, value_repeat))
        
    return tuple(new_past_kv)
def wrap_to_snn_model_new(model, args):
    '''
    replace nn.Linear and norm layer to correspond quantization counterparts
    '''
    T = args.T
    avg = args.avg_neuron
    # L = args.L
    for name, module in model.named_modules():
        
        # skip lm_head quantization
        if 'lm_head' in name:
            lm_head = snnLinear2(module, T=T)
            set_op_by_name(model, name, lm_head)
            del module
        # skip quantization of norm for lm_head
        # elif 'model.norm' in name:
        #     continue
        # if 'input_quantizer' in name and module.quant_type=='activation':
        #     neuron = LMHTNeuron(T)
        #     neuron.scale = module.scale
        #     set_op_by_name(model, name, neuron)
        #     del module
        # elif isinstance(module,(RMSN, LlamaRMSNorm)):
        #     quantnorm = QuantRMSNorm(module)
        #     set_op_by_name(model, name, quantnorm)
        #     del module
        # elif isinstance(module, LlamaModel):
        #     snnllama = snnEmbedding(module, T=T)
        #     set_op_by_name(model, name, snnllama)
        #     del module
        elif isinstance(module, nn.Embedding):
            snnembedding = snnEmbedding(module, T=T, avg=avg)
            set_op_by_name(model, name, snnembedding)
            del module
        elif isinstance(module, QuantRMSNorm):
            snnnorm = snnRMSNorm(module,T=T, avg=avg)
            set_op_by_name(model, name, snnnorm)  
            del module 
        elif isinstance(module, int_linear_fake.quantSdpaLlamaAttention):
            snnAttention = snnSdpaLlamaAttention(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnAttention)  
            del module 
        elif isinstance(module, int_linear_fake.quantLlamaMLP):
            snnMLP = snnLlamaMLP_old(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnMLP)  
            del module 
  

def wrap_to_snn_model(model, args):
    '''
    replace nn.Linear and norm layer to correspond quantization counterparts
    '''
    T = args.T
    avg = args.avg_neuron
    # L = args.L
    for name, module in model.named_modules():
        
        # skip lm_head quantization
        if 'lm_head' in name:
            lm_head = snnLinear2(module, T=T)
            set_op_by_name(model, name, lm_head)
            del module
        # skip quantization of norm for lm_head
        # elif 'model.norm' in name:
        #     continue
        # if 'input_quantizer' in name and module.quant_type=='activation':
        #     neuron = LMHTNeuron(T)
        #     neuron.scale = module.scale
        #     set_op_by_name(model, name, neuron)
        #     del module
        # elif isinstance(module,(RMSN, LlamaRMSNorm)):
        #     quantnorm = QuantRMSNorm(module)
        #     set_op_by_name(model, name, quantnorm)
        #     del module
        # elif isinstance(module, LlamaModel):
        #     snnllama = snnEmbedding(module, T=T)
        #     set_op_by_name(model, name, snnllama)
        #     del module
        elif isinstance(module, nn.Embedding):
            snnembedding = snnEmbedding(module, T=T, avg=avg)
            set_op_by_name(model, name, snnembedding)
            del module
        elif isinstance(module, QuantRMSNorm):
            snnnorm = snnRMSNorm_new(module,T=T, avg=avg)
            set_op_by_name(model, name, snnnorm)  
            del module 
        elif isinstance(module, int_linear_fake.quantSdpaLlamaAttention):
            snnAttention = snnSdpaLlamaAttention_new(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnAttention)  
            del module 
        elif isinstance(module, int_linear_fake.quantLlamaMLP):
            snnMLP = snnLlamaMLP(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnMLP)  
            del module 
def wrap_to_snn_model_old(model, args):
    '''
    replace nn.Linear and norm layer to correspond quantization counterparts
    '''
    T = args.T
    avg = args.avg_neuron
    # L = args.L
    for name, module in model.named_modules():
        
        # skip lm_head quantization
        if 'lm_head' in name:
            lm_head = snnLinear2(module, T=T)
            set_op_by_name(model, name, lm_head)
            del module
        # skip quantization of norm for lm_head
        # elif 'model.norm' in name:
        #     continue
        # if 'input_quantizer' in name and module.quant_type=='activation':
        #     neuron = LMHTNeuron(T)
        #     neuron.scale = module.scale
        #     set_op_by_name(model, name, neuron)
        #     del module
        # elif isinstance(module,(RMSN, LlamaRMSNorm)):
        #     quantnorm = QuantRMSNorm(module)
        #     set_op_by_name(model, name, quantnorm)
        #     del module
        # elif isinstance(module, LlamaModel):
        #     snnllama = snnEmbedding(module, T=T)
        #     set_op_by_name(model, name, snnllama)
        #     del module
        elif isinstance(module, nn.Embedding):
            snnembedding = snnEmbedding(module, T=T, avg=avg)
            set_op_by_name(model, name, snnembedding)
            del module
        elif isinstance(module, QuantRMSNorm):
            snnnorm = snnRMSNorm(module,T=T, avg=avg)
            set_op_by_name(model, name, snnnorm)  
            del module 
        elif isinstance(module, int_linear_fake.quantSdpaLlamaAttention):
            snnAttention = snnSdpaLlamaAttention(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnAttention)  
            del module 
        elif isinstance(module, int_linear_fake.quantLlamaMLP):
            snnMLP = snnLlamaMLP_old(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnMLP)  
            del module 
def wrap_to_snn_model_tdf(model, args):
    '''
    replace nn.Linear and norm layer to correspond quantization counterparts
    '''
    T = args.T
    avg = args.avg_neuron
    # L = args.L
    for name, module in model.named_modules():
        
        # skip lm_head quantization
        if 'lm_head' in name:
            lm_head = snnLinear2(module, T=T)
            set_op_by_name(model, name, lm_head)
            del module
        # skip quantization of norm for lm_head
        # elif 'model.norm' in name:
        #     continue
        # if 'input_quantizer' in name and module.quant_type=='activation':
        #     neuron = LMHTNeuron(T)
        #     neuron.scale = module.scale
        #     set_op_by_name(model, name, neuron)
        #     del module
        # elif isinstance(module,(RMSN, LlamaRMSNorm)):
        #     quantnorm = QuantRMSNorm(module)
        #     set_op_by_name(model, name, quantnorm)
        #     del module
        # elif isinstance(module, LlamaModel):
        #     snnllama = snnEmbedding(module, T=T)
        #     set_op_by_name(model, name, snnllama)
        #     del module
        elif isinstance(module, nn.Embedding):
            snnembedding = snnEmbedding(module, T=T, avg=avg)
            set_op_by_name(model, name, snnembedding)
            del module
        elif isinstance(module, QuantRMSNorm):
            snnnorm = snnRMSNorm_tdf(module,T=T, avg=avg)
            set_op_by_name(model, name, snnnorm)  
            del module 
        elif isinstance(module, int_linear_fake.quantSdpaLlamaAttention):
            snnAttention = snnSdpaLlamaAttention_tdf(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnAttention)  
            del module 
        elif isinstance(module, int_linear_fake.quantLlamaMLP):
            snnMLP = snnLlamaMLP_tdf(module, module.config, T, avg=avg)
            set_op_by_name(model, name, snnMLP)  
            del module 
                       
# def wrap_to_snn_llama_model(model, args):
#     '''
#     replace nn.Linear and norm layer to correspond quantization counterparts
#     '''
#     T = args.T
#     # L = args.L
#     for name, module in model.named_modules():
        
        
#         if isinstance(module, LlamaModel):
#             config = copy.deepcopy(module.config)
#             snnllama = snnLlamaModel(module, T, config)
#             set_op_by_name(model, name, snnllama)
#             del module