from collections import OrderedDict
import os
import sys
from typing import Dict
import typing
import torch
import bitsandbytes as bnb
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--type", default="pissa", type=str)
parser.add_argument("--base_model", default="", type=str)
parser.add_argument("--lora_init", default="none", type=str)
parser.add_argument("--lora_checkpoint", default="", type=str)
parser.add_argument("--output", default="", type=str)
parser.add_argument("--quant", default="none", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--lora_alpha", default=16, type=int)
args = parser.parse_args()
device= args.device
base_model = args.base_model
init_lora= args.lora_init
lora= args.lora_checkpoint
output= args.output
quant= args.quant
lora_alpha = args.lora_alpha

with torch.no_grad():
    w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
    # merge LoRA-only slim checkpoint into the main weights
    w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')

    if args.type=='pissa':
        w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu')
    for k in w_lora.keys():
        w[k] = w_lora[k]
    output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
    # merge LoRA weights
    keys = list(w.keys())
    for k in keys:
        if k.endswith('.weight'):
            prefix = k[:-len('.weight')]
            lora_A = prefix + '.lora_A'
            lora_B = prefix + '.lora_B'
            init_lora_A = 'model.' + prefix + '.init_lora_A'
            init_lora_B = 'model.' + prefix + '.init_lora_B'
            if lora_A in keys:
                assert lora_B in keys
                print(f'merging {lora_A} and {lora_B} into {k}')
                assert w[lora_B].shape[1] == w[lora_A].shape[0]
                lora_r = w[lora_B].shape[1]
                w[k] = w[k].to(device=device)
                w[lora_A] = w[lora_A].to(device=device)
                w[lora_B] = w[lora_B].to(device=device)
                
                if args.type=='pissa':
                    w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device)
                    w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device)
                    if quant=='4bit':
                        qw,qs = bnb.functional.quantize_4bit(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
                        w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16)
                    elif quant == 'nf4':
                        qw,qs = bnb.functional.quantize_nf4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
                        w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
                    elif quant == 'fp4':
                        qw,qs = bnb.functional.quantize_fp4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
                        w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
                    elif quant == 'int8':
                        qw,qs = bnb.functional.quantize(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A])
                        w[k] = (bnb.functional.dequantize(qw,state=qs)).to(dtype=torch.bfloat16)
                    else:
                        w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16)
                    w[k] +=  w[lora_B] @ w[lora_A]
                else:
                    if quant=='4bit':
                        qw,qs = bnb.functional.quantize_4bit(w[k])
                        w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16)
                    elif quant=='nf4':
                        qw,qs = bnb.functional.quantize_nf4(w[k])
                        w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
                    elif quant=='fp4':
                        qw,qs = bnb.functional.quantize_fp4(w[k])
                        w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16)
                    elif quant=='int8':
                        qw,qs = bnb.functional.quantize(w[k])
                        w[k] = (bnb.functional.dequantize(qw,state=qs)).to(dtype=torch.bfloat16)
                    w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
                output_w[k] = w[k].to(device='cpu', copy=True)
                del w[k]
                del w[lora_A]
                del w[lora_B]
                continue

        if 'lora' not in k:
            print(f'retaining {k}')
            output_w[k] = w[k].clone()
            del w[k]
    torch.save(output_w, output)