import argparse
import sys
import os

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from peft import PeftConfig, get_peft_model

class Logger(object):
    def __init__(self, f, fpath=None):
        self.console = f
        self.file = None
        if fpath is not None:
            # mkdir_if_missing(os.path.dirname(fpath))
            self.file = open(fpath, 'w')

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()

parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default="meta-llama/Llama-2-7b-hf")
parser.add_argument('--model_max_length', type=int, default=512)
parser.add_argument('--output_dir', type=str, default="None")
parser.add_argument('--rank', type=int, default=128)
args = parser.parse_args()

model = AutoModelForCausalLM.from_pretrained(
    args.model_name_or_path,
    quantization_config=None,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    device_map="auto"
)

sys.stdout = Logger(sys.stdout, f"{args.output_dir}/log_norm.txt")

adapter_dir = f"{args.output_dir}/adapter_model"
config=PeftConfig.from_pretrained(adapter_dir)
peft_model = get_peft_model(model, config)
peft_model.load_adapter(adapter_dir, adapter_name='default')
# print(peft_model)
r = args.rank
L = len(peft_model.base_model.model.model.layers)
q = np.zeros((L, 3))
k = np.zeros((L, 3))
v = np.zeros((L, 3))
o = np.zeros((L, 3))
g = np.zeros((L, 3))
u = np.zeros((L, 3))
d = np.zeros((L, 3))

for index, layer in  enumerate(peft_model.base_model.model.model.layers):

    dW = layer.self_attn.q_proj.get_delta_weight('default')
    _, s, _ = torch.svd_lowrank(dW, q=2*r, niter=r+2)
    nuc, fro, inf = s.sum().item(), s.norm().item(), s[0].item()
    q[index, 0] = nuc
    q[index, 1] = fro
    q[index, 2] = inf 
    print(f"layer {index+1:2d}, q | nuc: {nuc:.2f}, fro: {fro:.2f}, inf: {inf:.2f} ")

    dW = layer.self_attn.k_proj.get_delta_weight('default')
    _, s, _ = torch.svd_lowrank(dW, q=2*r, niter=r+2)
    nuc, fro, inf = s.sum().item(), s.norm().item(), s[0].item()
    k[index, 0] = nuc
    k[index, 1] = fro
    k[index, 2] = inf 
    print(f"layer {index+1:2d}, k | nuc: {nuc:.2f}, fro: {fro:.2f}, inf: {inf:.2f} ")

    dW = layer.self_attn.v_proj.get_delta_weight('default')
    _, s, _ = torch.svd_lowrank(dW, q=2*r, niter=r+2)
    nuc, fro, inf = s.sum().item(), s.norm().item(), s[0].item()
    v[index, 0] = nuc
    v[index, 1] = fro
    v[index, 2] = inf 
    print(f"layer {index+1:2d}, v | nuc: {nuc:.2f}, fro: {fro:.2f}, inf: {inf:.2f} ")

    dW = layer.self_attn.o_proj.get_delta_weight('default')
    _, s, _ = torch.svd_lowrank(dW, q=2*r, niter=r+2)
    nuc, fro, inf = s.sum().item(), s.norm().item(), s[0].item()
    o[index, 0] = nuc
    o[index, 1] = fro
    o[index, 2] = inf 
    print(f"layer {index+1:2d}, o | nuc: {nuc:.2f}, fro: {fro:.2f}, inf: {inf:.2f} ")

    dW = layer.mlp.gate_proj.get_delta_weight('default')
    _, s, _ = torch.svd_lowrank(dW, q=2*r, niter=r+2)
    nuc, fro, inf = s.sum().item(), s.norm().item(), s[0].item()
    g[index, 0] = nuc
    g[index, 1] = fro
    g[index, 2] = inf 
    print(f"layer {index+1:2d}, g | nuc: {nuc:.2f}, fro: {fro:.2f}, inf: {inf:.2f} ")

    dW = layer.mlp.up_proj.get_delta_weight('default')
    _, s, _ = torch.svd_lowrank(dW, q=2*r, niter=r+2)
    nuc, fro, inf = s.sum().item(), s.norm().item(), s[0].item()
    u[index, 0] = nuc
    u[index, 1] = fro
    u[index, 2] = inf 
    print(f"layer {index+1:2d}, u | nuc: {nuc:.2f}, fro: {fro:.2f}, inf: {inf:.2f} ")

    dW = layer.mlp.down_proj.get_delta_weight('default')
    _, s, _ = torch.svd_lowrank(dW, q=2*r, niter=r+2)
    nuc, fro, inf = s.sum().item(), s.norm().item(), s[0].item()
    d[index, 0] = nuc
    d[index, 1] = fro
    d[index, 2] = inf 
    print(f"layer {index+1:2d}, d | nuc: {nuc:.2f}, fro: {fro:.2f}, inf: {inf:.2f} ")

np.savetxt(f'{args.output_dir}/q.csv', q, delimiter=',')
np.savetxt(f'{args.output_dir}/k.csv', k, delimiter=',')
np.savetxt(f'{args.output_dir}/v.csv', v, delimiter=',')
np.savetxt(f'{args.output_dir}/o.csv', o, delimiter=',')
np.savetxt(f'{args.output_dir}/g.csv', g, delimiter=',')
np.savetxt(f'{args.output_dir}/u.csv', u, delimiter=',')
np.savetxt(f'{args.output_dir}/d.csv', d, delimiter=',')