############## import libraries ################
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM
import os
import torch.nn.functional as F

from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch.nn as nn
from tqdm import tqdm

from datasets import load_dataset

# from transformers import default_data_collator
from torch.utils.data import DataLoader , Dataset
import argparse
import math
import bitsandbytes as bnb

# import ipdb
from typing import List
import copy
import pickle
from dataset_loader import DatasetManager
# let gpus just be 2,3,4,5,6,7 
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5,6,7"
import torch.optim as optim
import matplotlib.gridspec as gridspec
# from llama_wiki_seperate_steps import visualize_output, calculate_ppl

from solver import lasso_regression



parser = argparse.ArgumentParser(description='LLM training')
parser.add_argument('--store_path', type=str, default='xxx/llama_researcher', help='path to store the low rank factors')
parser.add_argument('--store_model_dir', type=str, default='xxx/llama_researcher/tmp/', help='path to store the low rank factors')
parser.add_argument('--pic_path', type=str, default='./pics/', help='path to store the pictures')
parser.add_argument('--ratio', type=float, default=0.1, help='downsample ratio of the test set')
parser.add_argument('--is_debug', type=int, default=1, help='whether to debug and print')
parser.add_argument('--model_name', type=str, default='llama-2-7b-hf', help='model name')
parser.add_argument('--file_id', type=int, default=0, help='file id')

parser.add_argument('--mul_type', type=int, default=4, help='0: only up @ down, 1: (gate * up) @ down')

parser.add_argument('--rank', type=int, default=200, help='rank of the low rank factors')
parser.add_argument('--run_flag', type=int, default=1, help='whether to run the code')

# train
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--num_epochs', type=int, default=10, help='num of epochs')
parser.add_argument('--eval_freq', type=int, default=10, help='evaluation frequency')
parser.add_argument('--batch_size', type=int, default=500, help='batch size')
parser.add_argument('--max_len', type=int, default=400, help='max length of the input')
# parser.add_argument('--dataset_name', type=str, default='wikitext', help='dataset name')
parser.add_argument('--dataset_name', type=str, default='arxiv-math', help='dataset name')
parser.add_argument('--random_size', type=int, default=10000000, help='random size')

parser.add_argument('--dataset_type', type=int, default=0)

parser.add_argument('--teacher_layer', type=int, default=2)
parser.add_argument('--student_layer', type=int, default=1)

parser.add_argument('--device', type=int, default=0, help='device')

parser.add_argument('--if_try', type=int, default=0, help='if try the int8')

args = parser.parse_args()

os.makedirs(args.pic_path, exist_ok=True)

def print_debug(string):
    if args.is_debug:
        print(string)

if args.file_id == 0:
    default_cuda = 'cuda:2'
    teacher_cuda = 'cuda:0'
    student_cuda = 'cuda:7'
elif args.file_id == 1:
    default_cuda = 'cuda:0'
    teacher_cuda = 'cuda:2'
    student_cuda = 'cuda:1'
else:
    default_cuda = 'cuda:3'
    teacher_cuda = 'cuda:5'
    student_cuda = 'cuda:4'

print(f'default_cuda: {default_cuda}, teacher_cuda: {teacher_cuda}, student_cuda: {student_cuda}')

if args.device == -1:
    # all cuda be 'cpu'
    default_cuda = 'cpu'
    teacher_cuda = 'cpu'
    student_cuda = 'cpu'
    

store_path = args.store_path
downsample_ratio = args.ratio


############### general path and variables ################
model_dir = f"xxx/llama/{args.model_name}"

# tokenizer_dir = "xxx/llama/"
# store_path = "xxx/store_wiki2"

print(f'store_path: {store_path}, downsample_ratio: {downsample_ratio}, model_dir: {model_dir}')



# make sure store_path exists
os.makedirs(store_path, exist_ok=True)
os.makedirs(args.store_model_dir, exist_ok=True)



############## set the seed for reproducibility ################
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


############## load the model and tokenizer ################

tokenizer = LlamaTokenizer.from_pretrained(model_dir)

# set padding token for the tokenizer
tokenizer.pad_token = tokenizer.eos_token
# gradient checkpointing

if args.run_flag not in [99, 100] and args.run_flag <1000:
    model = LlamaForCausalLM.from_pretrained(model_dir)
    model = model.to(default_cuda)

    for param in model.parameters():
        param.requires_grad = False

    model.eval()



    weight_1_down, weight_1_up, weight_1_gate = model.model.layers[1].mlp.down_proj.weight.data, model.model.layers[1].mlp.up_proj.weight.data, model.model.layers[1].mlp.gate_proj.weight.data
    weight_2_down, weight_2_up, weight_2_gate = model.model.layers[2].mlp.down_proj.weight.data, model.model.layers[2].mlp.up_proj.weight.data, model.model.layers[2].mlp.gate_proj.weight.data
    weight_3_down, weight_3_up, weight_3_gate = model.model.layers[3].mlp.down_proj.weight.data, model.model.layers[3].mlp.up_proj.weight.data, model.model.layers[3].mlp.gate_proj.weight.data


    print_debug(f'weight_1_down: {weight_1_down.size()}, weight_2_down: {weight_2_down.size()}')
    print_debug(f'weight_1_up: {weight_1_up.size()}, weight_2_up: {weight_2_up.size()}')
    print_debug(f'weight_1_gate: {weight_1_gate.size()}, weight_2_gate: {weight_2_gate.size()}')
    # weight_1_down: torch.Size([4096, 11008])
    # weight_1_up: torch.Size([11008, 4096])
    # weight_1_gate: torch.Size([11008, 4096])


# LlamaForCausalLM(                                                                                                                                                                                                                                                      
#   (model): LlamaModel(                                                                                                                                                                                                                                                 
#     (embed_tokens): Embedding(32000, 4096)                                                                                                                                                                                                                             
#     (layers): ModuleList(                                                                                                                                                                                                                                              
#       (0-31): 32 x LlamaDecoderLayer(                                                                                                                                                                                                                                  
#         (self_attn): LlamaAttention(                                                                                                                                                                                                                                   
#           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (k_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (v_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)                                                                                                                                                                                            
#           (rotary_emb): LlamaRotaryEmbedding()                                                                                                                                                                                                                         
#         )                                                                                                                                                                                                                                                              
#         (mlp): LlamaMLP(                                                                                                                                                                                                                                               
#           (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)                                                    
#           (up_proj): Linear(in_features=4096, out_features=11008, bias=False)                                                      
#           (down_proj): Linear(in_features=11008, out_features=4096, bias=False)                                                    
#           (act_fn): SiLU()                                                                                                         
#         )                                                                                                                          
#         (input_layernorm): LlamaRMSNorm()                                                                                          
#         (post_attention_layernorm): LlamaRMSNorm()                                                                                 
#       )                                                                                                                            
#     )                                                                                                                              
#     (norm): LlamaRMSNorm()                                                                                                         
#   ) 
# )


def stable_rank(A, note = '', device = 'cpu'):
    # Compute the stable rank of a matrix A.
    # just get the approximate singular values, don't need U and V
    original_device = A.device
    A = A.float().to(device)
    S = torch.linalg.svdvals(A)
    # print(f'S.shape: {S.shape}')
    sr = (S**2).sum() / (S**2).max()
    
    print(f'{note}: S[:5]= {S[:5]}, S[-5:]= {S[-5:]}, stable rank: {sr}')
    
    result = sr.cpu().numpy()
    
    # clean memory
    del A, S
    torch.cuda.empty_cache()
    
    return result

def test_device_map():
    print('test device map')
    # load a model for all the device, and see the device for each layer
    new_model = LlamaForCausalLM.from_pretrained(model_dir, device_map = 'auto')
    
    for name, param in new_model.named_parameters():
        print(f'name: {name}, device: {param.device}')
        
    # result -> I found that the different sublayer in the same layer will be assigned to the same device.

def SVD_research(weight_2_down, weight_2_up, weight_2_gate, weight_3_down, weight_3_up, weight_3_gate):
    # move 2 and 3 parameters onto device
    weight_2_down = weight_2_down.to(default_cuda)
    weight_2_up = weight_2_up.to(default_cuda)
    weight_2_gate = weight_2_gate.to(default_cuda)

    print(f'L2 norm of weight_2_down: {torch.norm(weight_2_down, p=2).item()}, weight_2_up: {torch.norm(weight_2_up, p=2).item()}, weight_2_gate: {torch.norm(weight_2_gate, p=2).item()}')
    sr_2_down, sr_2_up, sr_2_gate = stable_rank(weight_2_down, '2_down'), stable_rank(weight_2_up, '2_up'), stable_rank(weight_2_gate, '2_gate')
    # print(f'stable rank of layer2: down: {sr_2_down}, up: {sr_2_up}, gate: {sr_2_gate}')
    
    weight_3_down = weight_3_down.to(default_cuda)
    weight_3_up = weight_3_up.to(default_cuda)
    weight_3_gate = weight_3_gate.to(default_cuda)
    
    print(f'L2 norm of weight_3_down: {torch.norm(weight_3_down, p=2).item()}, weight_3_up: {torch.norm(weight_3_up, p=2).item()}, weight_3_gate: {torch.norm(weight_3_gate, p=2).item()}')
    
    sr_3_down, sr_3_up, sr_3_gate = stable_rank(weight_3_down, '3_down'), stable_rank(weight_3_up, '3_up'), stable_rank(weight_3_gate, '3_gate')
    # print(f'stable rank of layer3: down: {sr_3_down}, up: {sr_3_up}, gate: {sr_3_gate}')
    
    diff_down, diff_up, diff_gate = weight_2_down - weight_3_down, weight_2_up - weight_3_up, weight_2_gate - weight_3_gate
    sr_diff_down, sr_diff_up, sr_diff_gate = stable_rank(diff_down, '23_diff_down'), stable_rank(diff_up, '23_diff_up'), stable_rank(diff_gate, '23_diff_gate')
    
    print(f'L2 norm of diff_down: {torch.norm(diff_down, p=2).item()}, diff_up: {torch.norm(diff_up, p=2).item()}, diff_gate: {torch.norm(diff_gate, p=2).item()}')
    
    print(f'stable rank of diff: down: {sr_diff_down}, up: {sr_diff_up}, gate: {sr_diff_gate}')
    
    # applied reduced SVD
    U_down_diff, S_down_diff, Vt_down_diff = torch.linalg.svd(diff_down, full_matrices=False)
    U_up_diff, S_up_diff, Vt_up_diff = torch.linalg.svd(diff_up, full_matrices=False)
    U_gate_diff, S_gate_diff, Vt_gate_diff = torch.linalg.svd(diff_gate, full_matrices=False)
    
    print(f'SVD error for down: {torch.norm(diff_down - U_down_diff @ torch.diag(S_down_diff) @ Vt_down_diff, p=2).item()}')
    print(f'SVD error for up: {torch.norm(diff_up - U_up_diff @ torch.diag(S_up_diff) @ Vt_up_diff, p=2).item()}')
    print(f'SVD error for gate: {torch.norm(diff_gate - U_gate_diff @ torch.diag(S_gate_diff) @ Vt_gate_diff, p=2).item()}')
    
    U_down_diff_lora, S_down_diff_lora, Vt_down_diff_lora = U_down_diff[:, :args.rank], S_down_diff[:args.rank], Vt_down_diff[:args.rank, :]
    U_up_diff_lora, S_up_diff_lora, Vt_up_diff_lora = U_up_diff[:, :args.rank], S_up_diff[:args.rank], Vt_up_diff[:args.rank, :]
    U_gate_diff_lora, S_gate_diff_lora, Vt_gate_diff_lora = U_gate_diff[:, :args.rank], S_gate_diff[:args.rank], Vt_gate_diff[:args.rank, :]
    
    print(f'reduced (rank = {args.rank}) SVD error for down: {torch.norm(diff_down - U_down_diff_lora @ torch.diag(S_down_diff_lora) @ Vt_down_diff_lora, p=2).item()}')
    print(f'reduced (rank = {args.rank}) SVD error for up: {torch.norm(diff_up - U_up_diff_lora @ torch.diag(S_up_diff_lora) @ Vt_up_diff_lora, p=2).item()}')
    print(f'reduced (rank = {args.rank}) SVD error for gate: {torch.norm(diff_gate - U_gate_diff_lora @ torch.diag(S_gate_diff_lora) @ Vt_gate_diff_lora, p=2).item()}')
    



########################## encode function and text dataset ##############################
class TextDataset(Dataset):
    def __init__(self, encoded_chunks):
        self.encoded_chunks = encoded_chunks

    def __len__(self):
        return len(self.encoded_chunks)

    def __getitem__(self, idx):
        return self.encoded_chunks[idx]

def encode(texts):
    # texts = [text for text in texts if len(text) > 50 and not text.isspace()]
    # print(f'len(texts): {len(texts)} for threshold 50')
    
    concatenated_text = "\n\n".join(texts)
    encodings = tokenizer(concatenated_text, return_tensors='pt')

    # print the length of the input
    print(f'input length: {encodings["input_ids"].shape[1]}')
    
    # Now split these long encodings into smaller chunks of max_len
    max_input_length = args.max_len
    input_ids_chunks = encodings['input_ids'][0].split(max_input_length)
    attention_mask_chunks = encodings['attention_mask'][0].split(max_input_length)
    
    print(f'input_ids_chunks: {len(input_ids_chunks)}, attention_mask_chunks: {len(attention_mask_chunks)}')

    return_list = [{'input_ids': chunk, 'attention_mask': mask} for chunk, mask in zip(input_ids_chunks, attention_mask_chunks)]
    
    # drop the last one
    if len(return_list) > 1:
        return return_list[:-1]
    else:
        return return_list

############## load the dataset ################
def get_dataset(args, tokenizer):
    if args.dataset_name == 'wikitext':
        # Load the train/val datasets
        train_texts = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")['text']
        # val_texts = load_dataset('wikitext', 'wikitext-2-raw-v1', split="validation")['text']
        # train_texts = load_dataset('wikitext', 'wikitext-2-raw-v1', split="test")['text']
        val_texts = load_dataset('wikitext', 'wikitext-2-raw-v1', split="test")['text']
        
        test = load_dataset('wikitext', 'wikitext-2-raw-v1', split="test")
        test_encodings = tokenizer("\n\n".join(test['text']), return_tensors='pt')
        
    elif args.dataset_name == 'arxiv-math':
        template = "[Question]: {}, [Answer]: {}."
        all_data = load_dataset('ArtifactAI/arxiv-math-instruct-50k', split="train")
        dataset_list = [template.format(item['question'], item['answer']) for item in all_data]
        print(f'dataset_list[0]: {dataset_list[0]}')
        
        ################## original split for 3.9 results ##################
        # # split train-val and test randomly, ratio 0.8 vs 0.2 vs 0.2
        # train_ratio, val_ratio, test_ratio = 0.8, 0.1, 0.1
        # train_num, val_num, test_num = int(len(dataset_list) * train_ratio), int(len(dataset_list) * val_ratio), int(len(dataset_list) * test_ratio)

        # # random shuffle
        # np.random.shuffle(dataset_list)
        # train_texts = dataset_list[:train_num]
        # val_texts = dataset_list[train_num : train_num+val_num]
        # test_texts = dataset_list[train_num+val_num: ]
        # # tmp downsample the test set
        # test_texts = test_texts[-int(len(test_texts) * 0.05):] # llama2-7b-chat: 0.01 -> 4.11, 0.05 -> 3.94; llama2-7b: 0.05 -> 3.0; blog report 1.0 -> 3.1
        # print(f'train: {len(train_texts)}, val: {len(val_texts)}, test: {len(test_texts)}')
        
        ################ same split as hicham ################
        # split train and test randomly, ratio 0.99 vs 0.01. Val is also the test set
        train_ratio, test_ratio = 0.99, 0.01
        train_num, test_num = int(len(dataset_list) * train_ratio), int(len(dataset_list) * test_ratio)
        
        # random shuffle
        np.random.shuffle(dataset_list)
        train_texts = dataset_list[:train_num]
        test_texts = dataset_list[train_num: ]

        # not need downsample the test set
        val_texts = test_texts
        
        ########################################################

        print(f'train: {len(train_texts)}, val: {len(val_texts)}, test: {len(test_texts)}')
        
        test_encodings = tokenizer("\n".join(test_texts), return_tensors='pt')
    
    return train_texts, val_texts, test_encodings
        
@torch.no_grad()
def calculate_ppl(model, encodings, stride=512, device=default_cuda):
    # using huggingface default setting
    max_length = model.config.max_position_embeddings
    seq_len = encodings.input_ids.size(1)
    print(f'max_length: {max_length}, seq_len: {seq_len}')
    
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    
    return ppl.item()
    


class LoRALinear(nn.Module):
    def __init__(self, weight, bias=None, rank=5, scale=1.0, type = 0):
        super(LoRALinear, self).__init__()
        self.weight = nn.Parameter(weight, requires_grad=False) # frozen weight
        if bias is not None:
            self.bias = nn.Parameter(bias, requires_grad=False)
        else:
            self.bias = None
        
        self.type = type # determine the type of LoRA
        
        # lora components
        self.u, self.v = self.init_loras(rank, scale, weight.size(0), weight.size(1))
        
        # left matrix and right matrix
        if type in [1, 4, 5, 7]:
            self.lc, self.ld = self.init_loras(rank, scale, weight.size(0), weight.size(0))
            
        if type in [2, 4, 6, 7]:
            self.rc, self.rd = self.init_loras(rank, scale, weight.size(1), weight.size(1))
            
        if type in [7, 8]:
            self.alpha = self.init_scalar(scale)
            print_debug(f'alpha: {self.alpha.size()}')

    
    
    def init_loras(self, rank, scale, dim1, dim2):
        std = 0.00001 * scale
        u = nn.Parameter(torch.zeros(dim1, rank), requires_grad=True)
        v = nn.Parameter(torch.zeros(rank, dim2), requires_grad=True)
        nn.init.normal_(u, mean=0.0, std=std)
        nn.init.normal_(v, mean=0.0, std=std)
        
        return u, v
    
    def init_scalar(self, scale):
        alpha = nn.Parameter(torch.ones(1), requires_grad=True)
        # nn.init.normal_(alpha, mean=1.0, std=0.00001 * scale)

        return alpha
        
    
        
    def forward(self, x):
        # Type 0
        if self.type == 0:
            new_weight = self.weight + self.u @ self.v
        elif self.type == 1:
            new_weight = self.weight + self.lc @ (self.ld @ self.weight) + self.u @ self.v
        elif self.type == 2:
            new_weight = self.weight + self.weight @ self.rc @ self.rd + self.u @ self.v
        elif self.type == 4:    
            new_weight = self.weight + self.lc @ (self.ld @ self.weight) + self.weight @ self.rc @ self.rd + self.u @ self.v
            
        # yd's idea -> not work for different loras init. the eval loss has almost no change. I think the reason is that the lora ranks are too low.
        elif self.type == 5:
            new_weight = self.lc @ (self.ld @ self.weight) + self.u @ self.v
        elif self.type == 6:
            new_weight = self.weight @ self.rc @ self.rd + self.u @ self.v
        
        elif self.type == 7:
            new_weight = self.alpha * self.weight + self.lc @ (self.ld @ self.weight) + self.weight @ self.rc @ self.rd + self.u @ self.v
        elif self.type == 8:
            new_weight = self.alpha * self.weight + self.u @ self.v
        
                    
        if self.bias is not None:
            return x @ new_weight.t() + self.bias
        else:
            return x @ new_weight.t()
        

class SelfDefineMLP(nn.Module):
    def __init__(self, gate_weights, up_weights, down_weights):
        super(SelfDefineMLP, self).__init__()
        
        # create the gate, up, down weights
        self.gate = nn.Linear(gate_weights.size(1), gate_weights.size(0), bias=False)
        self.up = nn.Linear(up_weights.size(1), up_weights.size(0), bias=False)
        self.down = nn.Linear(down_weights.size(0), down_weights.size(1), bias=False)
        
        # initialize the weights
        self.gate.weight = nn.Parameter(gate_weights)
        self.up.weight = nn.Parameter(up_weights)
        self.down.weight = nn.Parameter(down_weights)
        
    def forward(self, x):
        gate = F.silu(self.gate(x))
        up = self.up(x)
        down = self.down(gate * up)
        
        return down
    
    
class LoRAMLP(nn.Module):
    def __init__(self, gate_weights=None, up_weights=None, down_weights=None, rank=5, scale=1.0):
        super(LoRAMLP, self).__init__()
        
        # create the gate, up, down weights
        if gate_weights is not None:
            self.gate = LoRALinear(gate_weights, rank=rank, scale=scale, type=args.mul_type)
            self.up = LoRALinear(up_weights, rank=rank, scale=scale, type=args.mul_type)
            self.down = LoRALinear(down_weights, rank=rank, scale=scale, type=args.mul_type)
        else:
            self.gate = None
            self.up = None
            self.down = None
        
    def forward(self, x):
        gate = F.silu(self.gate(x))
        up = self.up(x)
        down = self.down(gate * up)
        
        return down

class LoRALlama2(nn.Module):
    def __init__(self, llama2_model, ref_layers, target_layers):
        super(LoRALlama2, self).__init__()
        self.llama2_model = llama2_model

        self.ref_layers = ref_layers
        self.target_layers = target_layers
        # self.rank_list = rank_list

        # first freeze the base model
        for name, param in self.llama2_model.named_parameters():
            if 'weight' in name:
                param.requires_grad = False
                
        self.replace_stored_loramlp(args.store_model_dir)
                
    
    def replace_stored_loramlp(self, store_model_dir):
        # replace the mlp in the target layers with LoRAMLP, use the stored LoRAMLP to replace the original mlp
        for ref_layer, target_layer in zip(self.ref_layers, self.target_layers):
            print(f'replace layer {target_layer} with layer {ref_layer}')
            # loramlp_state_dict = torch.load(os.path.join(store_model_dir, f'tea{target_layer}_stu{ref_layer}_best.pth'))
            # create the LoRAMLP using weight2
            loramlp = LoRAMLP(gate_weights=weight_2_gate, up_weights=weight_2_up, down_weights=weight_2_down, rank=args.rank, scale=0.01)
            
            loramlp.load_state_dict(torch.load(os.path.join(store_model_dir, f'tea{target_layer}_stu{ref_layer}_best.pth')))

            if loramlp is None:
                print('loramlp is None')
            else:
                print('loramlp is not None')
                
            # replace the mlp
            # self.llama2_model.model.layers[target_layer].mlp = loramlp
            self.llama2_model.model.layers[target_layer].mlp.gate = loramlp.gate
            self.llama2_model.model.layers[target_layer].mlp.up = loramlp.up
            self.llama2_model.model.layers[target_layer].mlp.down = loramlp.down
            
            print(f'layer {target_layer} replaced, llama2_model is none: {self.llama2_model is None}')
            

    # def forward(self, x):
    #     return self.llama2_model(x)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.llama2_model(input_ids, attention_mask=attention_mask, labels=labels)
        

@torch.no_grad()
def evaluate_model(student_mlp, teacher_mlp, dataloader, criterion, input_process = lambda x: x):
    student_mlp.eval()
    teacher_mlp.eval()
        
    with torch.no_grad():
        total_loss = 0.0
        flag = 0
        for data in tqdm(dataloader, desc='evaluating'):
            input = input_process(data)
            input = input.to(student_cuda)

            outputs = student_mlp(input)
            outputs_teacher = teacher_mlp(input)
            
            loss = criterion(outputs, outputs_teacher)
            
            total_loss += loss.item()
            
            if flag==0:
                
                print_debug(f'input = {torch.norm(input, p=2, dim=1).mean().item()}') # 0.93
                print_debug(f'outputs_teacher = {torch.norm(outputs_teacher, p=2, dim=1).mean().item()}') # 0.026 at the end
                print_debug(f'outputs_teacher - input = {torch.norm(outputs_teacher - input, p=2, dim=1).mean().item()}') # 0.93
                
                print_debug(f'outputs = {torch.norm(outputs, p=2, dim=1).mean().item()}') # 0.025 at the end
                print_debug(f'outputs - outputs_teacher = {torch.norm(outputs - outputs_teacher, p=2, dim=1).mean().item()}') # 0.028 at the end
                
                tmp_mlp = model.model.layers[args.student_layer].mlp.to(student_cuda)
                outputs_tmp = tmp_mlp(input)
                print_debug(f'outputs_teacher - outputs_tmp = {torch.norm(outputs_teacher - outputs_tmp, p=2, dim=1).mean().item()}') # 0.26 at the end
                
                flag += 1
            # torch.norm(outputs_teacher - input) = 4.5, torch.norm of input, outputs, outputs_teacher are 4.5, 0.36, 0.38 after 1 epoch for type 0
            # it's much larger than using weights from neighbor. So using weights from neighbor is necessary.
            
    return total_loss / len(dataloader)
                   
def L2_loss(outputs, outputs_teacher):
    return torch.norm(outputs - outputs_teacher, p=2, dim=1).mean()       

def train_student(args, student_mlp, teacher_mlp, train_loader, eval_loader, optimizer, scheduler, criterion, input_process = lambda x: x):
    teacher_mlp.eval()
    
    # baseline_eval_loss = evaluate_model(teacher_mlp, teacher_mlp, eval_loader, criterion, input_process)
    # print(f'baseline eval loss: {baseline_eval_loss}')
    
    best_eval_loss = evaluate_model(student_mlp, teacher_mlp, eval_loader, criterion, input_process)
    print(f'\n\n########## initial eval loss: {best_eval_loss} ##########\n')
    
    for epoch in range(args.num_epochs):
        print(f'=> Epoch: {epoch}')
        
        student_mlp.train()
        for data in tqdm(train_loader, desc='training'):
            input = input_process(data)
            input = input.to(student_cuda)
            
            optimizer.zero_grad()
            
            outputs = student_mlp(input)
            outputs_teacher = teacher_mlp(input)
            
            loss = criterion(outputs, outputs_teacher)
            
            loss.backward()
            optimizer.step()
            
        # scheduler
        if scheduler is not None:
            scheduler.step()
            print(f'current lr: {scheduler.get_last_lr()}')
            
        if epoch % args.eval_freq == 0 or epoch == args.num_epochs - 1:
            eval_loss = evaluate_model(student_mlp, teacher_mlp, eval_loader, criterion, input_process)
            print(f'\n\n########## eval loss: {eval_loss} #########\n')
            
            for name, param in student_mlp.named_parameters():
                if param.grad is not None:
                    print(f'name: {name}, norm: {torch.norm(param, p=2).item()}, grad norm: {torch.norm(param.grad, p=2).item()}')
            
            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                torch.save(student_mlp.state_dict(), os.path.join(args.store_model_dir, f'tea{args.teacher_layer}_stu{args.student_layer}_best.pth'))
                print(f'best model saved to {os.path.join(args.store_model_dir, f"tea{args.teacher_layer}_stu{args.student_layer}_best.pth")}')
            
        

def visualize_singular_values(s, output_dir, file_name):
    
    # if s is tensor, convert it to numpy
    if isinstance(s, torch.Tensor):
        s = s.cpu().numpy()
    
    log_s = np.log(s)
    
    os.makedirs(output_dir, exist_ok=True)

    plt.figure(figsize=(10, 6))
    plt.plot(log_s)
    plt.title('Log Singular Values for Layer 1')
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.grid(True)

    output_path = os.path.join(output_dir, file_name) #'log_singular_values_layer1.png')
    plt.savefig(output_path)
    print(f'log singular values saved to {output_path}')
    
    

def generate_random_orthogonal_matrix(dim1, dim2):
    # assume dim1 >= dim2
    
    # method 1: too slow
    # from scipy.stats import ortho_group
    # dim = max(dim1, dim2)
    # H = ortho_group.rvs(dim)
    # result = H[:dim1, :dim2]
    
    # method 2: use QR decomposition
    A = np.random.normal(size=(dim1, dim2))  
    Q, _ = np.linalg.qr(A)
    Q = torch.tensor(Q, dtype=torch.float32)
    
    # check the orthogonality
    print_debug(f'Q.size: {Q.shape}')
    # print_debug(f'generate ortho dist: {torch.norm(Q.t() @ Q - torch.eye(dim2), p=2)}') # 1e-5
    
    return Q
    

def generate_random_data(args, dim):    
    # generate random data, and return train/eval dataloader
    # generate random data
    
    if args.dataset_type == 0:
        random_data = torch.randn(args.random_size, dim) * 0.1
    
    elif args.dataset_type == 1:
        
        # make sure that Z^T Z is approximately I
        # Z = torch.randn(args.random_size, dim) / 100 / math.sqrt(args.random_size)  # also 0.29 -> 0.028! so maybe not need to be so complicated as orthogonal matrix
        Z = torch.randn(args.random_size, dim) / 500 #(args.random_size // 1000)
        # Z = generate_random_orthogonal_matrix(args.random_size, dim) # 0.29 -> 0.028
        
        if (1 == 0) and os.path.exists(os.path.join(args.store_path, f'S_complete_{args.teacher_layer}.pt')):
            Vt = torch.load(os.path.join(args.store_path, f'Vt_complete_{args.teacher_layer}.pt'))
            S = torch.load(os.path.join(args.store_path, f'S_complete_{args.teacher_layer}.pt'))
            print_debug(f'read Vt and S from {args.store_path}, Vt size: {Vt.size()}, S size: {S.size()}')
        else:
            # generate random data that has the similar covariance as hook data
            # target_data = torch.load(f'xxx/llama_reader/train_inputs_{args.teacher_layer}.pt').view(-1, dim)
            target_data = torch.load(f'xxx/llama_reader/train_inputs_complete_{args.teacher_layer}.pt').view(-1, dim)
            print_debug(f'== target_data: {target_data.size()}') # target data is B x max_len x dim, we need to reshape it to (B*max_len) x dim
        
            ######################################## original code ########################################
            # # don't use cuda to calculate the svd, error is larger!!
            # U, S, Vt = torch.linalg.svd(target_data, full_matrices=False)
            
            # print_debug(f'dist: {torch.norm(target_data - U @ torch.diag(S) @ Vt, p=2)}')
            # print_debug(f'dist between cov {torch.dist(target_data.t() @ target_data, Vt.t() @ torch.diag(S**2) @ Vt)}')
            # # dist: 0.0009, dist between cov 0.02        
            # # # print_debug(f'dist between covariance matrix: {torch.dist(random_data.t() @ random_data, target_data.t() @ target_data)}') # 0.02 if use generate_random_orthogonal_matrix. 273 if use torch.randn
            
            ######################################## modified code ########################################
            
            # # since N maybe too large, we can decompose using target_data.t() @ target_data
            Cov = target_data.t() @ target_data
            # print_debug(f'Cov size: {Cov.size()}')
            # V, S2, Vh = torch.linalg.svd(Cov)
            # print_debug(f'V size: {V.size()}, S2 size: {S2.size()}, dist: {torch.norm(Cov - V @ torch.diag(S2) @ Vh, p=2)}, V, Vh dist: {torch.norm(V - Vh.t(), p=2)}')
            # # dist: 8.38, V, Vh dist: 0.002
            
            # Vt = (Vh + V.t()) / 2
            # S = torch.sqrt(S2)
            
            Cov_numpy = Cov.cpu().numpy()
            print_debug(f'Cov_numpy size: {Cov_numpy.shape}')
            eigvals, eigvecs = np.linalg.eigh(Cov_numpy)
            print_debug(f'eigvecs {eigvecs.shape}, dist between cov {np.linalg.norm(Cov_numpy - eigvecs @ np.diag(eigvals) @ eigvecs.T)}, eigvecs @ eigvecs.T dist: {np.linalg.norm(eigvecs @ eigvecs.T - np.eye(dim))}')
            # 0.07, 2e-5
            Vt = torch.tensor(eigvecs, dtype=torch.float32).t()
            S = torch.sqrt(torch.tensor(eigvals, dtype=torch.float32))
            
            # save Vt and S
            torch.save(Vt, os.path.join(args.store_path, f'Vt_complete_{args.teacher_layer}.pt'))
            torch.save(S, os.path.join(args.store_path, f'S_complete_{args.teacher_layer}.pt'))
            
            ########################################################
            # K = target_data.numpy()
            # Sigma = np.cov(K, rowvar=False)
            # eigvals, eigvecs = np.linalg.eigh(Sigma)
            # print_debug(f'dist between cov {np.linalg.norm(Sigma - eigvecs @ np.diag(eigvals) @ eigvecs.T)}')
            
            # S = torch.sqrt(torch.tensor(eigvals, dtype=torch.float32))
            # Vt = torch.tensor(eigvecs, dtype=torch.float32).t()
        
        print_debug(f'S[:5]: {S[:5]}, S[-5:]: {S[-5:]}')
        random_data = Z @ torch.diag(S) @ Vt
        # visualize_singular_values(S, args.pic_path, 'log_singular_values_teacher_input.png')
        
        
    elif args.dataset_type == 2:
        # random_data = torch.load(f'xxx/llama_reader/train_inputs_{args.teacher_layer}.pt').view(-1, dim)
        random_data = torch.load(f'xxx/llama_reader/train_inputs_complete_{args.teacher_layer}.pt').view(-1, dim)
        print_debug(f'== random_data: {random_data.size()}')
    
    length = random_data.size(0)
    # split the data into train and eval
    train_size = int(length * 0.8)
    
    print(f'train_size: {train_size}, eval_size: {length - train_size}, random_data size: {random_data.size()}')
    
    train_loader = DataLoader(random_data[:train_size, :], batch_size=args.batch_size, shuffle=True)
    eval_loader = DataLoader(random_data[train_size:, :], batch_size=args.batch_size, shuffle=False)
    
    print(f'train_loader size: {len(train_loader)}, eval_loader size: {len(eval_loader)}')
    return train_loader, eval_loader
    
def lora_recovery():
    # update lora components with frozen weight1 to generate similar outputs as weight 2
    teacher_layer, student_layer = args.teacher_layer, args.student_layer
    
    ##### STEP1: define student and teacher layers #####
    teacher_mlp = copy.deepcopy(model.model.layers[teacher_layer].mlp)
    teacher_mlp = teacher_mlp.to(student_cuda)
    teacher_mlp.eval() 
    # teacher mlp -> no gradient
    for param in teacher_mlp.parameters():
        param.requires_grad = False
    
    
    weight_gate = model.model.layers[student_layer].mlp.gate_proj.weight.data
    weight_up = model.model.layers[student_layer].mlp.up_proj.weight.data
    weight_down = model.model.layers[student_layer].mlp.down_proj.weight.data
    
    student_mlp = LoRAMLP(weight_gate, weight_up, weight_down, rank=args.rank, scale=0.01).to(student_cuda)
    for name, param in student_mlp.named_parameters():
        if param.requires_grad:
            print(f'name: {name}, requires_grad: {param.requires_grad}, param size: {param.size()}')
    
    
    ##### STEP2: define the optimizer, scheduler, and criterion #####
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, student_mlp.parameters()), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    criterion = L2_loss
    
    ##### STEP3: define the dataloader #####
    train_loader, eval_loader = generate_random_data(args, weight_gate.size(1))

    ##### STEP4: train the student mlp #####
    train_student(args, student_mlp, teacher_mlp, train_loader, eval_loader, optimizer, scheduler, criterion)
    

    
@torch.no_grad()
def visualize_output(model, device):
    model.eval()
    model.to(device)
    
    print(f'====================================================================')
    prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
              "Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
              "there?")
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
    output = tokenizer.batch_decode(generated_ids)[0]
    print(output)
    
    print(f'====================================================================')
    prompt = ("A chat between a curious girl and an expert.\n\nGirl: Can you introduce Statue of Liberty for me in 100 words?\nExpert: Sure, "
              "the Statue of Liberty is a colossal neoclassical sculpture on Liberty Island in New York Harbor in New York City, in the United States. It was designed by ")
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
    output = tokenizer.batch_decode(generated_ids)[0]
    print(output)
    
    print(f'====================================================================')
    prompt = ("[Question]: What is the mass correction of a light pseudoscalar decay? [Answer]: The mass correction of a light pseudoscalar decay refers to the effects of the masses of the final state particles on the decay width of a particle. These corrections can be")
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=150, do_sample=False)
    output = tokenizer.batch_decode(generated_ids)[0]
    print(output)



def llama2_tester():
    # test the llama2 model with stored LoRAMLP
    ref_layers = [args.student_layer]
    target_layers = [args.teacher_layer]
    # [2]:[3] -> 3.4
    # [2,4,6,8]:[3,5,7,9] -> 10.5
    # [2,4,6]:[3,5,7] -> 4.7
    # [1 predict 1] -> 3M
    # [28]:[29] -> 3.25
    # [20]:[21] -> 3.11
    # 4:5 -> 3.19
    # 10:11 -> 3.06
    # 16:17 -> 3.15
    # 24:25 -> 3.10
    # ref_layers = [2,4,6,8,10,12,14,16,18,20,22,24,26,28]
    # target_layers = [3,5,7,9,11,13,15,17,19,21,23,25,27,29]
    
    test_model = LlamaForCausalLM.from_pretrained(model_dir)
    
    for ref_layer, target_layer in zip(ref_layers, target_layers):
        loramlp = LoRAMLP(weight_2_gate, weight_2_up, weight_2_down, rank=args.rank, scale=0.01).to(default_cuda)
        loramlp.load_state_dict(torch.load(os.path.join(args.store_model_dir, f'tea{target_layer}_stu{ref_layer}_best.pth')))
        
        # replace test_model with loramlp
        test_model.model.layers[target_layer].mlp = loramlp
        
    test_model.eval()
    
    test_model = test_model.to(teacher_cuda)
    
    visualize_output(test_model, device=teacher_cuda)
    
    # get the test dataset
    train_texts, val_texts, test_encodings = get_dataset(args, tokenizer)
    
    ppl = calculate_ppl(test_model, test_encodings, stride=512, device=teacher_cuda)
    print(f'ppl: {ppl}')

def gate_visualizer():
    # visualize the gate weights for each layer

    model = LlamaForCausalLM.from_pretrained(model_dir)
    model.eval()
    
    for i in range(32):
        weight_gate = model.model.layers[i].mlp.gate_proj.weight.data
        print_debug(f'weight_gate size: {weight_gate.size()}')
        
        # downsample the weight_gate to 0.1
        ratio = 0.1
        column_idx = torch.randperm(weight_gate.size(1))[:int(weight_gate.size(1) * ratio)]
        row_idx = torch.randperm(weight_gate.size(0))[:int(weight_gate.size(0) * ratio)]
        
        vis_weight_gate = weight_gate[row_idx][:, column_idx]
        
        # visualize 
        plt.figure(figsize=(10, 6))
        plt.imshow(vis_weight_gate.cpu().numpy(), cmap='hot', interpolation='nearest')
        plt.colorbar()
        plt.title(f'Gate Weights for Layer {i}')
        
        output_path = os.path.join(args.pic_path, f'gate_weights_layer{i}.png')
        plt.savefig(output_path)
        print(f'gate weights saved to {output_path}')
        
        
        
    
    
    

def verify_selfdefine_mlp(): # verified, ok!
    selfdefinemlp1 = SelfDefineMLP(weight_1_gate, weight_1_up, weight_1_down)
    ori_mlp1 = model.model.layers[1].mlp
    
    # compare the forward results
    x = torch.randn(10, weight_1_gate.size(1)).to(default_cuda) # size is 10 x 4096
    print_debug(f'x[:3, :3]: {x[:3, :3]}')
    
    # test  SelfDefineMLP
    y_sdm1 = selfdefinemlp1(x)
    y_ori1 = ori_mlp1(x)
    delta_y = torch.norm(y_sdm1 - y_ori1, p=2)
    
    print(f'delta_y: {delta_y.item()}, y1: {y_sdm1.size()}, y2: {y_ori1.size()}') # 0
    
    # test different layers
    ori_mlp2 = model.model.layers[2].mlp
    y_ori2 = ori_mlp2(x)
    delta_layer = torch.norm(y_ori1 - y_ori2, p=2)
    print(f'delta_layer: {delta_layer.item()}') # 523
    
    # test LoRAMLP
    loramlp1 = LoRAMLP(weight_1_gate, weight_1_up, weight_1_down).to(default_cuda)
    y_lora1 = loramlp1(x)
    delta_lora = torch.norm(y_lora1 - y_ori1, p=2)
    delta_lora_2 = torch.norm(y_lora1 - y_ori2, p=2)
    print(f'delta_lora: {delta_lora.item()}, delta_lora_2: {delta_lora_2.item()}') # 7e-5, 523
    
@torch.no_grad()
def shuffle_researcher(matrix1, matrix2, shuffle_dim, downsample_ratio):
    # given m1, m2 \in R^{d1 \times d2}, first downsample m1 in shuffle_dim to get m1', 
    # then normalize m1' and m2 for each row/column in shuffle_dim, and then find the closet m2 to each m1' shuffle_dim (maximize cosine similarity)
    # print the average cosine similarity for the best match
    
    if shuffle_dim == 1:
        # always keep the first dimension to be shuffled
        matrix1 = matrix1.t()
        matrix2 = matrix2.t()
    
    # put matrix onto the same device 
    matrix1 = matrix1.to(student_cuda)
    matrix2 = matrix2.to(student_cuda)
    
    # downsample matrix1
    random_indices = torch.randperm(matrix1.size(0))[:int(matrix1.size(0) * downsample_ratio)].to(student_cuda)
    matrix1_down = matrix1.index_select(0, random_indices)
    
    print_debug(f'matrix1_down: {matrix1_down.size()}, matrix2: {matrix2.size()}, matrix1: {matrix1.size()}')
    
    # normalize matrix1_down and matrix2
    N1 = F.normalize(matrix1_down, p=2, dim=1) # normalize along the shuffle_dim
    N2 = F.normalize(matrix2, p=2, dim=1)
    
    # find the closest matrix2 to matrix1_down
    # cosine similarity
    sim = torch.matmul(N1, N2.t())
    
    # print_debug(f'sim: {sim.size()}, avg is {torch.mean(torch.abs(sim)).item()}, sim[:3, :3]: {sim[:3, :3]}')
    # sim: torch.Size([110, 11008]), avg is 0.012809892185032368,
    
    # find the best match
    best_match = torch.argmax(sim, dim=1)
    
    best_N2 = torch.index_select(N2, 0, best_match)
    
    # calculate the average cosine similarity
    avg_cosine = torch.mean(torch.sum(N1 * best_N2, dim=1))
    print(f'avg_cosine: {avg_cosine.item()}')
    
    
    ############ also use L2 loss rather than cosine similarity to find the best match
    # L2 loss, on N1 and N2
    L2_loss = torch.norm(N1[:, None, :] - N2[None, :, :], p=2, dim=2) # avg is 1.4141640663146973
    print_debug(f'L2_loss: {L2_loss.size()}, avg is {torch.mean(L2_loss).item()}, L2_loss[:3, :3]: {L2_loss[:3, :3]}')
    # 2_loss: torch.Size([110, 11008]), avg is 1.4141640663146973
    
    # find the best match
    best_match_L2 = torch.argmin(L2_loss, dim=1)
    
    best_N2_L2 = torch.index_select(N2, 0, best_match_L2)
    
    # calculate the average cosine similarity
    avg_L2 = torch.mean(torch.sum(N1 * best_N2_L2, dim=1))
    # calculate the average L2 loss
    avg_L2_loss = torch.mean(torch.norm(N1 - best_N2_L2, p=2, dim=1))
    
    print(f'avg_L2: {avg_L2.item()}, avg_L2_loss: {avg_L2_loss.item()}')

    
    
    # ipdb.set_trace()

@torch.no_grad()
def mul_mlp(gate, up, down):
    if args.mul_type == 0:
        mlp = (up).t() @ down.t()
    else:
        mlp = (gate * up).t() @ down.t()
    
    return mlp

def vis_fc_layers(weight, name, max_dim = 1000):
    # if the weight matrix is too large, first sample evenly along the two dimensions
    if weight.size(0) > max_dim:
        weight = weight[::weight.size(0) // max_dim, ::weight.size(1) // max_dim]  
        
        print(f'weight size is too large, sample to {weight.size()}')
    
    # visualize the i-th layer's MLP parameters in color
    plt.figure(figsize=(20, 20))
    plt.imshow(weight.cpu().numpy(), cmap='Blues', interpolation='nearest')
    plt.colorbar()
    plt.title(f'{name} weight')
    plt.savefig(os.path.join(args.pic_path, f'{name}.png'))

@torch.no_grad()
def mlp_researcher(gate_1, up_1, down_1, gate_2, up_2, down_2):
    # get the multiplication of the three matrices, skip the activation function for now
    # gate: d1 x d2, up: d1 x d2, down: d2 x d1
    
    # put matrix onto the same device
    gate_1 = gate_1.to(student_cuda)
    up_1 = up_1.to(student_cuda)
    down_1 = down_1.to(student_cuda)
    gate_2 = gate_2.to(student_cuda)
    up_2 = up_2.to(student_cuda)
    down_2 = down_2.to(student_cuda)
    
    # calculate the multiplication
    mlp_1 = mul_mlp(gate_1, up_1, down_1)
    mlp_2 = mul_mlp(gate_2, up_2, down_2)
    print_debug(f'mlp_1: {mlp_1.size()}, mlp_2: {mlp_2.size()}')
    
    # normalize mlps with their average value on the diagonal
    avg_diag_1 = torch.mean(torch.diagonal(mlp_1))
    avg_diag_2 = torch.mean(torch.diagonal(mlp_2))
    print_debug(f'avg_diag_1: {avg_diag_1.item()}, avg_diag_2: {avg_diag_2.item()}')
    
    N1 = mlp_1 / avg_diag_1
    N2 = mlp_2 / avg_diag_2
    
    # calculate the stable rank of gate, up, down, mlp, and N
    sr_gate_1 = stable_rank(gate_1); sr_up_1 = stable_rank(up_1); sr_down_1 = stable_rank(down_1); sr_mlp_1 = stable_rank(mlp_1); sr_N1 = stable_rank(N1)
    sr_gate_2 = stable_rank(gate_2); sr_up_2 = stable_rank(up_2); sr_down_2 = stable_rank(down_2); sr_mlp_2 = stable_rank(mlp_2); sr_N2 = stable_rank(N2)
    
    print_debug(f'sr_gate_1: {sr_gate_1}, sr_up_1: {sr_up_1}, sr_down_1: {sr_down_1}, sr_mlp_1: {sr_mlp_1}, sr_N1: {sr_N1}')
    print_debug(f'sr_gate_2: {sr_gate_2}, sr_up_2: {sr_up_2}, sr_down_2: {sr_down_2}, sr_mlp_2: {sr_mlp_2}, sr_N2: {sr_N2}')
    
    # calculate the L2 loss compared to each other, and to the identity matrix
    L2_12 = torch.norm(N1 - N2, p=2)
    I = torch.eye(N1.size(0)).to(student_cuda)
    L2_1I = torch.norm(N1 - I, p=2)
    L2_2I = torch.norm(N2 - I, p=2)
    
    print(f'L2_12: {L2_12.item()}, L2_1I: {L2_1I.item()}, L2_2I: {L2_2I.item()}') 
    
    L2_ori_12 = torch.norm(mlp_1 - mlp_2, p=2)
    L2_ori_1I = torch.norm(mlp_1 - torch.eye(mlp_1.size(0), p=2).to(student_cuda))
    L2_ori_2I = torch.norm(mlp_2 - torch.eye(mlp_2.size(0), p=2).to(student_cuda))
    
    print(f'L2_ori_12: {L2_ori_12.item()}, L2_ori_1I: {L2_ori_1I.item()}, L2_ori_2I: {L2_ori_2I.item()}')

    print(f'N1[:5, :5]: {N1[:5, :5]}\nN2[:5, :5]: {N2[:5, :5]}\n')
    print(f'N1[:-5, :-5]: {N1[-5:, -5:]}\nN2[:-5, :-5]: {N2[-5:, -5:]}')
    middle_idx = N1.size(0) // 2
    # middle
    print(f"middle: N1[{middle_idx}:{middle_idx+5}, {middle_idx}:{middle_idx+5}]: {N1[middle_idx:middle_idx+5, middle_idx:middle_idx+5]}\nN2[{middle_idx}:{middle_idx+5}, {middle_idx}:{middle_idx+5}]: {N2[middle_idx:middle_idx+5, middle_idx:middle_idx+5]}")
    
    # L2_12: 225.32176208496094, L2_1I: 150.0594482421875, L2_2I: 179.70004272460938  if only up @ down
    
    # visualize the weight matrices
    vis_fc_layers(mlp_1, f'mlp_1_t{args.mul_type}')
    vis_fc_layers(mlp_2, f'mlp_2_t{args.mul_type}')
    vis_fc_layers(N1, f'N1_t{args.mul_type}')
    vis_fc_layers(N2, f'N2_t{args.mul_type}')
    
    # vis_fc_layers(gate_1, 'gate_1')
    # vis_fc_layers(up_1, 'up_1')
    # vis_fc_layers(down_1, 'down_1')
    # vis_fc_layers(gate_2, 'gate_2')
    # vis_fc_layers(up_2, 'up_2')
    # vis_fc_layers(down_2, 'down_2')
    
    
    
# shuffle_researcher(weight_1_down, weight_2_down, 1, downsample_ratio=0.01)
# shuffle_researcher(weight_1_gate, weight_2_gate, 0, downsample_ratio=0.01)

# mlp_researcher(weight_1_gate, weight_1_up, weight_1_down, weight_2_gate, weight_2_up, weight_2_down)
    
# verify_selfdefine_mlp()



def process_matrix(matrix, k, p, r_prime=100):
    print(f'matrix size: {matrix.size()}, k: {k}, p: {p}, r_prime: {r_prime}')
    # 确保输入值在正确的范围内
    assert 0 <= p <= k-1, "p must be in the range [0, k-2]"
    
    # 计算矩阵的最大值，验证是否为 2^k - 1
    max_value = torch.max(matrix)
    assert max_value == 2**k - 1, f"Maximum value should be {2**k - 1}, but got {max_value}"
    
    print(f'max_value: {max_value}, 2^k - 1: {2**k - 1}, 2^p: {2**p}, 2^(p+1)-1: {2**(p+1)-1}')
    
    # 创建掩码，保留 [2^p, 2^(p+1)) 范围内的值
    # mask = (matrix >= 2**p) & (matrix < 2**(p+1))
    upper = 2**(k-1)-0.5 + 2**(p-1)
    lower = 2**(k-1)-0.5 - 2**(p-1)
    mask = (matrix >= upper) | (matrix <= lower)
    print(f'upper: {upper}, lower: {lower}')
    
    # 应用掩码，将范围外的值设为0
    filtered_matrix = torch.where(mask, matrix, torch.zeros_like(matrix))
    
    print(f'mean of abs(matrix): {torch.mean(torch.abs(matrix)).item()}, mean of abs(filtered_matrix): {torch.mean(torch.abs(filtered_matrix)).item()}, mean of abs(matrix - filtered_matrix): {torch.mean(torch.abs(matrix - filtered_matrix)).item()}')
    
    # count the non-zero elements in filtered_matrix
    non_zero = torch.count_nonzero(filtered_matrix)
    print(f'non_zero: {non_zero}')
    
    # 计算新矩阵的秩
    rank = torch.linalg.matrix_rank(filtered_matrix).item()
    print(f'!! rank of filtered_matrix: {rank}') # 9417 emm
    
    # 执行SVD分解
    U, S, Vh = torch.linalg.svd(filtered_matrix, full_matrices=False)
    # U, S, V = torch.svd(filtered_matrix, full_matrices=False)
    
 
    # 只保留前r'个奇异值
    if r_prime == 0:
        r_prime = rank // 2
        print(f'r_prime is set to {r_prime}')
        
    S_truncated1 = torch.zeros_like(S)
    S_truncated1[:r_prime] = S[:r_prime]
    
    S_truncated2 = torch.zeros_like(S)
    S_truncated2[:rank // 2] = S[:rank // 2]
    
    S_truncated3 = torch.zeros_like(S)
    S_truncated3[:rank // 4] = S[:rank // 4]
    
    S_truncated4 = torch.zeros_like(S)
    S_truncated4[:rank // 8] = S[:rank // 8]
    
    S_truncated5 = torch.zeros_like(S)
    S_truncated5[:16] = S[:16]
    
    # 重建矩阵
    # reconstructed_matrix = U @ torch.diag(S_truncated) @ V.T
    reconstructed_matrix0 = U @ torch.diag(S) @ Vh
    reconstructed_matrix1 = U @ torch.diag(S_truncated1) @ Vh
    reconstructed_matrix2 = U @ torch.diag(S_truncated2) @ Vh
    reconstructed_matrix3 = U @ torch.diag(S_truncated3) @ Vh
    reconstructed_matrix4 = U @ torch.diag(S_truncated4) @ Vh
    reconstructed_matrix5 = U @ torch.diag(S_truncated5) @ Vh
    
    
    # 2 norm
    print(f'Frobenius norm of error0: {torch.norm(filtered_matrix - reconstructed_matrix0, p=2).item()}')
    print(f'Frobenius norm of error1: {torch.norm(filtered_matrix - reconstructed_matrix1, p=2).item()}')
    print(f'Frobenius norm of error2: {torch.norm(filtered_matrix - reconstructed_matrix2, p=2).item()}')
    print(f'Frobenius norm of error3: {torch.norm(filtered_matrix - reconstructed_matrix3, p=2).item()}')
    print(f'Frobenius norm of error4: {torch.norm(filtered_matrix - reconstructed_matrix4, p=2).item()}')
    print(f'Frobenius norm of error5: {torch.norm(filtered_matrix - reconstructed_matrix5, p=2).item()}')
    
    # 2 norm of filtered_matrix
    print(f'Frobenius norm of filtered_matrix: {torch.norm(filtered_matrix, p=2).item()}')
    print(f'Frobenius norm of reconstructed_matrix0: {torch.norm(reconstructed_matrix0, p=2).item()}')
    print(f'Frobenius norm of reconstructed_matrix1: {torch.norm(reconstructed_matrix1, p=2).item()}')
    print(f'Frobenius norm of reconstructed_matrix2: {torch.norm(reconstructed_matrix2, p=2).item()}')
    print(f'Frobenius norm of reconstructed_matrix3: {torch.norm(reconstructed_matrix3, p=2).item()}')
    print(f'Frobenius norm of reconstructed_matrix4: {torch.norm(reconstructed_matrix4, p=2).item()}')
    print(f'Frobenius norm of reconstructed_matrix5: {torch.norm(reconstructed_matrix5, p=2).item()}')
    

    
    

# clip_score=0.01 is the best
# I want to use clip_score=0.02 for test
def quantize_matrix(matrix, bits=2, clip_percent=1, clip_score=0.02):
    # 确保输入是浮点型
    matrix = matrix.float()
    std = torch.std(matrix)
    mean_abs = torch.mean(torch.abs(matrix))
    # 计算裁剪阈值
    # lower_percentile = torch.quantile(matrix, clip_percent / 100)
    # upper_percentile = torch.quantile(matrix, 1 - clip_percent / 100)
    lower_percentile = -clip_score
    upper_percentile = clip_score
    
    # 裁剪矩阵
    matrix_clipped = torch.clamp(matrix, lower_percentile, upper_percentile)
    matrix = matrix_clipped
    
    
    # 计算量化参数
    q_min, q_max = 0, 2**bits - 1
    min_val, max_val = matrix.min(), matrix.max()
    scale = (max_val - min_val) / (q_max - q_min)
    zero_point = q_min - torch.round(min_val / scale)

    # 执行量化
    if bits == 8 and 0:
        quantized_matrix = torch.quantize_per_tensor(matrix, scale, zero_point, torch.quint8)
        # 反量化以便比较和进一步计算
        dequantized_matrix = quantized_matrix.dequantize()
    elif bits == 2:
        # 特殊的2位模式，存储三个值：-scale, 0, scale
        abs_max = torch.max(torch.abs(matrix))
        
        # bound = std * 0.675
        # scale = bound * 2
        
        bound = std
        scale = bound * 2
        
        print(f'abs_max: {abs_max}, std: {std}, mean_abs: {mean_abs}, bound: {bound}, scale: {scale}')
        
        # 量化
        quantized_matrix = torch.zeros_like(matrix, dtype=torch.int8)
        quantized_matrix[matrix > bound] = 1
        quantized_matrix[matrix < -bound] = -1
        
        # 展示zero部分的数量
        zero_count = torch.count_nonzero(quantized_matrix == 0)
        print(f'zero_count: {zero_count}, ratio: {zero_count.item() / matrix.numel()}')
        
        # 反量化
        dequantized_matrix = quantized_matrix.float() * scale
    else:
        # 自定义 2 位或 4 位量化，但使用 quint8 存储
        matrix_normalized = torch.clamp((matrix - min_val) / (max_val - min_val), 0, 1)
        quantized_values = torch.round(matrix_normalized * q_max)
        
        
        ################# new method
        # process_matrix(quantized_values, k=bits, p = bits-1)
        
        # just use lazy way
        quantized_matrix = quantized_values
        dequantized_matrix = quantized_values.float() / q_max * (max_val - min_val) + min_val
        

    print(f'quantized_matrix[:5, :5]: {quantized_matrix[:5, :5]}')
    print(f'middle value: {quantized_matrix[matrix.size(0) // 2:matrix.size(0) // 2+5, matrix.size(1) // 2:matrix.size(1) // 2+5]}')
    print(f'quantized_matrix[-5:, -5:]: {quantized_matrix[-5:, -5:]}')
    print(f'dequantized_matrix - matrix: {torch.norm(dequantized_matrix - matrix, p=2).item()}')
    

    return quantized_matrix, dequantized_matrix, scale, zero_point


@torch.no_grad()
def find_left_mul(A, B, device='cpu', rank=-1):
    # find C, such that min |C @ A - B|_F
    # C = B @ A^\dagger
    # A: d1 x d2, B: d1 x d2, C: d1 x d1
    print(f'A size: {A.size()}, B size: {B.size()}')
    assert A.size() == B.size()
    
    original_device = A.device
    
    A = A.to(device)
    B = B.to(device)
    
    # calculate the pseudo inverse of A
    A_inv = torch.pinverse(A)
    print(f'A_inv size: {A_inv.size()}')

    C = B @ A_inv
    
    
    print(f'C size: {C.size()}')
    print(f"F norm of A: {torch.norm(A, p='fro').item()}, F norm of B: {torch.norm(B, p='fro').item()}, F norm of C: {torch.norm(C, p='fro').item()}, F norm of B - C @ A: {torch.norm(B - C @ A, p='fro').item()}")
    
    # calculate the stable rank of C
    # sr_C = stable_rank(C)
    # print(f'sr_C: {sr_C}')
    
    # try to decompose C by SVD and see now the difference of B - C' @ A, where C' is the SVD form of C
    if rank > 0:
        # # Perform SVD on C
        # U, S, Vh = torch.linalg.svd(C, full_matrices=False)
        # print(f"U size: {U.size()}, S size: {S.size()}, Vh size: {Vh.size()}")
        
        # # Keep only the top rank singular values
        # S_truncated = torch.zeros_like(S)
        # S_truncated[:rank] = S[:rank]
        
        # # Reconstruct the matrix C using the truncated SVD components
        # C_truncated = U @ torch.diag(S_truncated) @ Vh
        
        # print(f"C_truncated - C: {torch.norm(C_truncated - C, p='fro').item()}")
        # print(f"F norm of C_truncated: {torch.norm(C_truncated, p='fro').item()}, F norm of B - C_truncated @ A: {torch.norm(B - C_truncated @ A, p='fro').item()}")
        # print(f'\n\nS: {S}\n')
        # # C = C_truncated
        
        ############### Perform SVD on B
        U, S, Vh = torch.linalg.svd(B, full_matrices=False)
        print(f"U size: {U.size()}, S size: {S.size()}, Vh size: {Vh.size()}")
        
        # Keep only the top rank singular values
        S_truncated = torch.zeros_like(S)
        S_truncated[:rank] = S[:rank]
        
        # Reconstruct the matrix C using the truncated SVD components
        B_truncated = U @ torch.diag(S_truncated) @ Vh
        
        print(f"B_truncated - B: {torch.norm(B_truncated - B, p='fro').item()}")
        
        C_truncated = B_truncated @ A_inv
        
        print(f'C_truncated.shape = {C_truncated.shape}')
        print(f"F norm of C_truncated: {torch.norm(C_truncated, p='fro').item()}, F norm of B - C_truncated @ A: {torch.norm(B - C_truncated @ A, p='fro').item()}")
        print(f'\n\nS: {S}\n')
        # C = C_truncated
        
    # quantize C to 8 bits, and see the difference
    print(f'================================= quantize C =================================')
    quantize_C, dequantize_C, scale, zero_point = quantize_matrix(C)
    print(f'quantize_C size: {quantize_C.size()}, scale: {scale}, zero_point: {zero_point}')
    print(f'dequantize_C size: {dequantize_C.size()}, [:5, :5]: {dequantize_C[:5, :5]}\n [-5:, -5:]: {dequantize_C[-5:, -5:]}\nmiddle: {dequantize_C[dequantize_C.size(0) // 2:dequantize_C.size(0) // 2+5, dequantize_C.size(1) // 2:dequantize_C.size(1) // 2+5]}')
    # dequantize_C is now QUInt8, need to convert it back to float for calculation
    dequantize_C = dequantize_C.float()
    quantize_C = quantize_C.float()
    # visualize the quantize_C
    vis_log_density_distribution(quantize_C, name='quantize_C')
    
    print(f'F norm of dequantize_C: {torch.norm(dequantize_C, p=2).item()}, F norm of B - dequantize_C @ A: {torch.norm(B - dequantize_C @ A, p=2).item()}')

    # quantize B and see the difference
    print(f'================================= quantize B =================================')
    quantize_B, dequantize_B, scale, zero_point = quantize_matrix(B)
    print(f'quantize_B size: {quantize_B.size()}, scale: {scale}, zero_point: {zero_point}')
    print(f'dequantize_B size: {dequantize_B.size()}, [:5, :5]: {dequantize_B[:5, :5]}\n [-5:, -5:]: {dequantize_B[-5:, -5:]}\nmiddle: {dequantize_B[dequantize_B.size(0) // 2:dequantize_B.size(0) // 2+5, dequantize_B.size(1) // 2:dequantize_B.size(1) // 2+5]}')
    vis_log_density_distribution(quantize_B, name='quantize_B')
    
    C = C.to(original_device)
    # clean memory
    del A, B, A_inv
    if rank > 0:
        del U, S, Vh, S_truncated, B_truncated, C_truncated
    torch.cuda.empty_cache()
    
    exit(-1)
    return C

@torch.no_grad()
def process_left_mul(A, B, name='', store_dir='', device='cpu', rank=400):
    C = find_left_mul(A, B, device=device, rank=rank)
    
    # save C
    C_path = os.path.join(store_dir, f'C_{name}.pt')
    torch.save(C, C_path)
    print(f'save C to {C_path}')
    
    return C

@torch.no_grad()
def vis_log_density_distribution(matrix, name=''):
    # visualize the density distribution of the matrix
    plt.figure(figsize=(10, 6))
    plt.hist(matrix.cpu().numpy().flatten(), bins=100, log=True)
    plt.title(f'Log Density Distribution of {name}')
    
    output_path = os.path.join(args.pic_path, f'log_density_distribution_{name}.png')
    plt.savefig(output_path)
    print(f'log density distribution saved to {output_path}')
    


@torch.no_grad()
def find_left_mul_llama2(device='cpu', rank=400):
    model = LlamaForCausalLM.from_pretrained(model_dir)
    model.eval()
    
    layer1 = 2
    layer2 = 3
    # get the weight matrices
    weight_1_gate = model.model.layers[layer1].mlp.gate_proj.weight.data
    weight_1_up = model.model.layers[layer1].mlp.up_proj.weight.data
    weight_1_down = model.model.layers[layer1].mlp.down_proj.weight.data
    
    weight_2_gate = model.model.layers[layer2].mlp.gate_proj.weight.data
    weight_2_up = model.model.layers[layer2].mlp.up_proj.weight.data
    weight_2_down = model.model.layers[layer2].mlp.down_proj.weight.data
    
    # find the left multiplication matrix
    # if C is 4096*4096, the error is still very large!
    C_gate = process_left_mul(weight_1_gate, weight_2_gate, name=f'gate_{layer1}_{layer2}', store_dir=args.store_path, device=device, rank=rank)
    C_up = process_left_mul(weight_1_up, weight_2_up, name=f'up_{layer1}_{layer2}', store_dir=args.store_path, device=device, rank=rank)
    C_down = process_left_mul(weight_1_down.T, weight_2_down.T, name=f'down_{layer1}_{layer2}', store_dir=args.store_path, device=device, rank=rank)
    
    # visualize the density distribution
    vis_log_density_distribution(C_gate, name=f'gate_{layer1}_{layer2}')
    vis_log_density_distribution(C_up, name=f'up_{layer1}_{layer2}')
    vis_log_density_distribution(C_down, name=f'down_{layer1}_{layer2}')
    
    # visualize the weight matrices
    vis_fc_layers(C_gate, f'C_gate_{layer1}_{layer2}')
    vis_fc_layers(C_up, f'C_up_{layer1}_{layer2}')      
    vis_fc_layers(C_down, f'C_down_{layer1}_{layer2}')

    # calculate stable rank
    sr_C_gate = stable_rank(C_gate, device=device, note=f'gate_{layer1}_{layer2}')
    sr_C_up = stable_rank(C_up, device=device, note=f'up_{layer1}_{layer2}')
    sr_C_down = stable_rank(C_down, device=device, note=f'down_{layer1}_{layer2}')
    
    print(f'sr_C_gate: {sr_C_gate}, sr_C_up: {sr_C_up}, sr_C_down: {sr_C_down}')
    
    
# from torch.optim import LBFGS

# def convex_relaxation_pytorch(B, A, max_iter=100, lambda_param=0.1, device='cuda' if torch.cuda.is_available() else 'cpu', init_C=None):
#     m, n = B.shape
    
#     # put B and A onto the same device
#     B = B.to(device)
#     A = A.to(device)
    
#     #   min |B - C @ A|_F + lambda * |C|_1
#     if init_C is not None:
#         C = torch.tensor(init_C, requires_grad=True, device=device)
#     else:
#         C = torch.zeros((m, m), requires_grad=True, device=device)
    
#     optimizer = LBFGS([C], max_iter=max_iter, line_search_fn='strong_wolfe')
    
#     def closure():
#         optimizer.zero_grad()
#         loss = torch.norm(B - C @ A, 'fro') + lambda_param * torch.sum(torch.abs(C))
#         loss.backward()
#         return loss
    
#     # optimize C
#     for i in range(max_iter):
#         optimizer.step(closure)
#         print(f'iter {i}, loss: {closure().item()}')
    
#     C = C.detach()
    
#     # calculate the error
#     print(f'F norm of B - C @ A: {torch.norm(B - C @ A, p=2).item()}, norm of C: {torch.norm(C, p=2).item()}, norm of B: {torch.norm(B, p=2).item()}, norm of A: {torch.norm(A, p=2).item()}')
    
#     C_discrete = torch.round(torch.clamp(C, -1, 1))
    
#     print(f'F norm of B - C_discrete @ A: {torch.norm(B - C_discrete @ A, p=2).item()}, norm of C_discrete: {torch.norm(C_discrete, p=2).item()}')
    
#     print(f'ratio of zeros in C_discrete: {torch.sum(C_discrete == 0).item() / (m * m)}')
    
#     # remove memory
#     del B, A, C
#     torch.cuda.empty_cache()
    
#     return C_discrete.cpu()

def calculate_zero_ratio(C, ratio=0.1):
    std_C = torch.std(C)
    print(f'std of C: {std_C}')
    zero_threshold = std_C * ratio
    zero_ratio = torch.sum(torch.abs(C) < zero_threshold).item() / (C.size(0) * C.size(1))
    
    return zero_ratio


def process_C(C, ratio=0.1):
    zero_threshold = torch.std(C) * ratio
    
    # for every element that is smaller than zero_threshold, set it to zero
    processed_C = C.clone()
    processed_C[torch.abs(C) < zero_threshold] = 0
    
    return processed_C


def find_sparse_matrix(B, A, l1_lambda=0.1, lr=0.01, epochs=1000, device='cpu', init_C=None, step_size=100, gamma=0.5, ratio=1.0):

    B = B.to(device)
    A = A.to(device)

    if init_C is not None:
        C = torch.tensor(init_C, requires_grad=True, device=device)
        init_zero_ratio = calculate_zero_ratio(C, ratio)
        print(f'ratio of zeros in init C: {init_zero_ratio}')
    else:
        C = torch.randn(B.size(0), A.size(0), requires_grad=True, device=device)

    optimizer = optim.Adam([C], lr=lr)
    schedular = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    
    for epoch in range(epochs):
        optimizer.zero_grad()

        B_pred = C @ A

        loss = nn.MSELoss()(B_pred, B) + l1_lambda * torch.norm(C, p=1)

        loss.backward()
        optimizer.step()
        schedular.step()

        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}, lr: {optimizer.param_groups[0]["lr"]}')

    return_C = C.detach().cpu()
    # finally calcuate the error
    error = torch.norm(B - C @ A, p=2).item()
    print(f'error: {error}, norm of B: {torch.norm(B, p=2).item()}, norm of C: {torch.norm(return_C, p=2).item()}, norm of A: {torch.norm(A, p=2).item()}')
    
    # calculate the ratio of zeros in C
    return_ratio = calculate_zero_ratio(return_C, ratio)
    print(f'ratio of zeros in C: {return_ratio}, init ratio: {init_zero_ratio}')
    
    # free memory
    del B, A, C
    torch.cuda.empty_cache()
    
    return return_C


# @torch.no_grad()
def find_left_mul_llama3(device='cpu', rank=400):
    model = LlamaForCausalLM.from_pretrained(model_dir)
    model.eval()
    
    layer1 = 24
    layer2 = 25
    
    # get the weight matrices
    weight_1_gate = model.model.layers[layer1].mlp.gate_proj.weight.data
    weight_1_up = model.model.layers[layer1].mlp.up_proj.weight.data
    weight_1_down = model.model.layers[layer1].mlp.down_proj.weight.data
    
    weight_2_gate = model.model.layers[layer2].mlp.gate_proj.weight.data
    weight_2_up = model.model.layers[layer2].mlp.up_proj.weight.data
    weight_2_down = model.model.layers[layer2].mlp.down_proj.weight.data
    
    vis_log_density_distribution(weight_1_gate, name=f'Gate_{layer1}')
    vis_log_density_distribution(weight_1_up, name=f'Up_{layer1}')
    vis_log_density_distribution(weight_1_down, name=f'Down_{layer1}')
    vis_log_density_distribution(weight_2_gate, name=f'Gate_{layer2}')
    vis_log_density_distribution(weight_2_up, name=f'Up_{layer2}')
    vis_log_density_distribution(weight_2_down, name=f'Down_{layer2}')
    
    exit()
    # find the left multiplication matrix
    # if C is 4096*4096, the error is still very large!
    # C_gate_init = process_left_mul(weight_1_gate, weight_2_gate, name=f'gate_{layer1}_{layer2}', store_dir=args.store_path, device=device, rank=rank)
    # C_up_init = process_left_mul(weight_1_up, weight_2_up, name=f'up_{layer1}_{layer2}', store_dir=args.store_path, device=device, rank=rank)
    # C_down_init = process_left_mul(weight_1_down.T, weight_2_down.T, name=f'down_{layer1}_{layer2}', store_dir=args.store_path, device=device, rank=rank)
    
    # load the initial C
    C_gate_init = torch.load(os.path.join(args.store_path, f'C_gate_{layer1}_{layer2}.pt'))
    C_up_init = torch.load(os.path.join(args.store_path, f'C_up_{layer1}_{layer2}.pt'))
    C_down_init = torch.load(os.path.join(args.store_path, f'C_down_{layer1}_{layer2}.pt'))
    
    
    # C_gate = convex_relaxation_pytorch(weight_2_gate, weight_1_gate, max_iter=100, lambda_param=0.1, device=device, init_C=C_gate_init)
    # C_up = convex_relaxation_pytorch(weight_2_up, weight_1_up, max_iter=100, lambda_param=0.1, device=device, init_C=C_up_init)
    # C_down = convex_relaxation_pytorch(weight_2_down.T, weight_1_down.T, max_iter=100, lambda_param=0.1, device=device, init_C=C_down_init)
    
    # vis_log_density_distribution(C_gate_init, name=f'l2_gate_{layer1}_{layer2}')
    # vis_log_density_distribution(C_up_init, name=f'l2_up_{layer1}_{layer2}')
    # vis_log_density_distribution(C_down_init, name=f'l2_down_{layer1}_{layer2}')

    # C_gate_init, C_up_init, C_down_init = None, None, None
    
    l1_lambda = 1e-10
    lr = 0.1
    epochs = 400
    step_size = 10
    gamma = 0.95
    ratio = 0.1
    
    C_gate = find_sparse_matrix(weight_2_gate, weight_1_gate, l1_lambda=l1_lambda, lr=lr, epochs=epochs, device=device, init_C=C_gate_init, step_size=step_size, gamma=gamma)
    vis_log_density_distribution(C_gate, name=f'l1_gate_{layer1}_{layer2}')
    C_gate_processed = process_C(C_gate, ratio)
    print(f'errors: {torch.norm(weight_2_gate - C_gate_processed @ weight_1_gate, p=2).item()}')
    # fast fvd
    U_pt_gate, S_pt_gate, Vh_pt_gate = torch.linalg.svd(C_gate, full_matrices=False)
    # truncate by rank
    processed_truncated_C_gate = U_pt_gate[:, :rank] @ torch.diag(S_pt_gate[:rank]) @ Vh_pt_gate[:rank, :]
    print(f'diff between processed and processed_truncated: {torch.norm(C_gate_processed - processed_truncated_C_gate, p=2).item()}')
    print(f'F norm of processed_truncated_C_gate: {torch.norm(processed_truncated_C_gate, p=2).item()}, F norm of B - processed_truncated_C_gate @ A: {torch.norm(weight_2_gate - processed_truncated_C_gate @ weight_1_gate, p=2).item()}')
    
    
    C_up = find_sparse_matrix(weight_2_up, weight_1_up, l1_lambda=l1_lambda, lr=lr, epochs=epochs, device=device, init_C=C_up_init, step_size=step_size, gamma=gamma)
    C_up_processed = process_C(C_up, ratio)
    print(f'errors: {torch.norm(weight_2_up - C_up_processed @ weight_1_up, p=2).item()}')
    vis_log_density_distribution(C_up, name=f'l1_up_{layer1}_{layer2}')
    
    C_down = find_sparse_matrix(weight_2_down.T, weight_1_down.T, l1_lambda=l1_lambda, lr=lr, epochs=epochs, device=device, init_C=C_down_init, step_size=step_size, gamma=gamma)
    C_down_processed = process_C(C_down, ratio)
    print(f'errors: {torch.norm(weight_2_down.T - C_down_processed @ weight_1_down.T, p=2).item()}')
    vis_log_density_distribution(C_down, name=f'l1_down_{layer1}_{layer2}')
    # visualize the density distribution
    
    # store the processed C
    C_gate_path = os.path.join(args.store_path, f'C_gate_{layer1}_{layer2}_processed.pt')
    C_up_path = os.path.join(args.store_path, f'C_up_{layer1}_{layer2}_processed.pt')
    C_down_path = os.path.join(args.store_path, f'C_down_{layer1}_{layer2}_processed.pt')
    
    torch.save(C_gate_processed, C_gate_path)
    torch.save(C_up_processed, C_up_path)
    torch.save(C_down_processed, C_down_path)
    
    
    
    
    
    
    # visualize the weight matrices
    vis_fc_layers(C_gate, f'C_gate_{layer1}_{layer2}')
    vis_fc_layers(C_up, f'C_up_{layer1}_{layer2}')      
    vis_fc_layers(C_down, f'C_down_{layer1}_{layer2}')

    # calculate stable rank
    sr_C_gate = stable_rank(C_gate, device=device, note=f'gate_{layer1}_{layer2}')
    sr_C_up = stable_rank(C_up, device=device, note=f'up_{layer1}_{layer2}')
    sr_C_down = stable_rank(C_down, device=device, note=f'down_{layer1}_{layer2}')
    
    print(f'sr_C_gate: {sr_C_gate}, sr_C_up: {sr_C_up}, sr_C_down: {sr_C_down}')
    

@torch.no_grad()
def modify_mlp_general(
    model, 
    ref_layers: List[int],
    target_layers_list: List[List[int]],
):
    for ref, targets in zip(ref_layers, target_layers_list):
        for target in targets:
            model.model.layers[target].mlp.gate_proj.weight = model.model.layers[ref].mlp.gate_proj.weight
            model.model.layers[target].mlp.up_proj.weight = model.model.layers[ref].mlp.up_proj.weight
            model.model.layers[target].mlp.down_proj.weight = model.model.layers[ref].mlp.down_proj.weight
            # print_debug(f'change layer {target} using layer {ref}!', args.is_debug)
            print(f'change layer {target} using layer {ref}!')
    
    return model


# @torch.no_grad()
# def test_replace_similarity(dataset_name='arxiv-math', device = 'cuda:0', downsample_ratio=0.1):
#     # for given llama2 model, test the perplexity of the model replace x mlp layer with the x-1 mlp layers, where x range from [1, 31]
    
#     ref_layers = [i for i in range(0, 31)]
#     target_layers_list = [[i+1, i+2, i+3] for i in range(0, 29)]
#     target_layers_list.append([30, 31])
#     target_layers_list.append([31])
    
#     print(f'len of ref_layers: {len(ref_layers)}, len of target_layers_list: {len(target_layers_list)}')
#     print(f'ref_layers: {ref_layers}, target_layers_list: {target_layers_list}')
        
        
#     ppl_file = os.path.join(args.pic_path, f'ppl_{dataset_name}.pkl')
    
#     if os.path.exists(ppl_file):
#         with open(ppl_file, 'rb') as f:
#             ref_target_ppl_dict = pickle.load(f)
        
#         print(f'load ref_target_ppl_dict from {ppl_file}')
#     else:
#         # first import the dataset
#         from dataset_loader import DatasetManager
        
#         dm = DatasetManager()
        
#         train_texts, val_texts, test_texts = dm.get_dataset_texts(dataset_name, test_type='default')
        
#         # downsample the test_texts to 0.1
#         random_idx = torch.randperm(len(test_texts))[:int(len(test_texts) * downsample_ratio)]
#         test_texts = [test_texts[i] for i in random_idx]
#         print(f'downsampled test_texts to {len(test_texts)} for dataset {dataset_name} with ratio {downsample_ratio}')
        
#         test_encodings = tokenizer("\n".join(test_texts), return_tensors='pt')
        
        
        
#         ref_target_ppl_dict = {}
        
#         # calculate the baseline ppl
#         # model = LlamaForCausalLM.from_pretrained(model_dir).to(device)
#         # ppl = calculate_ppl(model, test_encodings, stride=512, device=device)
#         # print(f'baseline ppl: {ppl}') # 2.95
        
#         # # delete the model and clean the memory
#         # del model
#         # torch.cuda.empty_cache()
        
        
#         for i in range(len(ref_layers)):
#             ref = ref_layers[i]
            
#             for j in range(len(target_layers_list[i])):
#                 target = target_layers_list[i][j]
                
#                 # replace the target layer with the ref layer
#                 model = LlamaForCausalLM.from_pretrained(model_dir).to(device)
#                 model = modify_mlp_general(model, [ref], [[target]])
                
#                 ppl = calculate_ppl(model, test_encodings, stride=512, device=device)
#                 print(f'ref: {ref}, target: {target}, ppl: {ppl}')
                
#                 ref_target_ppl_dict[(ref, target)] = ppl
                
#                 # delete the model and clean the memory
#                 del model
#                 torch.cuda.empty_cache()
    
        
#         with open(ppl_file, 'wb') as f:
#             pickle.dump(ref_target_ppl_dict, f)
    
    
#     # print the dict
#     print(f'ref_target_ppl_dict: {ref_target_ppl_dict}')
    
#     # iterate the dict and bound the max ppl as 10
#     for key in ref_target_ppl_dict.keys():
#         if ref_target_ppl_dict[key] > 10:
#             ref_target_ppl_dict[key] = 10
    
#     # plot the target-1, target-2 and target-3 perplexity. where target-n is the average of the n target layers
    
#     target_ppl_dict = [None] * 3

#     target_ppl_dict[0] = {ref: ref_target_ppl_dict[(ref, ref+1)] for ref in ref_layers}
#     print(f'target_ppl_dict[0]: {target_ppl_dict[0]}')
#     target_ppl_dict[1] = {ref: (ref_target_ppl_dict[(ref, ref+1)] + ref_target_ppl_dict[(ref, ref+2)]) / 2 for ref in ref_layers[:-1]}
#     target_ppl_dict[1][ref_layers[-1]] = target_ppl_dict[0][ref_layers[-1]]
#     print(f'target_ppl_dict[1]: {target_ppl_dict[1]}')
    
#     target_ppl_dict[2] = {ref: (ref_target_ppl_dict[(ref, ref+1)] + ref_target_ppl_dict[(ref, ref+2)] + ref_target_ppl_dict[(ref, ref+3)]) / 3 for ref in ref_layers[:-2]}
#     target_ppl_dict[2][ref_layers[-2]] = target_ppl_dict[1][ref_layers[-2]]
#     target_ppl_dict[2][ref_layers[-1]] = target_ppl_dict[0][ref_layers[-1]]
#     print(f'target_ppl_dict[2]: {target_ppl_dict[2]}')
    
#     # plot the target-1, target-2 and target-3 perplexity in the same figure
#     plt.figure(figsize=(10, 6))
#     plt.plot(list(target_ppl_dict[0].keys()), list(target_ppl_dict[0].values()), label='ref+1')
#     plt.plot(list(target_ppl_dict[1].keys()), list(target_ppl_dict[1].values()), label='ref+1 & ref+2')
#     plt.plot(list(target_ppl_dict[2].keys()), list(target_ppl_dict[2].values()), label='ref+1 & ref+2 & ref+3')
    
#     plt.xlabel('ref layer')
#     plt.ylabel('ppl')    
#     plt.legend()
#     plt.title(f'ppl for different ref layers')
    
#     plt.savefig(os.path.join(args.pic_path, f'ppl_{dataset_name}.png'))
    
    
#     print(f'save the ppl plot to {os.path.join(args.pic_path, f"ppl_{dataset_name}.png")}')


@torch.no_grad()
def test_replace_similarity(dataset_names='arxiv-math', device = 'cuda:0', downsample_ratio=0.1):
    # for given llama2 model, test the perplexity of the model replace x mlp layer with the x-1 mlp layers, where x range from [1, 31]
    print(f'dataset_names: {dataset_names}, device: {device}, downsample_ratio: {downsample_ratio}')
    dataset_name_list = dataset_names.split('::')
    ref_layers = [i for i in range(0, 31)]
    target_layers_list = [[i+1] for i in range(0, 31)]
    
    print(f'len of ref_layers: {len(ref_layers)}, len of target_layers_list: {len(target_layers_list)}')
    print(f'ref_layers: {ref_layers}, target_layers_list: {target_layers_list}')
    
    dataset_ppl_dict = {}
    baseline_ppl_dict = {}
    baseline_ppl_dict = {
        "arxiv-math": 2.95,
        "alpaca-gpt4": 2.5,
        "databricks-dolly-15k": 4.31,
        "gsm8k": 2.38,
        "dialogsum": 3.61
    }
    
    name_dict = {
        "arxiv-math": "Arxiv-math",
        "alpaca-gpt4": "GPT4-Alpaca",
        "databricks-dolly-15k": "Dolly",
        "gsm8k": "GSM8k",
        "dialogsum": "DialogSum"
    }
    
    for dataset_name in dataset_name_list:
        ppl_file = os.path.join(args.pic_path, f'ppl_{dataset_name}.pkl')
        
        
        # load the ppl dict
        if os.path.exists(ppl_file):
            with open(ppl_file, 'rb') as f:
                ref_target_ppl_dict = pickle.load(f)
            print(f'load ref_target_ppl_dict from {ppl_file}')
        else:
            
            
            # #####################################
            # first import the dataset
            dm = DatasetManager()
            train_texts, val_texts, test_texts = dm.get_dataset_texts(dataset_name, test_type='default')
            
            # downsample the test_texts 
            random_idx = torch.randperm(len(test_texts))[:int(len(test_texts) * downsample_ratio)]
            test_texts = [test_texts[i] for i in random_idx]
            print(f'downsampled test_texts to {len(test_texts)} for dataset {dataset_name} with ratio {downsample_ratio}')
            
            test_encodings = tokenizer("\n".join(test_texts), return_tensors='pt')
                
                
            # calculate the baseline ppl
            model = LlamaForCausalLM.from_pretrained(model_dir).to(device)
            baseline_ppl = calculate_ppl(model, test_encodings, stride=512, device=device)
            print(f'baseline ppl: {baseline_ppl}') # 2.95
            baseline_ppl_dict[dataset_name] = baseline_ppl
            
            # delete the model and clean the memory
            del model
            torch.cuda.empty_cache()
            ##############################
            
            print(f'begin calculating for {ppl_file}')
            ref_target_ppl_dict = {}
            
            for i in range(len(ref_layers)):
                ref = ref_layers[i]
                
                for j in range(len(target_layers_list[i])):
                    target = target_layers_list[i][j]
                    
                    # replace the target layer with the ref layer
                    model = LlamaForCausalLM.from_pretrained(model_dir).to(device)
                    model = modify_mlp_general(model, [ref], [[target]])
                    
                    ppl = calculate_ppl(model, test_encodings, stride=512, device=device)
                    print(f'ref: {ref}, target: {target}, ppl: {ppl}')
                    
                    ref_target_ppl_dict[(ref, target)] = ppl
                    
                    # delete the model and clean the memory
                    del model
                    torch.cuda.empty_cache()
        
            
            with open(ppl_file, 'wb') as f:
                pickle.dump(ref_target_ppl_dict, f)
        
        
        # print the dict
        print(f'ref_target_ppl_dict: {ref_target_ppl_dict}')

        dataset_ppl_dict[dataset_name] = ref_target_ppl_dict
        
    
    
    
    # Use bold font for all the text 
    # plt.rcParams["font.weight"] = "bold"
    
    # use subfig to draw broken y-axis
    # plot different dataset in the same figure. The plot of ppl for different ref layers is solid line, while the baseline is a dashed level line with the same color
    # the top sub figure should be smaller than the bottom sub figure
    # fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
    fig = plt.figure(figsize=(7, 6))
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 9], hspace=0.05)
    
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1], sharex=ax1)
    
    # get different color, like red, blue, green, yellow, purple, orange, brown, pink, gray, cyan
    # color_map = plt.get_cmap('tab10')
    color_map = plt.get_cmap('tab10')
    
    ax2.set_ylim(0, 6)
    ax1.set_ylim(10000, 20000)
    
    # plt.rcParams.update({'font.size': 16})
    
    
    for i, dataset_name in enumerate(dataset_name_list):
        ref_target_ppl_dict = dataset_ppl_dict[dataset_name]
        x_list = [r+1 for r in ref_layers]
        y_list = [ref_target_ppl_dict[(ref, ref+1)] for ref in ref_layers]
        # draw the solid line
        ax1.plot(x_list, y_list, label=f'{name_dict[dataset_name]}', color=color_map(i))
        ax2.plot(x_list, y_list, label=f'{name_dict[dataset_name]}', color=color_map(i))
        
        
        d = .2
        kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
        
        ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
        ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

        # draw the baseline line
        ax2.axhline(y=baseline_ppl_dict[dataset_name], color=color_map(i), linestyle='--')
        # ax2.axhline(y=baseline_ppl_dict[dataset_name], color=color_map(i), linestyle='--')
        
        print(f"draw lines for {dataset_name}")
        
        if dataset_name == 'arxiv-math':
            lora_ref_list = [2,4,6,8,10,12,14,16,18,20, 22, 24, 26, 28]
            lora_ref_list = [l+1 for l in lora_ref_list]
            lora_arxiv_math_result_list = [2.97, 2.96, 2.93, 2.91, 2.90, 2.89, 2.97, 2.99, 2.91 , 2.91, 2.91, 2.95, 3.0, 2.93]
            # get a similar color as color_map(i)

            # plot another line for lora recovery result of arxiv-math, use solid line, deep blue color
            ax2.plot(lora_ref_list, lora_arxiv_math_result_list, label='Arxiv-math (Finetuned)', 
                        color='blue', linestyle='-', marker='o', markersize=5)
            
    

    # hide the xticks for ax1
    plt.setp(ax1.get_xticklabels(), visible=False)
    
    # hide the spines between ax and ax2
    ax1.spines.bottom.set_visible(False)
    ax1.spines.top.set_visible(False)
    ax2.spines.top.set_visible(False)
    ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    ax1.tick_params(labeltop=False)  # don't put tick labels at the top
    ax2.tick_params(labeltop=False)
    
    # hide the yticks for ax1   
    ax1.yaxis.set_visible(False)
    
    # just show 1.95e4 behind the y-axis of ax1
    ax1.text(-0.01, 0.5, '1.5e4', transform=ax1.transAxes, va='center', ha='right', fontsize=18)
    
    
    # set legend
    ax2.legend(fontsize=14)
    
    #larger font size
    # plt.rcParams.update({'font.size': 16})
    ax2.set_ylabel('Perplexity', fontsize=20)#, fontweight='bold')
    ax2.set_xlabel('Reference layer Index', fontsize=20)#, fontweight='bold')

    # ax1.set_title(f'Perplexity for Different Reference Layers')
    ax2.tick_params(axis='both', which='major', labelsize=18)
    
    # make the plot tight
    plt.tight_layout()
    
    # remove unnecessary space on the top and right
    plt.subplots_adjust(top=0.95, right=0.95)

    final_path = os.path.join(args.pic_path, f'ppl_{dataset_names}.png')
    plt.savefig(final_path)
    print(f'save the ppl plot to {final_path}')
    
    # save as pdf
    final_path = os.path.join(args.pic_path, f'ppl_{dataset_names}.pdf')
    plt.savefig(final_path)
    print(f'save the ppl plot to {final_path}')
    
    
    
    
    
        


@torch.no_grad()
def vis_L2_diff():
    model = LlamaForCausalLM.from_pretrained(model_dir)
    N = 32
    
    down_list, up_list, gate_list, down_norms, up_norms, gate_norms = [], [], [], [], [], []
    for i in range(0, N-1):
        layer1 = i
        layer2 = i+1
        
        weight_1_down = model.model.layers[layer1].mlp.down_proj.weight
        weight_1_up = model.model.layers[layer1].mlp.up_proj.weight
        weight_1_gate = model.model.layers[layer1].mlp.gate_proj.weight
        
        weight_2_down = model.model.layers[layer2].mlp.down_proj.weight
        weight_2_up = model.model.layers[layer2].mlp.up_proj.weight
        weight_2_gate = model.model.layers[layer2].mlp.gate_proj.weight
        
        down_L2_diff = torch.norm(weight_1_down - weight_2_down, p=2)
        up_L2_diff = torch.norm(weight_1_up - weight_2_up, p=2)
        gate_L2_diff = torch.norm(weight_1_gate - weight_2_gate, p=2)
        
        down_norm = torch.norm(weight_1_down, p=2)
        up_norm = torch.norm(weight_1_up, p=2)
        gate_norm = torch.norm(weight_1_gate, p=2)
        
        down_list.append(down_L2_diff)
        up_list.append(up_L2_diff)
        gate_list.append(gate_L2_diff)
        
        down_norms.append(down_norm)
        up_norms.append(up_norm)
        gate_norms.append(gate_norm)
    
    # vis
    print(f'down_list : {down_list}')
    print(f'down_norms: {down_norms}')
    # continue
    
    down_L2_diff_avg = torch.mean(torch.stack(down_list)).item()
    up_L2_diff_avg = torch.mean(torch.stack(up_list)).item()
    gate_L2_diff_avg = torch.mean(torch.stack(gate_list)).item()

    down_norm_avg = torch.mean(torch.stack(down_norms)).item()
    up_norm_avg = torch.mean(torch.stack(up_norms)).item()
    gate_norm_avg = torch.mean(torch.stack(gate_norms)).item()

    down_relative_error = torch.mean(torch.stack(down_list) / torch.stack(down_norms)).item()
    up_relative_error = torch.mean(torch.stack(up_list) / torch.stack(up_norms)).item()
    gate_relative_error = torch.mean(torch.stack(gate_list) / torch.stack(gate_norms)).item()
    
    print(f'down_L2_diff_avg: {down_L2_diff_avg}, down_norm_avg: {down_norm_avg}, down_relative_error: {down_relative_error}')
    print(f'up_L2_diff_avg: {up_L2_diff_avg}, up_norm_avg: {up_norm_avg}, up_relative_error: {up_relative_error}')
    print(f'gate_L2_diff_avg: {gate_L2_diff_avg}, gate_norm_avg: {gate_norm_avg}, gate_relative_error: {gate_relative_error}')
    
    
    # 准备绘图数据
    # averages = [
    #     down_L2_diff_avg, 
    #     down_norm_avg, 
    #     up_L2_diff_avg, 
    #     up_norm_avg, 
    #     gate_L2_diff_avg,
    #     gate_norm_avg
    # ]

    
    # labels = [
    #     r'$\overline{\|\Theta^{\mathrm{Down}}_{i+1} - \Theta^{\mathrm{Down}}_{i}\|}_2$',
    #     r'$\overline{\|\Theta^{\mathrm{Down}}_{i}\|}_2$',
    #     r'$\overline{\|\Theta^{\mathrm{Up}}_{i+1} - \Theta^{\mathrm{Up}}_{i}\|}_2$',
    #     r'$\overline{\|\Theta^{\mathrm{Up}}_{i}\|}_2$',
    #     r'$\overline{\|\Theta^{\mathrm{Gate}}_{i+1} - \Theta^{\mathrm{Gate}}_{i}\|}_2$',
    #     r'$\overline{\|\Theta^{\mathrm{Gate}}_{i}\|}_2$'
    # ]
    
    averages = [
        down_L2_diff_avg / down_norm_avg,
        up_L2_diff_avg / up_norm_avg,
        gate_L2_diff_avg / gate_norm_avg
    ]
    
    print(f'averages: {averages}')

    
    # labels = [
    #     r'$\frac{\overline{\|\Theta^{\mathrm{Down}}_{i+1} - \Theta^{\mathrm{Down}}_{i}\|}_2}{\overline{\|\Theta^{\mathrm{Down}}_{i}\|}_2}$',
    #     r'$\frac{\overline{\|\Theta^{\mathrm{Up}}_{i+1} - \Theta^{\mathrm{Up}}_{i}\|}_2}{\overline{\|\Theta^{\mathrm{Up}}_{i}\|}_2}$',
    #     r'$\frac{\overline{\|\Theta^{\mathrm{Gate}}_{i+1} - \Theta^{\mathrm{Gate}}_{i}\|}_2}{\overline{\|\Theta^{\mathrm{Gate}}_{i}\|}_2}$',
    # ]
    labels = [
        "Down",
        "Up",
        "Gate"
    ]

    colors = [
        '#3ab9dd',  # Down L2 Diff - 浅蓝色
        # '#0000FF',  # Down Norm - 蓝色
        '#4ade40',  # Up L2 Diff - 浅绿色
        # '#008000',  # Up Norm - 绿色
        '#d74848',  # Gate L2 Diff - 浅粉色
        # '#FF0000'   # Gate Norm - 红色
    ]
    
    # use latex
    # plt.rcParams['text.usetex'] = True
    # plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'
    
    # plt.rcParams['mathtext.fontset'] = 'stix'
    # plt.rcParams['font.family'] = 'STIXGeneral'
    
    plt.rcParams['mathtext.fontset'] = 'cm'
    plt.rcParams['font.family'] = 'serif'

    # larger text size
    plt.rcParams.update({'font.size': 18})
    
    # 绘制柱状图
    plt.figure(figsize=(4, 5))
    # fontsize=18
    plt.bar(labels, averages, color=colors)

    
    # 设置Y轴范围
    # plt.ylim(100, 200)
    plt.ylim(1.0, 1.6)

    # 设置Y轴标签
    # plt.ylabel('Average Relative Error Across Layers')

    # 美化图表
    # plt.xticks(rotation=0, ha='right')
    
    plt.rcParams.update({'font.size': 12})
    
    # 移除上方和右侧的边框线
    ax = plt.gca()
    import matplotlib.ticker as mtick
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0))
    ax.tick_params(axis='y', labelsize=14)
    # y label
    ax.set_ylabel('Average Relative Error Across Layers', fontsize=16)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.tight_layout()
    # 保存图表
    save_path = os.path.join(args.pic_path, 'average_L2_diff_and_norms.png')
    plt.savefig(save_path, dpi=300)
    
    save_path = os.path.join(args.pic_path, 'average_L2_diff_and_norms.pdf')
    plt.savefig(save_path)
    
    plt.close()
    print(f"save the plot to {save_path}")
            
        
        



@torch.no_grad()
def test_replace_similarity2(dataset_name='arxiv-math', _device = 'cuda:0', downsample_ratio=0.3):
    # for given llama2 model, test the perplexity of the model replace x mlp layer with the x-1 mlp layers, where x range from [1, 31]
    
    ref_layers = [i for i in range(0, 31)]
    target_layers_list = [[j for j in range(i+1, 32)] for i in range(0, 31)]
    
    file_path = os.path.join(args.pic_path, f'full_ppl_{dataset_name}.pkl')
    
    if os.path.exists(file_path):
        with open(file_path, 'rb') as f:
            ref_target_ppl_dict = pickle.load(f)
        
        print(f'load ref_target_ppl_dict from {file_path}')
    
    else:
        # first import the dataset
        from dataset_loader import DatasetManager
        
        dm = DatasetManager()
        
        train_texts, val_texts, test_texts = dm.get_dataset_texts(dataset_name, test_type='default')
        
        # downsample the test_texts to 0.1
        random_idx = torch.randperm(len(test_texts))[:int(len(test_texts) * downsample_ratio)]
        test_texts = [test_texts[i] for i in random_idx]
        print(f'downsampled test_texts to {len(test_texts)} for dataset {dataset_name} with ratio {downsample_ratio}')
        
        test_encodings = tokenizer("\n".join(test_texts), return_tensors='pt')
        
        
        
        print(f'len of ref_layers: {len(ref_layers)}, len of target_layers_list: {len(target_layers_list)}')
        print(f'ref_layers: {ref_layers}, target_layers_list: {target_layers_list}')
        
        ref_target_ppl_dict = {}
        
        # calculate the baseline ppl
        model = LlamaForCausalLM.from_pretrained(model_dir).to(_device)
        ppl = calculate_ppl(model, test_encodings, stride=512, device=_device)
        print(f'baseline ppl: {ppl}') # 2.95 for arxiv-math, and 7.68 for wikitext
        
        # delete the model and clean the memory
        del model
        torch.cuda.empty_cache()
        
        
        for i in range(len(ref_layers)):
            ref = ref_layers[i]
            
            for j in range(len(target_layers_list[i])):
                target = target_layers_list[i][j]
                
                # replace the target layer with the ref layer
                model = LlamaForCausalLM.from_pretrained(model_dir).to(_device)
                model = modify_mlp_general(model, [ref], [[target]])
                
                ppl = calculate_ppl(model, test_encodings, stride=512, device=_device)
                print(f'ref: {ref}, target: {target}, ppl: {ppl}')
                
                ref_target_ppl_dict[(ref, target)] = ppl
                
                # delete the model and clean the memory
                del model
                torch.cuda.empty_cache()
                
        with open(file_path, 'wb') as f:
            pickle.dump(ref_target_ppl_dict, f)
        
    
    # print the dict
    print(f'ref_target_ppl_dict: {ref_target_ppl_dict}')
    
    
    
    # visualize the dict in a 2-dim hot map
    ppl_matrix = torch.zeros(31, 31)    
    for ref in ref_layers:
        for target in target_layers_list[ref]:
            ppl_matrix[ref, target - 1] = ref_target_ppl_dict[(ref, target)]
    
    # for ref in ref_layers:
    #     for target in target_layers_list[ref]:
    #         ppl_matrix[ref, target - 1] = ref + target * 100 
    
    # add a different color solely for (0,0), which has very large value (>19k), while others are small (<10)
    # if ppl_matrix > 15, set it to 15
    # for i in range(31):
    #     for j in range(31):
    #         if ppl_matrix[i, j] > 15:
    #             ppl_matrix[i, j] = 15
    
    if dataset_name == 'arxiv-math':
        bounds = [0, 3.0, 3.1, 3.2, 3.3, 3.5, 3.7, 4.0, 4.5, 6.5, 20000]
        # cmap = plt.get_cmap('hot')
        # use green color
        cmap = plt.get_cmap('Greens')
        norm = mcolors.BoundaryNorm(bounds, cmap.N)
    elif dataset_name == 'wikitext':
        bounds = [0, 7.9, 8.0, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.8, 9.0, 10.0, 15.0, 20000]
        # cmap = plt.get_cmap('hot')
        # use green color
        cmap = plt.get_cmap('Greens')
        norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    # cmap = plt.get_cmap('hot')
    # norm = mcolors.Normalize(vmin=0, vmax=4000)
    
    # larger text size
    plt.rcParams.update({'font.size': 14})
    
    plt.figure(figsize=(7, 6))
    img = plt.imshow(ppl_matrix.cpu().numpy().T, cmap=cmap, interpolation='nearest', norm=norm)
    cbar = plt.colorbar(img)
    
    cbar.set_ticks(bounds)
    cbar.set_ticklabels([f'{bound:.1f}' for bound in bounds])
    
    # add axis numbers, just show every 2th number
    plt.xticks(ticks=range(0, 31, 2), labels=range(1, 32, 2))
    plt.yticks(ticks=range(0, 31, 2), labels=range(2, 33, 2))
    
    plt.xlabel('Reference layer', fontsize=20)
    plt.ylabel('Target layer', fontsize=20)
    
    # plt.title(f'Perplexity for different ref and target layers')
    
    # let the picture be tight
    plt.tight_layout()
    
    # remove unnecessary space on the top and right
    plt.subplots_adjust(bottom=0.02)
    
    path = os.path.join(args.pic_path, f'full_ppl_{dataset_name}.png')
    plt.savefig(path)
    
    path = os.path.join(args.pic_path, f'full_ppl_{dataset_name}.pdf')
    plt.savefig(path)
    print(f'save the full ppl plot to {path}')


def nonlinear_mapping_train(X, Y, hidden_dim=400, num_epochs=1000, learning_rate=0.001, device='cpu'):
    """
    This function trains a non-linear mapping from matrix X to Y using a simple neural network (MLP).

    Parameters:
    X: Input matrix of shape (m, n), where m is the number of data points and n is the dimension of each point.
    Y: Target matrix of shape (m, n), where m is the number of data points and n is the dimension of each point.
    hidden_dim: Number of hidden units in the neural network (default: 64).
    num_epochs: Number of epochs to train the model (default: 500).
    learning_rate: Learning rate for the optimizer (default: 0.001).

    Returns:
    model: The trained neural network model that approximates the mapping F from X to Y.
    """

    # Define the neural network model for non-linear mapping
    class NonlinearMapping(nn.Module):
        def __init__(self, input_dim, output_dim, hidden_dim):
            super(NonlinearMapping, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),  # First hidden layer
                # nn.ReLU(),                         # Activation function
                # nn.Linear(hidden_dim, hidden_dim),  # Second hidden layer
                # nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)   # Output layer
            )
            
            # Initialize the weights and biases of the model
            for layer in self.model:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_normal_(layer.weight)
                    nn.init.zeros_(layer.bias)
                    
                    

        def forward(self, x):
            return self.model(x)

    # Check dimensions of input and output matrices
    m, n = X.shape
    assert Y.shape == (m, n), "X and Y must have the same shape."
    print(f'X shape: {X.shape}, Y shape: {Y.shape}') # for gate: X shape: torch.Size([11008, 4096]), Y shape: torch.Size([11008, 4096]
    
    # move the data to the device
    X = X.to(device)
    Y = Y.to(device)
    
    # Initialize the model, loss function, and optimizer
    model = NonlinearMapping(input_dim=n, output_dim=n, hidden_dim=hidden_dim).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

    # Compute the L2 norms of X and Y for comparison
    norm_X = torch.norm(X)
    norm_Y = torch.norm(Y)
    
    batch_size = min(44032, X.shape[0])
    # Training loop
    for epoch in range(num_epochs):
        # Forward pass: compute the model's output
        # outputs = model(X)
        # loss = criterion(outputs, Y) * m
        
        # use batch training
        loss = 0
        for i in range(0, m, batch_size):
            X_batch = X[i:i+batch_size]
            Y_batch = Y[i:i+batch_size]
            outputs = model(X_batch)
            # loss += criterion(outputs, Y_batch) * X_batch.shape[0]
            loss += criterion(outputs * X_batch, Y_batch) * X_batch.shape[0]
            # loss += criterion(outputs + X, Y_batch) * X_batch.shape[0]

        # Print norms of X and Y, as well as the current loss
        if (epoch + 1) % 50 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, lr: {scheduler.get_last_lr()[0]}')


        # Backward pass: compute gradients and update model parameters
        optimizer.zero_grad()  # Clear gradients
        loss.backward()        # Backpropagation
        optimizer.step()       # Update model parameters
        scheduler.step()
        

    
    # test the final loss
    outputs = model(X)
    # print(f'Final loss: {torch.norm(outputs - Y).item()}')
    print(f'Final loss: {torch.norm(outputs * X - Y).item()}')
    # print(f'Final loss: {torch.norm(outputs + X - Y).item()}')
    print(f'Norm of X: {norm_X.item():.4f}, Norm of Y: {norm_Y.item():.4f}\n')
    
    # release the memory
    del X, Y, outputs
    torch.cuda.empty_cache()
    
    # Return the trained model
    return model


def nonlinear_compress(rank=400, device='cpu'):
    model = LlamaForCausalLM.from_pretrained(model_dir)
    model.eval()
    
    layer1 = 24
    layer2 = 25
    
    # get the weight matrices
    weight_1_gate = model.model.layers[layer1].mlp.gate_proj.weight.data
    weight_1_up = model.model.layers[layer1].mlp.up_proj.weight.data
    weight_1_down = model.model.layers[layer1].mlp.down_proj.weight.data
    
    weight_2_gate = model.model.layers[layer2].mlp.gate_proj.weight.data
    weight_2_up = model.model.layers[layer2].mlp.up_proj.weight.data
    weight_2_down = model.model.layers[layer2].mlp.down_proj.weight.data
    
    # change the weight from M * N to 4M * N/4, where N/4 means each 1/4 part of the original weight in order
    # compress the weight by 4 times
    gate_dim0, gate_dim1 = weight_1_gate.shape
    k = 1
    # gap = gate_dim1 // k
    # rank = rank * k
    # w1g = torch.cat([weight_1_gate[:,:gap], weight_1_gate[:,gap:2*gap], weight_1_gate[:,2*gap:3*gap], weight_1_gate[:,3*gap:]], dim=0)
    # w2g = torch.cat([weight_2_gate[:,:gap], weight_2_gate[:,gap:2*gap], weight_2_gate[:,2*gap:3*gap], weight_2_gate[:,3*gap:]], dim=0)
    
    # get w1g and w2g in a more efficient way
    # w1g_splits = torch.split(weight_1_gate, gap, dim=1) 
    # w2g_splits = torch.split(weight_2_gate, gap, dim=1)
    # print(f'len of w1g_splits: {len(w1g_splits)}, len of w2g_splits: {len(w2g_splits)}, shape of w1g_splits[0]: {w1g_splits[0].shape}')
    # w1g = torch.cat(w1g_splits, dim=0)
    # w2g = torch.cat(w2g_splits, dim=0)
    
    gap = gate_dim0 // k
    rank = rank // k
    w1g_splits = torch.split(weight_1_gate, gap, dim=0) 
    w2g_splits = torch.split(weight_2_gate, gap, dim=0)
    print(f'len of w1g_splits: {len(w1g_splits)}, len of w2g_splits: {len(w2g_splits)}, shape of w1g_splits[0]: {w1g_splits[0].shape}')
    w1g = torch.cat(w1g_splits, dim=1)
    w2g = torch.cat(w2g_splits, dim=1)
    
    # w2g = torch.cat([weight_2_gate[:gap], weight_2_gate[gap:2*gap], weight_2_gate[2*gap:3*gap], weight_2_gate[3*gap:]], dim=1)
    weight_1_gate = w1g
    weight_2_gate = w2g
    
    # test gate
    compress_gate = nonlinear_mapping_train(weight_1_gate, weight_2_gate, hidden_dim=rank, device=device)
    # compress_up = nonlinear_mapping_train(weight_1_up, weight_2_up, hidden_dim=rank, device=device)
    # compress_down = nonlinear_mapping_train(weight_1_down, weight_2_down, hidden_dim=rank, device=device)
    
  

def dequantize_weight(weight):
    if isinstance(weight, bnb.nn.Int8Params):
        # 直接将 Int8Params 转换为浮点张量
        return weight.float()
    else:
        # 如果不是 Int8Params，返回原始权重
        return weight
    

def try_int8(device='cuda:0', rank=100, if_try=1):
    if if_try == 0:
        model = LlamaForCausalLM.from_pretrained(model_dir)

    else:
        model = LlamaForCausalLM.from_pretrained(
            model_dir,
            # load_in_8bit=True,
            load_in_8bit=True,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
        )
    
    model.eval()
    
    layer1, layer2 = 2, 3
    
    # 获取量化的权重
    weight_1_down = model.model.layers[layer1].mlp.down_proj.weight
    weight_1_up = model.model.layers[layer1].mlp.up_proj.weight
    weight_1_gate = model.model.layers[layer1].mlp.gate_proj.weight
    
    weight_2_down = model.model.layers[layer2].mlp.down_proj.weight
    weight_2_up = model.model.layers[layer2].mlp.up_proj.weight
    weight_2_gate = model.model.layers[layer2].mlp.gate_proj.weight
    
    print(f'Shape of weight_1_down: {weight_1_down.shape}')
    print(f'Shape of weight_1_up: {weight_1_up.shape}')
    print(f'Shape of weight_1_gate: {weight_1_gate.shape}')
    print(f'weight_1_down[:5, :5]:\n{weight_1_down[:5, :5]} [:-5, :-5]:\n{weight_1_down[:-5, :-5]}')
    
    # 反量化权重
    dequant_weight_1_down = dequantize_weight(weight_1_down)
    dequant_weight_1_up = dequantize_weight(weight_1_up)
    dequant_weight_1_gate = dequantize_weight(weight_1_gate)
    
    dequant_weight_2_down = dequantize_weight(weight_2_down)
    dequant_weight_2_up = dequantize_weight(weight_2_up)
    dequant_weight_2_gate = dequantize_weight(weight_2_gate)
    
    # 打印形状和部分值
    print(f'Shape of dequant_weight_1_down: {dequant_weight_1_down.shape}')
    print(f'Shape of dequant_weight_1_up: {dequant_weight_1_up.shape}')
    print(f'Shape of dequant_weight_1_gate: {dequant_weight_1_gate.shape}')
    
    print(f'Dequant_weight_1_down[:5, :5]:\n{dequant_weight_1_down[:5, :5]}, [:-5, :-5]:\n{dequant_weight_1_down[:-5, :-5]}')
    # print(f'dequant_weight_1_up[:5, :5]:\n{dequant_weight_1_up[:5, :5]}')
    # print(f'dequant_weight_1_gate[:5, :5]:\n{dequant_weight_1_gate[:5, :5]}')
    
    # 处理左乘矩阵（假设 process_left_mul 函数已经定义）
    C_down = process_left_mul(dequant_weight_1_down.T, dequant_weight_2_down.T, rank=rank, device=device, name=f'int8_{layer1}_{layer2}_down', store_dir=args.store_path)
    C_up = process_left_mul(dequant_weight_1_up, dequant_weight_2_up, rank=rank, device=device, name=f'int8_{layer1}_{layer2}_up', store_dir=args.store_path)
    C_gate = process_left_mul(dequant_weight_1_gate, dequant_weight_2_gate, rank=rank, device=device, name=f'int8_{layer1}_{layer2}_gate', store_dir=args.store_path)
    
    # visualize log distribution of C_down, C_up, C_gate
    vis_log_density_distribution(C_down, name=f'int8_{layer1}_{layer2}_down')
    vis_log_density_distribution(C_up, name=f'int8_{layer1}_{layer2}_up')
    vis_log_density_distribution(C_gate, name=f'int8_{layer1}_{layer2}_gate')
    
    return C_down, C_up, C_gate


def try_sparse_transform(device='cuda:0', rank=100, if_try=1):
    if if_try == 0:
        model = LlamaForCausalLM.from_pretrained(model_dir)
    elif if_try == 2:
        #################### test the memory cost of int8 model
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            load_in_8bit=True,        # Load model in 8-bit precision
            # device_map='auto'         # Automatically map model to available devices
        )
        for name, param in model.named_parameters():
            print(f'{name}: {param.dtype}')
        
        total_params = 0
        total_bytes = 0

        for param in model.parameters():
            params = param.numel()
            bytes_per_param = param.element_size()
            total_params += params
            total_bytes += params * bytes_per_param

        total_gb = total_bytes / (1024 ** 3)

        print(f'Total parameters: {total_params}')
        print(f'Total memory usage: {total_bytes / (1024 ** 2):.2f} MB')
        print(f'Total memory usage: {total_gb:.2f} GB')
        
        print(f'GPU Memory Allocated: {torch.cuda.memory_allocated() / (1024 ** 3):.2f} GB')
        
        # Total parameters: 6738415616
        # Total memory usage: 6676.51 MB
        # Total memory usage: 6.52 GB
        # GPU Memory Allocated: 6.65 GB

        ####################### check the memory cost of MLP layers
        # Step 2: Initialize Counters
        total_mlp_params = 0
        total_mlp_bytes = 0

        # Define MLP layer identifiers
        mlp_layer_names = ['gate_proj', 'up_proj', 'down_proj']
        
        # Step 3: Iterate Through Model Parameters
        for name, param in model.named_parameters():
            # Check if the parameter belongs to an MLP layer
            if any(layer_name in name for layer_name in mlp_layer_names):
                params = param.numel()
                bytes_per_param = param.element_size()  # Should be 1 byte for 8-bit
                total_mlp_params += params
                total_mlp_bytes += params * bytes_per_param
                print(f'{name}: {params} parameters, {params * bytes_per_param / (1024 ** 2):.2f} MB')
            else:
                print(f'{name}: Not an MLP parameter, bytes: {param.numel() * param.element_size()/ (1024 ** 2):.2f} MB')
                
        # Step 4: Convert Bytes to Megabytes and Gigabytes
        total_mlp_mb = total_mlp_bytes / (1024 ** 2)
        total_mlp_gb = total_mlp_bytes / (1024 ** 3)

        # Step 5: Output the Results
        print(f'Total MLP parameters: {total_mlp_params}')
        print(f'Total MLP memory usage: {total_mlp_mb:.2f} MB')
        print(f'Total MLP memory usage: {total_mlp_gb:.2f} GB')    
        
        # Total MLP parameters: 4328521728                                                                                                                                                                                                                                                                                                                                                                                                     
        # Total MLP memory usage: 4128.00 MB                                                                                                                                                                                                                                                                                                                                                                                                   
        # Total MLP memory usage: 4.03 GB
        
        exit()
    else:
        model = LlamaForCausalLM.from_pretrained(
            model_dir,
            # load_in_8bit=True,
            load_in_8bit=True,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
        )
    
    model.eval()
    
    layer1, layer2 = 2, 3
    
    # 获取量化的权重
    weight_1_down = model.model.layers[layer1].mlp.down_proj.weight
    weight_1_up = model.model.layers[layer1].mlp.up_proj.weight
    weight_1_gate = model.model.layers[layer1].mlp.gate_proj.weight
    
    weight_2_down = model.model.layers[layer2].mlp.down_proj.weight
    weight_2_up = model.model.layers[layer2].mlp.up_proj.weight
    weight_2_gate = model.model.layers[layer2].mlp.gate_proj.weight
    
    print(f'Shape of weight_1_down: {weight_1_down.shape}')
    print(f'Shape of weight_1_up: {weight_1_up.shape}')
    print(f'Shape of weight_1_gate: {weight_1_gate.shape}')
    print(f'weight_1_down[:5, :5]:\n{weight_1_down[:5, :5]} [:-5, :-5]:\n{weight_1_down[:-5, :-5]}')
    
    weight_1_up = weight_1_up.to(torch.float32)
    weight_2_up = weight_2_up.to(torch.float32)
    # random select k row vectors from up
    k = 1
    idxs = torch.randperm(weight_2_up.shape[0])[:k]
    test_vector_list = [weight_2_up[idx].reshape(-1,1) for idx in idxs]
    print(f'len of test_vector_list: {len(test_vector_list)}, shape of test_vector_list[0]: {test_vector_list[0].shape}')
    
    # run lasso regression to solve Ax = b, where A is weight_1_up, b is a vector in test_vector_list
    # to see that how sparse can be the solution x, and see if we quantize x into -scale, 0, scale, how many 0s can be obtained and what's the error
    error_list = []
    for i, test_vector in enumerate(test_vector_list):
        #####################
        # first calculate the minimum L2 loss between all the row vectors in weight_1_up and test_vector
        L2error = torch.norm(weight_1_up - test_vector.reshape(1,-1), dim=1)
        print(f'minimum L2 error: {torch.min(L2error).item()}, mean L2 error: {torch.mean(L2error).item()}, norm of test_vector: {torch.norm(test_vector).item()}')
        # minimum L2 error: 1.4224458932876587, mean L2 error: 1.6024643182754517, norm of test_vector: 1.1489225625991821 
        # omg sooo bad similarity
        
        normalized_weight_1_up = weight_1_up / torch.norm(weight_1_up, dim=1).reshape(-1,1)
        normalized_test_vector = test_vector / torch.norm(test_vector)
        L2 = torch.norm(normalized_weight_1_up - normalized_test_vector.reshape(1,-1), dim=1)
        print(f'minimum L2 error: {torch.min(L2).item()}, mean L2 error: {torch.mean(L2).item()}, norm of normalized_test_vector: {torch.norm(normalized_test_vector).item()}')
        # minimum L2 error: 1.3687177896499634, mean L2 error: 1.4140182733535767, norm of normalized_test_vector: 1.0000003576278687
        # ok fine, almost all the row vectors are unrelated hahaha
        
        
        #####################
        print(f'\ntest_vector {i} ==================')
        x = lasso_regression(weight_1_up.T, test_vector, device=device)
        weight_1_up = weight_1_up.to(device)
        # sparsity of x
        print(f'sparsity of x: {torch.sum(x == 0).item() / x.numel()}, l1 norm of x: {torch.norm(x, p=1).item()}, l2 norm of x: {torch.norm(x, p=2).item()}')
        # calculate the error
        test_vector = test_vector.to(device)
        print(f'error: {torch.norm(weight_1_up.T @ x - test_vector).item()}')
        print(f'x[:5]: {x[:5].reshape(-1)}, x[-5:]: {x[-5:].reshape(-1)}')
        print(f'abs mean of x: {torch.mean(torch.abs(x)).item()}, std of x: {torch.std(x).item()}, abs max of x: {torch.max(torch.abs(x)).item()}')
        
        # quantize x
        bound = 0.005
        scale = bound * 2
        x_quantized = torch.where(x < -bound, -scale, torch.where(x > bound, scale, 0))
        print(f'sparsity of x_quantized: {torch.sum(torch.abs(x_quantized) < 1e-3 ).item() / x_quantized.numel()}, l1 norm of x_quantized: {torch.norm(x_quantized, p=1).item()}, l2 norm of x_quantized: {torch.norm(x_quantized, p=2).item()}')
        print(f'error of x_quantized: {torch.norm(weight_1_up.T @ x_quantized - test_vector).item()}')
        print(f'x_quantized[:5]: {x_quantized[:5].reshape(-1)}, x_quantized[-5:]: {x_quantized[-5:].reshape(-1)}')
        

        
# def 
    
    
    


if args.run_flag == 0:
    llama2_tester()
elif args.run_flag == -1:
    print(f'run SVD_research()')
    SVD_research(weight_2_down, weight_2_up, weight_2_gate, weight_3_down, weight_3_up, weight_3_gate)
elif args.run_flag == 100:
    test_device_map()
elif args.run_flag == 99:
    gate_visualizer()
elif args.run_flag == 1000:
    # test_replace_similarity(dataset_name='arxiv-math', device='cuda:0')
    test_replace_similarity2(dataset_name=args.dataset_name, _device=f'cuda:{args.device}', downsample_ratio=args.ratio)
elif args.run_flag == 1001:
    test_replace_similarity(dataset_names=args.dataset_name, device=f'cuda:{args.device}', downsample_ratio=args.ratio)
elif args.run_flag == 1002:
    find_left_mul_llama2(device=f'cuda:{args.device}', rank=args.rank)
elif args.run_flag == 1003:
    find_left_mul_llama3(device=f'cuda:{args.device}')
elif args.run_flag == 6199:
    nonlinear_compress(rank=args.rank, device=f'cuda:{args.device}')
elif args.run_flag == 6200: 
    try_int8(device=f'cuda:{args.device}', rank=args.rank, if_try=args.if_try)
elif args.run_flag == 6201:
    try_sparse_transform(device=f'cuda:{args.device}', rank=args.rank, if_try=args.if_try)
elif args.run_flag == 6202:
    vis_L2_diff()
else:
    lora_recovery()
    llama2_tester()

