import torch
from PIL import Image
from pathlib import Path

from diffusers import DiffusionPipeline
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, get_peft_model_state_dict
from safetensors.torch import load_file
from flow_grpo.diffusers_patch.qwenimage_pipeline_with_logprob import pipeline_with_logprob
from transformers import Qwen2_5_VLForConditionalGeneration

base_model_path = "" 
te_model_path = ""
lora_checkpoint_path = ""
merged_model_save_path = "/".join(lora_checkpoint_path.split("/")[:3]) + "/merged_model"
output_dir = "scripts/demo/"
device = "cuda:1" if torch.cuda.is_available() else "cpu"

print(f"配置已设置 -> base_model: {base_model_path}, te_model: {te_model_path}, lora: {lora_checkpoint_path}, device: {device}", flush=True)

prompt = ["A coffee shop entrance features a chalkboard sign reading \"Qwen Coffee 😊 $2 per cup,\" with a neon light beside it displaying \"通义千问\". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written \"π≈3.1415926-53589793-23846264-33832795-02384197\"."]
negative_prompt = [" "] * len(prompt)
width = 512
height = 512
num_inference_steps = 50
true_cfg_scale = 4.0
seed = 42

print("开始加载 DiffusionPipeline ...", flush=True)
pipe = DiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.bfloat16
)
print("已加载 DiffusionPipeline。", flush=True)

old_te = pipe.text_encoder
print("开始加载新的 text encoder (Qwen2_5_VLForConditionalGeneration) ...", flush=True)
new_te = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    te_model_path,
    torch_dtype=torch.bfloat16
)
pipe.text_encoder = new_te
del old_te
print("已替换 text_encoder。", flush=True)

target_modules = [
    "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
    "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
    "img_mlp.net.0.proj", "img_mlp.net.2", "txt_mlp.net.0.proj", "txt_mlp.net.2",
]
lora_config = LoraConfig(
    r=64, lora_alpha=128, init_lora_weights="gaussian", target_modules=target_modules
)
pipe.transformer = get_peft_model(pipe.transformer, lora_config)
print("已为 transformer 应用 LoRA PEFT 配置。", flush=True)

print(f"开始加载 LoRA 权重文件：{lora_checkpoint_path} ...", flush=True)
lora_state_dict = load_file(lora_checkpoint_path, device="cpu")
print(f"已加载 LoRA 权重文件，包含 {len(lora_state_dict)} 个键。", flush=True)

expected_lora_keys = {k for k in get_peft_model_state_dict(pipe.transformer).keys() if "lora_" in k}

ANCHORS = ["transformer_blocks.", "img_mlp.", "txt_mlp."]

def canonical_suffix(name: str) -> str | None:
    for a in ANCHORS:
        pos = name.find(a)
        if pos != -1:
            return name[pos:]
    return None

def normalize_suffix(name: str) -> str | None:
    suf = canonical_suffix(name)
    if suf is None:
        return None
    # 统一去掉 ckpt 中的适配器名片段 ".default"
    # 例如: lora_A.default.weight -> lora_A.weight
    suf = suf.replace(".default.weight", ".weight")
    return suf

# 期望键：后缀 -> 完整 key
exp_suffix2full = {}
for k in expected_lora_keys:
    suf = normalize_suffix(k)
    if suf is not None:
        exp_suffix2full[suf] = k

# ckpt 中 LoRA 项（仅 lora_*）
ckpt_lora_items = {k: v for k, v in lora_state_dict.items() if "lora_" in k}
ckpt_suffix2tensor = {}
for k, v in ckpt_lora_items.items():
    suf = normalize_suffix(k)
    if suf is not None:
        ckpt_suffix2tensor[suf] = v

# 按后缀映射
mapped_state = {}
for suf, full_key in exp_suffix2full.items():
    if suf in ckpt_suffix2tensor:
        mapped_state[full_key] = ckpt_suffix2tensor[suf]

matched = len(mapped_state)
total_expected = len(exp_suffix2full)
missing = total_expected - matched
extra_ckpt = max(0, len(ckpt_suffix2tensor) - matched)
print(f"LoRA 对齐统计 -> 期望: {total_expected}, 匹配: {matched}, 缺失: {missing}, 多余(ckpt未用): {extra_ckpt}", flush=True)

if matched == 0:
    sample_exp = list(sorted(exp_suffix2full))[:5]
    sample_ckpt = list(sorted(ckpt_suffix2tensor))[:5]
    print("样例 expected 后缀:", *sample_exp, sep="\n  ", flush=True)
    print("样例 ckpt 后缀:", *sample_ckpt, sep="\n  ", flush=True)
    raise ValueError("未匹配到任何 LoRA 权重，请检查训练/推理的模块命名是否一致。")

# 代表性参数用于 merge 前后对比
rep_param_name = None
for n, p in pipe.transformer.base_model.model.named_parameters():
    if n.endswith("attn.to_q.weight"):
        rep_param_name = n
        rep_before = p.detach().float().cpu().clone()
        break
print(f"代表性参数: {rep_param_name}", flush=True)

print("加载映射后的 LoRA 权重到模型 ...", flush=True)
set_peft_model_state_dict(pipe.transformer, mapped_state)  # adapter_name=默认"default"
print("LoRA 权重加载完成。", flush=True)

print("执行 merge_and_unload ...", flush=True)
pipe.transformer = pipe.transformer.merge_and_unload()
print("Merge 完成。", flush=True)

# 合并后差异验证
if rep_param_name:
    rep_suf = canonical_suffix(rep_param_name)
    rep_after = None
    for n, p in pipe.transformer.named_parameters():
        if rep_suf and n.endswith(rep_suf.replace("base_model.model.", "")):
            rep_after = p.detach().float().cpu()
            break
    if rep_after is not None:
        delta = (rep_after - rep_before).abs().mean().item()
        print(f"merge 前后代表性权重平均绝对差: {delta:.6f}", flush=True)
    else:
        print("未在合并后模型中定位到代表性参数，跳过差异对比。", flush=True)

print(f"开始保存合并后的模型到 {merged_model_save_path} ...", flush=True)
pipe.save_pretrained(merged_model_save_path)
print(f"Merged model saved to {merged_model_save_path}", flush=True)

# print(f"将 pipeline 移动到设备 {device} ...", flush=True)
# pipe = pipe.to(device)
# print(f"pipeline 已移动到设备：{device}", flush=True)

# generator = torch.Generator(device=device).manual_seed(seed)
# print(f"已设置随机种子 seed={seed}，generator 设备={device}", flush=True)

# print("开始推理（pipeline_with_logprob） ...", flush=True)
# with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=(device=="cuda")):
#     results = pipeline_with_logprob(
#         pipe,
#         prompt=prompt,
#         negative_prompt=negative_prompt,
#         num_inference_steps=num_inference_steps,
#         true_cfg_scale=true_cfg_scale,
#         height=height,
#         width=width,
#         generator=generator,
#         noise_level=0.0, 
#     )
# print("推理完成。", flush=True)

# generated_images = results['images']
# print(f"生成图片数量：{len(generated_images)}", flush=True)

# output_path = Path(output_dir)
# for i, img in enumerate(generated_images):
#     save_path = output_path / f"qwenimage_lora_inference_example_{i}.png"
#     img.save(save_path)
#     print(f"已保存图片 -> {save_path}", flush=True)