import os
import argparse
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Block Hadamard HiRA (baseline and optimized)
from block_hadamard_hira import (
    BlockHadamardHiRAConfig,
    apply_block_hadamard_hira,
    set_adapter_state_dict as set_bhra_adapter_state_dict,
)


def _infer_adapter_name_from_state(adapter_state: dict) -> str:
    # keys look like: "...block_lora_A.<adapter_name>" or "...block_lora_B.<adapter_name>"
    for k in adapter_state.keys():
        if ".block_lora_A." in k:
            return k.split(".block_lora_A.", 1)[1].split(".")[0]
        if ".block_lora_B." in k:
            return k.split(".block_lora_B.", 1)[1].split(".")[0]
    return "default"


def load_and_merge_bhra(base_model_name: str, adapter_path: str, output_path: str):
    print(f"Loading base model: {base_model_name}")
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)

    with open(os.path.join(adapter_path, "adapter_config.json"), "r") as f:
        cfg = json.load(f)

    peft_type = cfg.get("peft_type", "Block_Hadamard_HiRA")

    # Load adapter state first to infer the true adapter_name used during training
    adapter_state = torch.load(
        os.path.join(adapter_path, "adapter_model.bin"),
        map_location="cuda" if torch.cuda.is_available() else "cpu",
        weights_only=True,
    )
    adapter_name = _infer_adapter_name_from_state(adapter_state)
    print(f"Detected adapter name in state dict: '{adapter_name}' (total keys: {len(adapter_state)})")

    # Normalize fields
    target_modules = cfg.get("target_modules")
    if isinstance(target_modules, str):
        target_modules = [target_modules]


    bhra_cfg = BlockHadamardHiRAConfig(
        r=cfg["r"],
        alpha=cfg["alpha"],
        dropout=cfg.get("dropout", 0.0),
        target_modules=target_modules,
        bias=cfg.get("bias", "none"),
        init_lora_weights=cfg.get("init_lora_weights", True),
        num_blocks=cfg.get("num_blocks", 4),
        block_arrangement=cfg.get("block_arrangement", "square"),
        use_fast_inference=True,
    )
    model = apply_block_hadamard_hira(base_model, bhra_cfg, adapter_name=adapter_name)
    set_fn = set_bhra_adapter_state_dict

    # Load adapter params
    set_fn(model, adapter_state, adapter_name)

    print("Merging BHRA adapter into base model...")
    merged = model.merge_and_unload()

    print(f"Saving merged model to {output_path}")
    os.makedirs(output_path, exist_ok=True)

    # IMPORTANT: our model had a monkey-patched save_pretrained that saves adapters.
    # To save a real HF model, reinstantiate a fresh base model class and load weights.
    merged_state = merged.state_dict()
    fresh = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
    fresh.load_state_dict(merged_state)
    fresh.save_pretrained(output_path)

    # Save tokenizer alongside
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.save_pretrained(output_path)

    return output_path


if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Merge and save Block Hadamard HiRA adapter")
    p.add_argument("--base_model", type=str, required=True)
    p.add_argument("--adapter_path", type=str, required=True)
    p.add_argument("--output_path", type=str, required=True)
    args = p.parse_args()
    load_and_merge_bhra(args.base_model, args.adapter_path, args.output_path)
