import os
import time
import random
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
from sklearn.model_selection import train_test_split

from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score,
    recall_score, roc_auc_score, f1_score, matthews_corrcoef
)
import numpy as np
from transformers import AutoConfig

import warnings
warnings.filterwarnings("ignore")

# ================= 超参数 =================
LABEL_FILE = './苹果称重-1-7.xlsx'
IMG_DIR = './MNR_figure'
TEXT_ROOT = './MNR_CL_text_right'

text_root_dir = Path(TEXT_ROOT)
TEXT_SOURCES = sorted([d.name for d in text_root_dir.iterdir() if d.is_dir()])

TEXT_MODELS = [
    "bert-base-uncased",
    "roberta-base",
    "distilbert-base-uncased",
    "albert-base-v2"
]

VISION_MODELS = [
    "resnet18", "resnet34", "resnet50",
    "densenet121", "efficientnet_b0", "mobilenet_v3_small",
    "shufflenet_v2_x1_0", "vgg16"]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_LEN = 512
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 1
LR = 2e-4
SEED = 42
OUTPUT_DIR = './logs_fusion_Gate'
# ==========================================

random.seed(SEED)
torch.manual_seed(SEED)

# ============== Dataset ==============

def compute_mean_std(df, img_dir, img_size):
    imgs = []
    for idx, row in df.iterrows():
        img_path = os.path.join(img_dir, f"{int(row['id'])}.png")
        img = Image.open(img_path).convert('RGB').resize((img_size, img_size))
        img = np.array(img) / 255.0  # 归一化到[0,1]
        imgs.append(img)
    imgs = np.stack(imgs, axis=0)  # (N, H, W, C)
    imgs = imgs.transpose((0, 3, 1, 2))  # (N, C, H, W)
    mean = imgs.mean(axis=(0, 2, 3))
    std = imgs.std(axis=(0, 2, 3))
    return mean, std

# ==== 定义 FiLM 模块 （轻量可解释）====
class FiLM(nn.Module):
    def __init__(self, text_dim, num_channels):
        super().__init__()
        self.gamma_gen = nn.Linear(text_dim, num_channels)
        self.beta_gen  = nn.Linear(text_dim, num_channels)

    def forward(self, x, text_feat):
        # x: (B, C, H, W), text_feat: (B, D)
        gamma = self.gamma_gen(text_feat).unsqueeze(-1).unsqueeze(-1)
        beta  = self.beta_gen(text_feat).unsqueeze(-1).unsqueeze(-1)
        return gamma * x + beta

class FusionDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_dir, tokenizer, max_len, transform):
        self.df = df
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        id_ = row["id"]
        label = row["label"]

        # image
        img_path = os.path.join(self.img_dir, f"{int(id_)}.png")
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # text
        encoding = self.tokenizer(
            row["text"], padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        return {
            "image": image,
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

# ============== Model ==============
class FusionModel(nn.Module):
    def __init__(self, vision_backbone, text_model, text_hidden, num_channels, num_classes=2):
        super().__init__()
        self.vision_backbone = vision_backbone  # 输出 (B, C, H, W)
        self.text_model = text_model
        self.film = FiLM(text_hidden, num_channels)
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Linear(num_channels, num_classes)

    def forward(self, image, input_ids, attention_mask):
        # 视觉特征
        feat_map = self.vision_backbone(image)
        # 文本特征: 取 [CLS]
        txt_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        txt_feat = txt_out.last_hidden_state[:,0,:]
        # FiLM 调制
        mod_map = self.film(feat_map, txt_feat)
        # 池化与分类
        pooled = self.pool(mod_map).flatten(1)
        return self.classifier(pooled)

# ============== Helpers ==============
def get_vision_backbone_and_channels(name, pretrained=True):
    # 获取预训练模型
    base = getattr(models, name)(pretrained=pretrained)
    # 自动去除最后一层（分类层）
    children = list(base.children())
    backbone = nn.Sequential(*children[:-1])
    # 使用 dummy tensor 推断特征通道数
    with torch.no_grad():
        dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE)
        feat = backbone(dummy)
        num_channels = feat.shape[1]
    return backbone, num_channels

def load_data(text_source) -> pd.DataFrame:
    label_df = pd.read_excel(LABEL_FILE, dtype={0: str})
    label_df = label_df.iloc[:, [0, 3]].dropna()
    label_df.columns = ['id', 'label']
    label_df["id"] = label_df["id"].astype(str).str.zfill(3)

    id_set = set(label_df["id"].str.lstrip("0"))  # 移除前导0用于文件匹配
    text_dir = Path(TEXT_ROOT) / text_source

    texts = {}
    for txt_file in text_dir.glob("*.txt"):
        raw_id = txt_file.stem  # 如 1, 42
        if raw_id in id_set:
            std_id = raw_id.zfill(3)  # 对应 Excel 中的 '001'
            texts[std_id] = txt_file.read_text(encoding='utf-8', errors='ignore')

    df = label_df[label_df["id"].isin(texts.keys())].copy()
    df["text"] = df["id"].map(texts)
    df = df.dropna()
    df["label"] = df["label"].astype(int)
    return df

def evaluate(loader):
    model.eval()
    preds, targets, losses = [], [], []
    with torch.no_grad():
        for batch in loader:
            image = batch["image"].to(DEVICE)
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            out = model(image, input_ids, attention_mask)
            # —— 补丁开始 —— 
            # 如果 out 里有 NaN，先把它们全替换成 0
            if torch.isnan(out).any():
                out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
            # 计算 loss
            loss = criterion(out, labels)
            losses.append(loss.item() * image.size(0))

            # 计算概率
            probs = torch.softmax(out, dim=1)[:, 1]
            # 把概率里的 NaN 全部设为 0.5（最不偏不倚的中立值）
            probs = torch.nan_to_num(probs, nan=0.5)
            preds.append(probs.cpu())
            targets.append(labels.cpu())
        # —— 补丁结束 —— 

    preds = torch.cat(preds).numpy()
    targets = torch.cat(targets).numpy()
    y_hat = (preds >= 0.5).astype(int)

    metrics = {
        'loss':   sum(losses) / len(loader.dataset),
        'accuracy': accuracy_score(targets, y_hat),
        'balanced_accuracy': balanced_accuracy_score(targets, y_hat),
        'precision': precision_score(targets, y_hat, zero_division=0),
        'recall':    recall_score(targets, y_hat, zero_division=0),
        # 如果只有一个类或者仍有 NaN，就默认 AUC=0.5
        'auc':   roc_auc_score(targets, preds) if (len(np.unique(targets))>1 and not np.isnan(preds).any()) else 0.5,
        'mcc':   matthews_corrcoef(targets, y_hat),
        'f1':    f1_score(targets, y_hat, zero_division=0),
    }
    return metrics

def train_eval(model, train_loader, val_loader, optimizer, criterion, patience=10):
    best_val_f1 = 0.0
    best_val_metrics = {}
    best_train_metrics = {}
    last_val_metrics = {}
    last_train_metrics = {}

    no_improve_epochs = 0

    for epoch in range(1, EPOCHS + 1):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(
                batch["image"].to(DEVICE),
                batch["input_ids"].to(DEVICE),
                batch["attention_mask"].to(DEVICE)
            )
            loss = criterion(out, batch["label"].to(DEVICE))
            loss.backward()
            optimizer.step()

        train_metrics = evaluate(train_loader)
        val_metrics = evaluate(val_loader)

        last_train_metrics = train_metrics
        last_val_metrics = val_metrics

        print(f"Ep{epoch:02d} | "
              f"tr_f1={train_metrics['f1']:.3f} "
              f"val_f1={val_metrics['f1']:.3f} "
              f"val_auc={val_metrics['auc']:.3f} "
              f"val_mcc={val_metrics['mcc']:.3f}")

        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            best_val_metrics = val_metrics.copy()
            best_train_metrics = train_metrics.copy()
            no_improve_epochs = 0  # reset counter
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"⏹️ Early stopping triggered at epoch {epoch}")
                break

    return best_val_metrics, best_train_metrics, last_val_metrics, last_train_metrics, epoch

def measure_inference_latency(model, loader, device, max_samples=50, warmup=5):
    model.eval()
    total_time = 0
    total_num = 0
    with torch.no_grad():
        n = 0
        for batch in loader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            # warmup
            for _ in range(warmup):
                _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t0 = time.time()
            _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t1 = time.time()
            n_batch = image.size(0)
            total_time += (t1 - t0)
            total_num += n_batch

            n += n_batch
            if n >= max_samples:
                break

    avg_time = total_time / total_num * 1000  # ms/张
    return avg_time

# ============== Main Loop ==============
Path(OUTPUT_DIR).mkdir(exist_ok=True)

summary_path = Path(OUTPUT_DIR) / "summary_results.csv"
summary_records = []

# 计算总组合数
total_runs = len(TEXT_SOURCES) * len(TEXT_MODELS) * len(VISION_MODELS) * 2 * 2
run_count = 0
global_start_time = time.time()

for text_source in TEXT_SOURCES:
    df = load_data(text_source)
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=SEED, stratify=df["label"])
    
    # 统计训练集 mean/std
    mean, std = compute_mean_std(train_df, IMG_DIR, IMG_SIZE)
    print(f"[{text_source}] Train set mean: {mean}, std: {std}")
    
    # 定义transform
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean.tolist(), std=std.tolist()),
    ])

    for text_model_name in TEXT_MODELS:
        for vision_model_name in VISION_MODELS:
            for vision_init in ("pretrained", "scratch"):
                for text_init in ("pretrained", "scratch"):
                    run_count += 1
                    tag = f"{text_source}__{text_model_name.replace('/', '_')}--{text_init}__{vision_model_name}--{vision_init}"
                    print(f"\n[{run_count}/{total_runs}] 🔄 Running: {tag}")

                    # 初始化文本模型
                    if text_init == "pretrained":
                        text_model = AutoModel.from_pretrained(text_model_name).to(DEVICE)
                    else:
                        config = AutoConfig.from_pretrained(text_model_name)
                        text_model = AutoModel.from_config(config).to(DEVICE)

                    # 初始化视觉模型
                    vision_backbone, num_channels = get_vision_backbone_and_channels(vision_model_name, pretrained=(vision_init=="pretrained"))
                    model = FusionModel(vision_backbone,
                                        text_model,
                                        text_model.config.hidden_size,
                                        num_channels,
                                        num_classes=2).to(DEVICE)

                    # 计算参数量
                    total_params = sum(p.numel() for p in model.parameters())

                    # 数据与训练器
                    tokenizer = AutoTokenizer.from_pretrained(text_model_name)
                    train_ds = FusionDataset(train_df, IMG_DIR, tokenizer, MAX_LEN, transform)
                    val_ds = FusionDataset(val_df, IMG_DIR, tokenizer, MAX_LEN, transform)
                    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
                    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
                    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
                    criterion = nn.CrossEntropyLoss()

                    start_time = time.time()
                    best_val, best_train, last_val, last_train, epoch_run = train_eval(model, train_loader, val_loader, optimizer, criterion)
                    elapsed = round(time.time() - start_time, 2)
                    
                    #计算推理时间
                    latency_ms = measure_inference_latency(model, val_loader, DEVICE)

                    record = {
                            "src_text": text_source,
                            "model": text_model_name + "+" + vision_model_name,
                            "init": f"text={text_init}, vision={vision_init}",
                            "best_val_f1":  best_val.get("f1", 0),
                            "best_val_auc": best_val.get("auc", 0),
                            "best_val_mcc": best_val.get("mcc", 0),
                            "best_val_bal_acc": best_val.get("balanced_accuracy", 0),
                            "best_val_precision": best_val.get("precision", 0),
                            "best_val_recall":    best_val.get("recall", 0),
                            "best_train_f1":  best_train.get("f1", 0),
                            "best_train_auc": best_train.get("auc", 0),
                            "best_train_mcc": best_train.get("mcc", 0),
                            "best_train_bal_acc": best_train.get("balanced_accuracy", 0),
                            "best_train_precision": best_train.get("precision", 0),
                            "best_train_recall":    best_train.get("recall", 0),
                            "train_loss_last": last_train.get("loss", 0),
                            "val_loss_last":   last_val.get("loss", 0),
                            "epochs_run": epoch_run,
                            "seconds": elapsed,
                            "params_M": total_params / 1e6,
                            "infer_latency_ms": latency_ms,
                        }

                    summary_records.append(record)

                    # 每次都保存 CSV（避免中断损失）
                    pd.DataFrame(summary_records).to_csv(summary_path, index=False)
                    avg_time = (time.time() - global_start_time) / run_count
                    remaining = (total_runs - run_count) * avg_time
                    eta_min = int(remaining // 60)
                    eta_sec = int(remaining % 60)
                    
                    print(f"✅ Saved to {summary_path}")
                    print(f"⏱️ Elapsed: {elapsed:.1f}s | ⏳ ETA: ~{eta_min}m {eta_sec}s")