import torch
import argparse
import os
import shutil
from transformers import AutoModelForCausalLM, AutoTokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", required=True, help="Path to the original base model (e.g., Llama-3-8b)")
parser.add_argument("--user_lora", required=True)
parser.add_argument("--safety_lora", required=True)
parser.add_argument("--output_path", required=True, help="Path to save the fully merged model")
parser.add_argument("--alpha", type=float, default=0.1, help="Scaling factor for safety LoRA projection")

args = parser.parse_args()

print(f"Loading base model from {args.base_model_path}...")
# CPU로 로드하여 메모리 절약 (필요시 'cuda'로 변경)
base_model = AutoModelForCausalLM.from_pretrained(
    args.base_model_path, 
    torch_dtype=torch.float16,
    device_map="cpu",
    cache_dir="cache",
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)

print("Loading LoRA weights...")
W_u = torch.load(os.path.join(args.user_lora, "adapter_model.bin"), map_location="cpu")
W_s = torch.load(os.path.join(args.safety_lora, "adapter_model.bin"), map_location="cpu")

def get_orth_projected_delta(Wa_u, Wb_u, Wa_s, Wb_s, alpha=0.1):
    # 1. 각각의 Delta 계산 (float32 권장)
    delta_u = (Wb_u.float() @ Wa_u.float())
    delta_s = (Wb_s.float() @ Wa_s.float())
    
    # 2. 직교 투영 (Vector space projection)
    flat_u = delta_u.flatten()
    flat_s = delta_s.flatten()
    
    dot_su = torch.dot(flat_s, flat_u)
    dot_uu = torch.dot(flat_u, flat_u)
    
    # 투영: Safety에서 User 방향 성분 제거
    proj_coeff = dot_su / (dot_uu + 1e-8)
    delta_s_perp = delta_s - (alpha * proj_coeff * delta_u)
    
    # 3. 합체 (User + Safety_perp)
    final_delta = delta_u + delta_s_perp
    return final_delta.to(torch.float16) # 베이스 모델 타입으로 변환

def check_rank_collapse(user_lora_path, eps=1e-2):
    print(f"Checking Rank Collapse for: {user_lora_path}")
    W_u = torch.load(os.path.join(user_lora_path, "adapter_model.bin"), map_location="cpu")
    
    total_layers = 0
    collapsed_layers = 0
    avg_effective_rank = 0.0
    
    print(f"{'Layer Name':<50} | {'Set Rank':<8} | {'Eff Rank':<8} | {'Ratio':<6}")
    print("-" * 80)

    for name in W_u.keys():
        if 'lora_A' in name:
            # Wa: [r, d_in]
            Wa = W_u[name].float()
            r = Wa.shape[0]
            
            # Singular Values 계산 (속도를 위해 svd_lowrank 대신 그냥 svd 사용해도 r이 작아서 빠름)
            # 여기서는 정밀하게 보기 위해 full svd of small matrix G 사용
            G = Wa @ Wa.T
            evals = torch.linalg.eigvalsh(G) # 고유값
            
            # 유효 랭크 기준: Max eigenvalue의 1%보다 큰 것의 개수
            threshold = eps * evals.max()
            eff_rank = (evals > threshold).sum().item()
            
            print(f"{name:<50} | {r:<8} | {eff_rank:<8} | {eff_rank/r:.2f}")
            
            total_layers += 1
            avg_effective_rank += eff_rank
            if eff_rank < r:
                collapsed_layers += 1

    print("-" * 80)
    print(f"Summary: {collapsed_layers}/{total_layers} layers have rank collapse.")
    print(f"Average Effective Rank: {avg_effective_rank / total_layers:.2f}")

# 사용법
# import pdb; pdb.set_trace()
check_rank_collapse(args.user_lora)

def get_fast_qr_projected_delta(Wa_u, Wb_u, Wa_s, Wb_s, alpha=0.1, eps=0.01):
    # 1. Delta S만 미리 계산 (Delta U는 나중에 필요할 때 계산)
    delta_s = (Wb_s.float() @ Wa_s.float())

    # # 2. User LoRA의 유효 Subspace 추출 (핵심 로직)
    # # 2-1. Wa_u의 Gram Matrix 계산 (작은 행렬 r x r)
    # # Wa_u: [r, d_in] -> G: [r, r]
    G = Wa_u.float() @ Wa_u.float().T
    
    # 2-2. 고유값 분해 (Eigen Decomposition) on small matrix
    # eigh는 대칭 행렬용이라 빠르고 정확함
    evals, evecs = torch.linalg.eigh(G)
    
    # 2-3. 유효 랭크 필터링 (Rank Collapse 대응)
    # 최대 고유값 대비 eps 비율 이상인 것만 살림
    valid_indices = evals > (eps * evals.max())
    
    if not valid_indices.any():
        # 유효한 랭크가 하나도 없으면(User가 아무것도 안 배운 상태), 투영 없이 그냥 더함
        return (Wb_u.float() @ Wa_u.float()) + delta_s


    if valid_indices.sum().item() == Wb_s.shape[1]:
        B_eff = Wb_u.float()
    else:
        # V_eff: 유효한 고유벡터들 [r, k] (k <= r)
        V_eff = evecs[:, valid_indices]

        # 3. Effective Basis 생성
        # Wb_u [d_out, r] @ V_eff [r, k] -> B_eff [d_out, k]
        # 이제 B_eff는 "User가 진짜로 사용하는 출력 공간"만 담고 있음
        B_eff = Wb_u.float() @ V_eff

    # 2. 최적화: delta_u 대신 Wb_u(작은 행렬)로 QR 수행
    # 결과인 Q는 delta_u를 QR한 것과 수학적으로 동일한 공간을 가짐
    # Wb_u: [d_out, rank] -> QR 속도 매우 빠름
    Q, _ = torch.linalg.qr(B_eff, mode='reduced')
    # Q, _ = torch.linalg.qr(Wb_u.float(), mode='reduced')

    # 3. 투영 및 제거
    delta_s_perp = delta_s - alpha *  Q @ (Q.T @ delta_s)

    # 4. 최종 합산
    delta_u = (Wb_u.float() @ Wa_u.float())
    final_delta = delta_u + delta_s_perp #* (delta_u.norm() / (delta_s_perp.norm() + 1e-8))
    
    print("UserLoRA Scale: ", delta_u.norm().item())
    print("Safety LoRA Projected Scale: ", (delta_s_perp).norm().item())
    print("Final Delta Scale: ", final_delta.norm().item()) 
    
    return final_delta.to(torch.float16)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

print("Merging LoRAs into Base Model...")
with torch.no_grad():
    # Base Model의 state_dict를 직접 수정
    base_params = base_model.state_dict()
    
    processed_layers = set()
    
    for name, param in W_u.items():
        if 'lora_A' in name:
            layer_key = name.replace('.lora_A.weight', '.weight').replace('base_model.model.', '')
            
            # PEFT 네이밍 규칙에 따라 base model의 key 찾기
            # 예: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight 
            # -> model.layers.0.self_attn.q_proj.weight
            
            # 실제 base_model 키 매칭 로직 (모델 구조에 따라 다를 수 있음. 일반적인 Llama/HF 구조 기준)
            base_key = name.replace('base_model.model.', '').replace('.lora_A.weight', '.weight')
            if base_key in base_params:
                lora_a_name = name
                lora_b_name = name.replace('lora_A', 'lora_B')
                
                # User LoRA
                ua = W_u[lora_a_name]
                ub = W_u[lora_b_name]
                
                # Safety LoRA가 있는 경우
                if lora_a_name in W_s and lora_b_name in W_s:
                    sa = W_s[lora_a_name]
                    sb = W_s[lora_b_name]
                    
                    # 직교 투영된 Delta W 계산
                    delta_w = get_fast_qr_projected_delta(ua, ub, sa, sb, alpha=args.alpha)
                    # delta_w = get_orth_projected_delta(ua, ub, sa, sb, alpha=args.alpha)
                else:
                    # User LoRA만 있는 경우
                    delta_w = (ub.float() @ ua.float()).to(torch.float16)
                
                # Base Model 가중치에 더하기 (In-place)
                base_params[base_key] += delta_w
                print(f"Merged layer: {base_key}")
                del delta_w

# import pdb; pdb.set_trace()
end_event.record()
torch.cuda.synchronize()
ont_shot_time = start_event.elapsed_time(end_event)
print("Estimated one shot time {} (h)".format(ont_shot_time/ 1000/3600))
memory_usage = torch.cuda.memory_reserved()
print(f"Memory usage: { memory_usage:.2f} GPU memory used")

print(f"Saving full model to {args.output_path}...")
base_model.save_pretrained(args.output_path)
tokenizer.save_pretrained(args.output_path)
print("Done.")