from typing import List

import os
import json
import glob
import torch
import transformers
from datasets import load_dataset

"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer

def get_checkpoints(checkpoint_folder):
    checkpoint_list = glob.glob(os.path.join(checkpoint_folder, 'step_*'))
    checkpoint_list.sort(key=lambda x: int(x.split('_')[-1]))
    return checkpoint_list

def get_lora_weight(checkpoint):
    lora_config = json.load(open(os.path.join(checkpoint, 'adapter_config.json')))
    weight_dict = torch.load(os.path.join(checkpoint, 'adapter_model.bin'))
    lora_weights = {module: {'lora_A': [], 'lora_B': []} for module in lora_config['target_modules']}
    for name, weight in weight_dict.items():
        pass
    


lora_weights = []
checkpoint_folder = 'output/llama7b_lora'
for ckpt in get_checkpoints(checkpoint_folder):
    get_lora_weight(ckpt)
    break