# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo

import argparse
import math
import os
import torch
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

def load_state_dict(file_name):
    if model_util.is_safetensors(file_name):
        sd = load_file(file_name)
        with safe_open(file_name, framework="pt") as f:
            metadata = f.metadata()
    else:
        sd = torch.load(file_name, map_location="cpu")
        metadata = None

    return sd, metadata


def save_to_file(file_name, model, metadata):
    if model_util.is_safetensors(file_name):
        save_file(model, file_name, metadata)
    else:
        torch.save(model, file_name)


def split_lora_model(lora_sd, unit):
    max_rank = 0

    # Extract loaded lora dim and alpha
    for key, value in lora_sd.items():
        if "lora_down" in key:
            rank = value.size()[0]
            if rank > max_rank:
                max_rank = rank
    logger.info(f"Max rank: {max_rank}")

    rank = unit
    split_models = []
    new_alpha = None
    while rank < max_rank:
        logger.info(f"Splitting rank {rank}")
        new_sd = {}
        for key, value in lora_sd.items():
            if "lora_down" in key:
                new_sd[key] = value[:rank].contiguous()
            elif "lora_up" in key:
                new_sd[key] = value[:, :rank].contiguous()
            else:
                # なぜかscaleするとおかしくなる……
                # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
                # scale = math.sqrt(this_rank / rank)  # rank is > unit
                # logger.info(key, value.size(), this_rank, rank, value, scale)
                # new_alpha = value * scale  # always same
                # new_sd[key] = new_alpha
                new_sd[key] = value

        split_models.append((new_sd, rank, new_alpha))
        rank += unit

    return max_rank, split_models


def split(args):
    logger.info("loading Model...")
    lora_sd, metadata = load_state_dict(args.model)

    logger.info("Splitting Model...")
    original_rank, split_models = split_lora_model(lora_sd, args.unit)

    comment = metadata.get("ss_training_comment", "")
    for state_dict, new_rank, new_alpha in split_models:
        # update metadata
        if metadata is None:
            new_metadata = {}
        else:
            new_metadata = metadata.copy()

        new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
        new_metadata["ss_network_dim"] = str(new_rank)
        # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())

        model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
        metadata["sshs_model_hash"] = model_hash
        metadata["sshs_legacy_hash"] = legacy_hash

        filename, ext = os.path.splitext(args.save_to)
        model_file_name = filename + f"-{new_rank:04d}{ext}"

        logger.info(f"saving model to: {model_file_name}")
        save_to_file(model_file_name, state_dict, new_metadata)


def setup_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()

    parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
    parser.add_argument(
        "--save_to",
        type=str,
        default=None,
        help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
    )
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
    )

    return parser


if __name__ == "__main__":
    parser = setup_parser()

    args = parser.parse_args()
    split(args)
