import os
import time
import random
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
from sklearn.model_selection import train_test_split

from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score,
    recall_score, roc_auc_score, f1_score, matthews_corrcoef
)
import numpy as np
from transformers import AutoConfig

import warnings
warnings.filterwarnings("ignore")

# ================= 超参数 =================
LABEL_FILE = './苹果称重-1-7.xlsx'
IMG_DIR = './MNR_figure'
TEXT_ROOT = './MNR_CL_text_right'

text_root_dir = Path(TEXT_ROOT)
TEXT_SOURCES = sorted([d.name for d in text_root_dir.iterdir() if d.is_dir()])

TEXT_MODELS = [
    "bert-base-uncased",
    "roberta-base",
    "distilbert-base-uncased",
    "albert-base-v2"
]

VISION_MODELS = [
    "resnet18", "resnet34", "resnet50",
    "densenet121", "efficientnet_b0", "mobilenet_v3_small",
    "shufflenet_v2_x1_0", "vgg16"]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_LEN = 512
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50
LR = 2e-4
SEED = 42
OUTPUT_DIR = './logs_fusion_CA'
# ==========================================

random.seed(SEED)
torch.manual_seed(SEED)

# ============== Dataset ==============

def compute_mean_std(df, img_dir, img_size):
    imgs = []
    for idx, row in df.iterrows():
        img_path = os.path.join(img_dir, f"{int(row['id'])}.png")
        img = Image.open(img_path).convert('RGB').resize((img_size, img_size))
        img = np.array(img) / 255.0  # 归一化到[0,1]
        imgs.append(img)
    imgs = np.stack(imgs, axis=0)  # (N, H, W, C)
    imgs = imgs.transpose((0, 3, 1, 2))  # (N, C, H, W)
    mean = imgs.mean(axis=(0, 2, 3))
    std = imgs.std(axis=(0, 2, 3))
    return mean, std

class CrossModalFusion(nn.Module):
    def __init__(self, vision_dim, text_dim, fusion_dim=256, n_heads=4, num_classes=2, dropout=0.1):
        super().__init__()
        # 先映射到同一维度
        self.vision_proj = nn.Linear(vision_dim, fusion_dim)
        self.text_proj   = nn.Linear(text_dim,   fusion_dim)
        
        # 双向跨模态注意力：vision query→text，text query→vision
        self.attn_v2t = nn.MultiheadAttention(embed_dim=fusion_dim, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.attn_t2v = nn.MultiheadAttention(embed_dim=fusion_dim, num_heads=n_heads, dropout=dropout, batch_first=True)
        
        # 简单的前馈+残差+LayerNorm
        self.norm_v = nn.LayerNorm(fusion_dim)
        self.norm_t = nn.LayerNorm(fusion_dim)
        self.ffn     = nn.Sequential(
            nn.Linear(fusion_dim*2, fusion_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        
        # 分类头
        self.classifier = nn.Linear(fusion_dim, num_classes)

    def forward(self, vision_feat, text_feat):
        """
        vision_feat: (B, N_patch, vision_dim)
        text_feat:   (B, N_txt,   text_dim)
        """
        # 1) 映射
        v = self.vision_proj(vision_feat)  # (B, N_v, D)
        t = self.text_proj( text_feat   )  # (B, N_t, D)
        
        # 2) Vision→Text Cross-Attention
        #    query=v, key/value=t
        v2t, _ = self.attn_v2t(query=v, key=t, value=t)
        v = self.norm_v(v + v2t)          # 残差 + LayerNorm
        
        # 3) Text→Vision Cross-Attention
        #    query=t, key/value=v
        t2v, _ = self.attn_t2v(query=t, key=v, value=v)
        t = self.norm_t(t + t2v)

        # 4) 池化 & 融合
        #    分别对 v 和 t 做平均池化，然后拼接
        v_pool = v.mean(dim=1)  # (B, D)
        t_pool = t.mean(dim=1)  # (B, D)
        fused  = torch.cat([v_pool, t_pool], dim=1)  # (B, 2D)
        fused  = self.ffn(fused)                     # (B, D)
        
        # 5) 分类
        out = self.classifier(fused)                 # (B, num_classes)
        return out

class FusionDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_dir, tokenizer, max_len, transform):
        self.df = df
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        id_ = row["id"]
        label = row["label"]

        # image
        img_path = os.path.join(self.img_dir, f"{int(id_)}.png")
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # text
        encoding = self.tokenizer(
            row["text"], padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        return {
            "image": image,
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

# ============== Model ==============
class VisionTextCrossAttentionModel(nn.Module):
    def __init__(self, vision_backbone, text_model, text_hidden, patch_dim, num_classes=2):
        super().__init__()
        self.vision_backbone = vision_backbone
        self.text_model      = text_model

        self.fusion = CrossModalFusion(
            vision_dim=patch_dim,
            text_dim=text_hidden,
            fusion_dim=256,
            n_heads=4,
            num_classes=num_classes
        )

    def forward(self, image, input_ids, attention_mask):
        # 1) 抽 patch
        vision_feat = self.vision_backbone(image)              # (B, N_patch, C)
        # 2) 文本 token
        text_out    = self.text_model(input_ids=input_ids,
                                      attention_mask=attention_mask)
        text_feat   = text_out.last_hidden_state              # (B, N_txt, D)
        # 3) 融合 & 分类
        return self.fusion(vision_feat, text_feat)

# ============== Helpers ==============
def get_vision_model_patch(name, img_size=224, patch_size=32, pretrained=True):
    """
    返回一个 nn.Module (PatchExtractor) 和对应的 patch_dim。
    支持：ResNet, DenseNet, EfficientNet, MobileNetV3, ShuffleNetV2, VGG。
    """
    # 1. 构造 backbone_layers
    if 'resnet' in name:
        base = getattr(models, name)(pretrained=pretrained)
        backbone_layers = nn.Sequential(*list(base.children())[:-2])  # 去掉 avgpool & fc
    elif 'densenet' in name:
        base = getattr(models, name)(pretrained=pretrained)
        backbone_layers = base.features  # DenseNet 的 features 包含所有卷积+池化
    elif 'efficientnet' in name:
        base = getattr(models, name)(pretrained=pretrained)
        backbone_layers = base.features  # EfficientNet 的特征提取部分
    elif 'mobilenet' in name:
        base = getattr(models, name)(pretrained=pretrained)
        backbone_layers = base.features  # MobileNetV3 的特征提取部分
    elif 'shufflenet' in name:
        base = getattr(models, name)(pretrained=pretrained)
        # children: conv1, maxpool, stage2, stage3, stage4, conv5, fc
        layers = list(base.children())[:-1]  # 去掉最后的 fc
        backbone_layers = nn.Sequential(*layers)
    elif 'vgg' in name:
        base = getattr(models, name)(pretrained=pretrained)
        backbone_layers = base.features  # VGG 的特征提取部分（包含多层 MaxPool）
    else:
        raise NotImplementedError(f"get_vision_model_patch 不支持模型: {name}")

    # 2. 定义 PatchExtractor
    class PatchExtractor(nn.Module):
        def __init__(self, backbone):
            super().__init__()
            self.backbone = backbone

        def forward(self, x):
            x = self.backbone(x)              # (B, C, H, W)
            B, C, H, W = x.shape
            x = x.flatten(2).transpose(1, 2)  # -> (B, N_patch=H*W, C)
            return x

    backbone = PatchExtractor(backbone_layers)

    # 3. 自动推理 patch_dim
    with torch.no_grad():
        dummy = torch.randn(1, 3, img_size, img_size)
        feat = backbone(dummy)  # (1, N_patch, C)
        patch_dim = feat.shape[-1]

    return backbone, patch_dim

def load_data(text_source) -> pd.DataFrame:
    label_df = pd.read_excel(LABEL_FILE, dtype={0: str})
    label_df = label_df.iloc[:, [0, 3]].dropna()
    label_df.columns = ['id', 'label']
    label_df["id"] = label_df["id"].astype(str).str.zfill(3)

    id_set = set(label_df["id"].str.lstrip("0"))  # 移除前导0用于文件匹配
    text_dir = Path(TEXT_ROOT) / text_source

    texts = {}
    for txt_file in text_dir.glob("*.txt"):
        raw_id = txt_file.stem  # 如 1, 42
        if raw_id in id_set:
            std_id = raw_id.zfill(3)  # 对应 Excel 中的 '001'
            texts[std_id] = txt_file.read_text(encoding='utf-8', errors='ignore')

    df = label_df[label_df["id"].isin(texts.keys())].copy()
    df["text"] = df["id"].map(texts)
    df = df.dropna()
    df["label"] = df["label"].astype(int)
    return df

def evaluate(loader):
    model.eval()
    preds, targets, losses = [], [], []
    with torch.no_grad():
        for batch in loader:
            image = batch["image"].to(DEVICE)
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            out = model(image, input_ids, attention_mask)
            # —— 补丁开始 —— 
            # 如果 out 里有 NaN，先把它们全替换成 0
            if torch.isnan(out).any():
                out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
            # 计算 loss
            loss = criterion(out, labels)
            losses.append(loss.item() * image.size(0))

            # 计算概率
            probs = torch.softmax(out, dim=1)[:, 1]
            # 把概率里的 NaN 全部设为 0.5（最不偏不倚的中立值）
            probs = torch.nan_to_num(probs, nan=0.5)
            preds.append(probs.cpu())
            targets.append(labels.cpu())
        # —— 补丁结束 —— 

    preds = torch.cat(preds).numpy()
    targets = torch.cat(targets).numpy()
    y_hat = (preds >= 0.5).astype(int)

    metrics = {
        'loss':   sum(losses) / len(loader.dataset),
        'accuracy': accuracy_score(targets, y_hat),
        'balanced_accuracy': balanced_accuracy_score(targets, y_hat),
        'precision': precision_score(targets, y_hat, zero_division=0),
        'recall':    recall_score(targets, y_hat, zero_division=0),
        # 如果只有一个类或者仍有 NaN，就默认 AUC=0.5
        'auc':   roc_auc_score(targets, preds) if (len(np.unique(targets))>1 and not np.isnan(preds).any()) else 0.5,
        'mcc':   matthews_corrcoef(targets, y_hat),
        'f1':    f1_score(targets, y_hat, zero_division=0),
    }
    return metrics

def train_eval(model, train_loader, val_loader, optimizer, criterion, patience=10):
    best_val_f1 = 0.0
    best_val_metrics = {}
    best_train_metrics = {}
    last_val_metrics = {}
    last_train_metrics = {}

    no_improve_epochs = 0

    for epoch in range(1, EPOCHS + 1):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(
                batch["image"].to(DEVICE),
                batch["input_ids"].to(DEVICE),
                batch["attention_mask"].to(DEVICE)
            )
            loss = criterion(out, batch["label"].to(DEVICE))
            loss.backward()
            optimizer.step()

        train_metrics = evaluate(train_loader)
        val_metrics = evaluate(val_loader)

        last_train_metrics = train_metrics
        last_val_metrics = val_metrics

        print(f"Ep{epoch:02d} | "
              f"tr_f1={train_metrics['f1']:.3f} "
              f"val_f1={val_metrics['f1']:.3f} "
              f"val_auc={val_metrics['auc']:.3f} "
              f"val_mcc={val_metrics['mcc']:.3f}")

        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            best_val_metrics = val_metrics.copy()
            best_train_metrics = train_metrics.copy()
            no_improve_epochs = 0  # reset counter
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"⏹️ Early stopping triggered at epoch {epoch}")
                break

    return best_val_metrics, best_train_metrics, last_val_metrics, last_train_metrics, epoch

def measure_inference_latency(model, loader, device, max_samples=50, warmup=5):
    model.eval()
    total_time = 0
    total_num = 0
    with torch.no_grad():
        n = 0
        for batch in loader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            # warmup
            for _ in range(warmup):
                _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t0 = time.time()
            _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t1 = time.time()
            n_batch = image.size(0)
            total_time += (t1 - t0)
            total_num += n_batch

            n += n_batch
            if n >= max_samples:
                break

    avg_time = total_time / total_num * 1000  # ms/张
    return avg_time

# ============== Main Loop ==============
Path(OUTPUT_DIR).mkdir(exist_ok=True)

summary_path = Path(OUTPUT_DIR) / "summary_results.csv"
summary_records = []

# 计算总组合数
total_runs = len(TEXT_SOURCES) * len(TEXT_MODELS) * len(VISION_MODELS) * 2 * 2
run_count = 0
global_start_time = time.time()

for text_source in TEXT_SOURCES:
    df = load_data(text_source)
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=SEED, stratify=df["label"])
    
    # 统计训练集 mean/std
    mean, std = compute_mean_std(train_df, IMG_DIR, IMG_SIZE)
    print(f"[{text_source}] Train set mean: {mean}, std: {std}")
    
    # 定义transform
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean.tolist(), std=std.tolist()),
    ])

    for text_model_name in TEXT_MODELS:
        for vision_model_name in VISION_MODELS:
            for vision_init in ("pretrained", "scratch"):
                for text_init in ("pretrained", "scratch"):
                    run_count += 1
                    tag = f"{text_source}__{text_model_name.replace('/', '_')}--{text_init}__{vision_model_name}--{vision_init}"
                    print(f"\n[{run_count}/{total_runs}] 🔄 Running: {tag}")

                    # 初始化文本模型
                    if text_init == "pretrained":
                        text_model = AutoModel.from_pretrained(text_model_name).to(DEVICE)
                    else:
                        config = AutoConfig.from_pretrained(text_model_name)
                        text_model = AutoModel.from_config(config).to(DEVICE)

                    # 初始化视觉模型
                    vision_backbone, vision_dim = get_vision_model_patch(vision_model_name, img_size=IMG_SIZE, patch_size=32, pretrained=(vision_init=="pretrained"))
                    
                    model = VisionTextCrossAttentionModel(
                    vision_backbone, text_model, text_model.config.hidden_size, vision_dim, num_classes=2
                    ).to(DEVICE)
                    
                    # 计算参数量
                    total_params = sum(p.numel() for p in model.parameters())

                    # 数据与训练器
                    tokenizer = AutoTokenizer.from_pretrained(text_model_name)
                    train_ds = FusionDataset(train_df, IMG_DIR, tokenizer, MAX_LEN, transform)
                    val_ds = FusionDataset(val_df, IMG_DIR, tokenizer, MAX_LEN, transform)
                    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
                    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
                    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
                    criterion = nn.CrossEntropyLoss()

                    start_time = time.time()
                    best_val, best_train, last_val, last_train, epoch_run = train_eval(model, train_loader, val_loader, optimizer, criterion)
                    elapsed = round(time.time() - start_time, 2)
                    
                    #计算推理时间
                    latency_ms = measure_inference_latency(model, val_loader, DEVICE)

                    record = {
                            "src_text": text_source,
                            "model": text_model_name + "+" + vision_model_name,
                            "init": f"text={text_init}, vision={vision_init}",
                            "best_val_f1":  best_val.get("f1", 0),
                            "best_val_auc": best_val.get("auc", 0),
                            "best_val_mcc": best_val.get("mcc", 0),
                            "best_val_bal_acc": best_val.get("balanced_accuracy", 0),
                            "best_val_precision": best_val.get("precision", 0),
                            "best_val_recall":    best_val.get("recall", 0),
                            "best_train_f1":  best_train.get("f1", 0),
                            "best_train_auc": best_train.get("auc", 0),
                            "best_train_mcc": best_train.get("mcc", 0),
                            "best_train_bal_acc": best_train.get("balanced_accuracy", 0),
                            "best_train_precision": best_train.get("precision", 0),
                            "best_train_recall":    best_train.get("recall", 0),
                            "train_loss_last": last_train.get("loss", 0),
                            "val_loss_last":   last_val.get("loss", 0),
                            "epochs_run": epoch_run,
                            "seconds": elapsed,
                            "params_M": total_params / 1e6,
                            "infer_latency_ms": latency_ms,
                        }

                    summary_records.append(record)

                    # 每次都保存 CSV（避免中断损失）
                    pd.DataFrame(summary_records).to_csv(summary_path, index=False)
                    avg_time = (time.time() - global_start_time) / run_count
                    remaining = (total_runs - run_count) * avg_time
                    eta_min = int(remaining // 60)
                    eta_sec = int(remaining % 60)
                    
                    print(f"✅ Saved to {summary_path}")
                    print(f"⏱️ Elapsed: {elapsed:.1f}s | ⏳ ETA: ~{eta_min}m {eta_sec}s")