from transformers import  LlamaTokenizer
from smoothquant.llama_multi import Int8LlamaForCausalLM , Int8LlamaDecoderLayer
from datasets import load_dataset
import torch
from transformers import AdamW, get_linear_schedule_with_warmup
import torch.nn as nn
import tqdm
def save_parameters(model, filepath):
    scales_dict = {}   
    for i, layer in enumerate(model.model.layers):
        scales_dict[f'layer_{i}'] = {
            'k_proj': layer.self_attn.k_proj.cscales.data,
            'v_proj': layer.self_attn.v_proj.cscales.data,
            'q_proj': layer.self_attn.q_proj.cscales.data,
            'o_proj': layer.self_attn.o_proj.cscales.data,
            'gate_proj': layer.mlp.gate_proj.cscales.data,
            'up_proj': layer.mlp.up_proj.cscales.data,
            'down_proj': layer.mlp.down_proj.cscales.data,
        } 
    torch.save(scales_dict, filepath)
def load_parameters(model, filepath):
    scales_dict = torch.load(filepath)
    for i, layer in enumerate(model.model.layers):
        if f'layer_{i}' in scales_dict:
            layer.self_attn.k_proj.cscales = nn.Parameter(scales_dict[f'layer_{i}']['k_proj'].clone(),requires_grad=True)
            layer.self_attn.v_proj.cscales = nn.Parameter(scales_dict[f'layer_{i}']['v_proj'].clone(),requires_grad=True)
            layer.self_attn.q_proj.cscales = nn.Parameter(scales_dict[f'layer_{i}']['q_proj'].clone(),requires_grad=True)
            layer.self_attn.o_proj.cscales = nn.Parameter(scales_dict[f'layer_{i}']['o_proj'].clone(),requires_grad=True)
            layer.mlp.gate_proj.cscales = nn.Parameter(scales_dict[f'layer_{i}']['gate_proj'].clone(),requires_grad=True)
            layer.mlp.up_proj.cscales = nn.Parameter(scales_dict[f'layer_{i}']['up_proj'].clone(),requires_grad=True)
            layer.mlp.down_proj.cscales = nn.Parameter(scales_dict[f'layer_{i}']['down_proj'].clone(),requires_grad=True)
        else:
            print(f"Warning: No parameters found for layer_{i}")
class Evaluator:
    def __init__(self, dataset, tokenizer, device, n_samples=5):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device
        self.dataset = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt").input_ids.to(device)
        self.n_samples = n_samples
    @torch.no_grad()
    def evaluate(self, model ,seq_length=2048):
        model.eval()
        nlls = []
        for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
            batch = self.dataset[:, (i * seq_length) : ((i + 1) * seq_length)].to(model.device)
            with torch.no_grad():
                lm_logits = model(batch).logits
            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = self.dataset[:, (i * seq_length) : ((i + 1) * seq_length)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() 
            nlls.append(neg_log_likelihood)
        return torch.exp(torch.stack(nlls).sum() / (self.n_samples))
tokenizer = LlamaTokenizer.from_pretrained("/path_to_fp16_llama2_model/")
output_path='/path_to_int8_llama2_model/'
model_int8 = Int8LlamaForCausalLM.from_pretrained(output_path, torch_dtype=torch.bfloat16, device_map=None).cuda()
load_parameters(model_int8,'/path_to_your_initial_scales/scales.pth/')
for name, module in model_int8.model.named_modules():
    if isinstance(module, Int8LlamaDecoderLayer):
        module.self_attn.to('cuda:0')
        module.input_layernorm.to('cuda:0')
        module.mlp.to('cuda:1')
        module.post_attention_layernorm.to('cuda:1')
for param in model_int8.parameters():  
    param.requires_grad = False
for param in model_int8.model.layers[:].parameters():  
    param.requires_grad = True
params_to_clip = [p for p in model_int8.parameters() if p.requires_grad and p.grad is not None]
train_data = load_dataset('/path_to_wikitext/', 'wikitext-103-v1', split='train')
self_dataset = tokenizer("\n\n".join(train_data["text"]), return_tensors="pt").input_ids.to('cuda')
total_tokens = self_dataset.shape[1]
seq_length = 2048
total_batches = (total_tokens + seq_length - 1) // seq_length
optimizer = torch.optim.AdamW(model_int8.model.layers.parameters(), lr=2e-4)
num_epochs = 5
total_steps = total_batches * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
loss_fct = nn.CrossEntropyLoss()
model_int8.train()
dataset = load_dataset('/path_to_wikitext/', 'wikitext-2-raw-v1', split='test')
evaluator = Evaluator(dataset, tokenizer, 'cuda')
ppl_min=100
for epoch in range(num_epochs):
    indices = torch.randperm(total_batches)[:total_batches-1].tolist()
    for step, i in enumerate(tqdm.tqdm(indices, desc=f"Epoch {epoch}")):       
        batch_start = i * seq_length
        batch_end = min((i + 1) * seq_length, total_tokens)
        batch = self_dataset[:, batch_start:batch_end].to('cuda')
        if batch.shape[1] < seq_length:
            continue
        optimizer.zero_grad()
        lm_logits = model_int8(batch).logits
        shift_logits = lm_logits[:, :-1, :].contiguous().float()
        shift_logits = shift_logits + 1e-6  
        shift_labels = self_dataset[:, batch_start:batch_end][:, 1:]
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params_to_clip, max_norm=1.0)
        optimizer.step()
        scheduler.step()  
        if step % 10 == 0: 
            print(f"Epoch: {epoch}, Step: {step}, Loss: {loss.item()}")
        if step % 10 == 0: 
            model_int8.eval()
            with torch.no_grad():
                ppl = evaluator.evaluate(model_int8)
                print(f"Step: {step}, Perplexity: {ppl}")
                if ppl < ppl_min:
                    save_parameters(model_int8, '/path_to_your_scales/scales.pth/')
                    ppl_min = ppl
                print(f"now ppl_min is {ppl_min}")
            model_int8.train()
    torch.cuda.empty_cache()
    model_int8.eval()
    with torch.no_grad():
        ppl = evaluator.evaluate(model_int8)
        print(f"Epoch: {epoch}, Perplexity: {ppl}")
    model_int8.train()  

