import os
import json
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torchvision import models, transforms
from transformers import AutoModelForCausalLM, AutoTokenizer

import warnings

warnings.filterwarnings("ignore")

# ========== 1. 配置 ==========
IMG_DIR   = "MNR_figure"                        # 合成图像目录
DSL       = "AMCD-LFNMR-237-Chat.jsonl"         # 多轮对话 JSONL
IMG_SIZE  = 224                                 # 图像输入尺寸
LLM_NAME  = "Qwen/Qwen3-4B"                     # Qwen3-4B Instruct 模型 :contentReference[oaicite:0]{index=0}

# ========== 2. 读取样本 & 统计 mean/std ==========
def compute_mean_std(df, img_dir, img_size):
    imgs = []
    for _, row in df.iterrows():
        apple_id = int(row["apple_id"])
        img_path = os.path.join(img_dir, f"{apple_id}.png")
        img = Image.open(img_path).convert("RGB").resize((img_size, img_size))
        arr = np.array(img) / 255.0
        imgs.append(arr)
    imgs = np.stack(imgs, axis=0)           # (N, H, W, C)
    imgs = imgs.transpose(0, 3, 1, 2)       # (N, C, H, W)
    mean = imgs.mean(axis=(0, 2, 3))
    std  = imgs.std(axis=(0, 2, 3))
    return mean, std

records = []
with open(DSL, "r", encoding="utf-8") as f:
    for line in f:
        data = json.loads(line)
        records.append({
            "apple_id": data["apple_id"],
            "image":    data["image"],
            "turns":    data["turns"]
        })
df = pd.DataFrame(records)

mean, std = compute_mean_std(df, IMG_DIR, IMG_SIZE)
print(f"[INFO] Train set mean: {mean}, std: {std}")

# ========== 3. 定义预处理 & 特征提取 ==========
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean.tolist(), std=std.tolist()),
])

# 把 VGG16 一并搬到 GPU（或 CPU）上
vgg = models.vgg16(pretrained=True).to(device).eval()

def get_vgg16_feat(img_tensor: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        x = vgg.features(img_tensor)
        x = vgg.avgpool(x)
        x = torch.flatten(x, 1)
        x = vgg.classifier[0](x)  # fc1
    return x

# ========== 4. 加载 Qwen3-4B ==========
print("[INFO] Loading Qwen3-4B...")
tokenizer = AutoTokenizer.from_pretrained(LLM_NAME, trust_remote_code=True)

llm = AutoModelForCausalLM.from_pretrained(
    LLM_NAME,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)


device    = llm.device
print(f"[INFO] LLM loaded on {device}")

# ========== 5. 端到端 Demo ==========
sample = records[0]
print(f"[INFO] Demo 使用 apple_id = {sample['apple_id']}")

# 5.1 图像特征
img     = Image.open(sample["image"]).convert("RGB")
img_t   = transform(img).unsqueeze(0).to(device)
feat    = get_vgg16_feat(img_t)                   # (1,4096)
feat_np = feat.squeeze(0).cpu().numpy()
feat_str = ", ".join(f"{x:.4f}" for x in feat_np[:10]) + ", …"

# 5.2 构造 prompt
dialogue = ""
for turn in sample["turns"]:
    dialogue += f"{turn['role'].capitalize()}: {turn['text']}\n"

prompt = (
    "You are a fruit quality expert. Based on visual features and prior dialogue, answer the last user question.\n\n"
    f"[VISUAL_FEATURES] {feat_str}\n{dialogue}Assistant: "
)

inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_len = inputs["input_ids"].shape[1]

outputs = llm.generate(
    **inputs,
    max_new_tokens=128,
    do_sample=False,
    eos_token_id=tokenizer.eos_token_id,
)
new_tokens = outputs[0][input_len:]
answer = tokenizer.decode(new_tokens, skip_special_tokens=True)

# ========== 6. 输出 ==========
print("\n" + "="*40)
print("Prompt:\n", prompt)
print("\nGenerated Answer:\n", answer)
print("="*40)