import os
import time
import random
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
from sklearn.model_selection import train_test_split

from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score,
    recall_score, roc_auc_score, f1_score, matthews_corrcoef
)
import numpy as np
from transformers import AutoConfig

import warnings
warnings.filterwarnings("ignore")

# ================= 超参数 =================
LABEL_FILE = './苹果称重-1-7.xlsx'
IMG_DIR = './MNR_figure'
TEXT_ROOT = './MNR_CL_text_right'

text_root_dir = Path(TEXT_ROOT)
TEXT_SOURCES = sorted([d.name for d in text_root_dir.iterdir() if d.is_dir()])

TEXT_MODELS = [
    "bert-base-uncased",
    "roberta-base",
    "distilbert-base-uncased",
    "albert-base-v2"
]

VISION_MODELS = [
    "resnet18", "resnet34", "resnet50",
    "densenet121", "efficientnet_b0", "mobilenet_v3_small",
    "shufflenet_v2_x1_0", "vgg16"]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_LEN = 512
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50
LR = 2e-4
SEED = 42
OUTPUT_DIR = './logs_fusion_ms'
# ==========================================

random.seed(SEED)
torch.manual_seed(SEED)

# ============== Dataset ==============

def compute_mean_std(df, img_dir, img_size):
    imgs = []
    for idx, row in df.iterrows():
        img_path = os.path.join(img_dir, f"{int(row['id'])}.png")
        img = Image.open(img_path).convert('RGB').resize((img_size, img_size))
        img = np.array(img) / 255.0  # 归一化到[0,1]
        imgs.append(img)
    imgs = np.stack(imgs, axis=0)  # (N, H, W, C)
    imgs = imgs.transpose((0, 3, 1, 2))  # (N, C, H, W)
    mean = imgs.mean(axis=(0, 2, 3))
    std = imgs.std(axis=(0, 2, 3))
    return mean, std

class FusionDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_dir, tokenizer, max_len, transform):
        self.df = df
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        id_ = row["id"]
        label = row["label"]

        # image
        img_path = os.path.join(self.img_dir, f"{int(id_)}.png")
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # text
        encoding = self.tokenizer(
            row["text"], padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        return {
            "image": image,
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

# ============== Model ==============
class VisionTextModel(nn.Module):
    def __init__(self, vision_backbone, text_model, text_hidden, vision_output_dim=None, num_classes=2):
        super().__init__()
        self.vision_backbone = vision_backbone
        self.text_model = text_model
        self.text_proj = nn.Linear(text_hidden, 256)

        # 自动推理 vision 输出维度
        if vision_output_dim is None:
            with torch.no_grad():
                dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(next(vision_backbone.parameters()).device)
                feat = vision_backbone(dummy)
                vision_output_dim = feat.view(1, -1).shape[1]

        self.fusion = nn.Sequential(
            nn.Linear(256 + vision_output_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, image, input_ids, attention_mask):
        vision_feat = self.vision_backbone(image)
        if vision_feat.ndim > 2:  # 展平
            vision_feat = vision_feat.view(vision_feat.size(0), -1)
        text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        
        # 若模型有 pooler_output（如 BERT），否则取 [CLS]
        if hasattr(text_out, "pooler_output"):
            text_feat_raw = text_out.pooler_output
        else:
            text_feat_raw = text_out.last_hidden_state[:, 0]  # [CLS]
        text_feat = self.text_proj(text_feat_raw)
        
        fused = torch.cat((vision_feat, text_feat), dim=1)
        return self.fusion(fused)

# ============== Helpers ==============
def get_vision_model(name, pretrained=True):
    model = getattr(models, name)(pretrained=pretrained)
    if 'resnet' in name or 'resnext' in name or 'wide_resnet' in name:
        model.fc = nn.Identity()
        return nn.Sequential(model, nn.Flatten())
    elif 'densenet' in name:
        model.classifier = nn.Identity()
        return nn.Sequential(model, nn.Flatten())
    elif 'mobilenet' in name or 'efficientnet' in name:
        model.classifier = nn.Identity()
        return nn.Sequential(model, nn.Flatten())
    elif 'vgg' in name:
        model.classifier = nn.Sequential(*list(model.classifier.children())[:-1])
        return model
    elif 'squeezenet' in name or 'shufflenet' in name or 'mnasnet' in name:
        model.classifier = nn.Identity()
        return nn.Sequential(model, nn.Flatten())
    else:
        raise NotImplementedError(name)

def load_data(text_source) -> pd.DataFrame:
    label_df = pd.read_excel(LABEL_FILE, dtype={0: str})
    label_df = label_df.iloc[:, [0, 3]].dropna()
    label_df.columns = ['id', 'label']
    label_df["id"] = label_df["id"].astype(str).str.zfill(3)

    id_set = set(label_df["id"].str.lstrip("0"))  # 移除前导0用于文件匹配
    text_dir = Path(TEXT_ROOT) / text_source

    texts = {}
    for txt_file in text_dir.glob("*.txt"):
        raw_id = txt_file.stem  # 如 1, 42
        if raw_id in id_set:
            std_id = raw_id.zfill(3)  # 对应 Excel 中的 '001'
            texts[std_id] = txt_file.read_text(encoding='utf-8', errors='ignore')

    df = label_df[label_df["id"].isin(texts.keys())].copy()
    df["text"] = df["id"].map(texts)
    df = df.dropna()
    df["label"] = df["label"].astype(int)
    return df

def evaluate(loader):
    model.eval()
    preds, targets, losses = [], [], []
    with torch.no_grad():
        for batch in loader:
            image = batch["image"].to(DEVICE)
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            out = model(image, input_ids, attention_mask)
            loss = criterion(out, labels)
            losses.append(loss.item() * image.size(0))

            probs = torch.softmax(out, dim=1)[:, 1]  # 二分类
            preds.append(probs.cpu())
            targets.append(labels.cpu())

    preds = torch.cat(preds).numpy()
    targets = torch.cat(targets).numpy()
    y_hat = (preds >= 0.5).astype(int)

    metrics = {
        'loss':              sum(losses) / len(loader.dataset),
        'accuracy':          accuracy_score(targets, y_hat),
        'balanced_accuracy': balanced_accuracy_score(targets, y_hat),
        'precision':         precision_score(targets, y_hat, zero_division=0),
        'recall':            recall_score(targets, y_hat, zero_division=0),
        'auc':               roc_auc_score(targets, preds) if len(np.unique(targets)) > 1 else 0.5,
        'mcc':               matthews_corrcoef(targets, y_hat),
        'f1':                f1_score(targets, y_hat, zero_division=0),
    }
    return metrics

def train_eval(model, train_loader, val_loader, optimizer, criterion, patience=10):
    best_val_f1 = 0.0
    best_val_metrics = {}
    best_train_metrics = {}
    last_val_metrics = {}
    last_train_metrics = {}

    no_improve_epochs = 0

    for epoch in range(1, EPOCHS + 1):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(
                batch["image"].to(DEVICE),
                batch["input_ids"].to(DEVICE),
                batch["attention_mask"].to(DEVICE)
            )
            loss = criterion(out, batch["label"].to(DEVICE))
            loss.backward()
            optimizer.step()

        train_metrics = evaluate(train_loader)
        val_metrics = evaluate(val_loader)

        last_train_metrics = train_metrics
        last_val_metrics = val_metrics

        print(f"Ep{epoch:02d} | "
              f"tr_f1={train_metrics['f1']:.3f} "
              f"val_f1={val_metrics['f1']:.3f} "
              f"val_auc={val_metrics['auc']:.3f} "
              f"val_mcc={val_metrics['mcc']:.3f}")

        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            best_val_metrics = val_metrics.copy()
            best_train_metrics = train_metrics.copy()
            no_improve_epochs = 0  # reset counter
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"⏹️ Early stopping triggered at epoch {epoch}")
                break

    return best_val_metrics, best_train_metrics, last_val_metrics, last_train_metrics, epoch

def measure_inference_latency(model, loader, device, max_samples=50, warmup=5):
    model.eval()
    total_time = 0
    total_num = 0
    with torch.no_grad():
        n = 0
        for batch in loader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            # warmup
            for _ in range(warmup):
                _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t0 = time.time()
            _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t1 = time.time()
            n_batch = image.size(0)
            total_time += (t1 - t0)
            total_num += n_batch

            n += n_batch
            if n >= max_samples:
                break

    avg_time = total_time / total_num * 1000  # ms/张
    return avg_time

# ============== Main Loop ==============
Path(OUTPUT_DIR).mkdir(exist_ok=True)

summary_path = Path(OUTPUT_DIR) / "summary_results.csv"
summary_records = []

# 计算总组合数
total_runs = len(TEXT_SOURCES) * len(TEXT_MODELS) * len(VISION_MODELS) * 2 * 2
run_count = 0
global_start_time = time.time()

for text_source in TEXT_SOURCES:
    df = load_data(text_source)
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=SEED, stratify=df["label"])
    
    # 统计训练集 mean/std
    mean, std = compute_mean_std(train_df, IMG_DIR, IMG_SIZE)
    print(f"[{text_source}] Train set mean: {mean}, std: {std}")
    
    # 定义transform
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean.tolist(), std=std.tolist()),
    ])

    for text_model_name in TEXT_MODELS:
        for vision_model_name in VISION_MODELS:
            for vision_init in ("pretrained", "scratch"):
                for text_init in ("pretrained", "scratch"):
                    run_count += 1
                    tag = f"{text_source}__{text_model_name.replace('/', '_')}--{text_init}__{vision_model_name}--{vision_init}"
                    print(f"\n[{run_count}/{total_runs}] 🔄 Running: {tag}")

                    # 初始化文本模型
                    if text_init == "pretrained":
                        text_model = AutoModel.from_pretrained(text_model_name).to(DEVICE)
                    else:
                        config = AutoConfig.from_pretrained(text_model_name)
                        text_model = AutoModel.from_config(config).to(DEVICE)

                    # 初始化视觉模型
                    vision_backbone = get_vision_model(vision_model_name, pretrained=(vision_init == "pretrained")).to(DEVICE)

                    # 构建融合模型
                    model = VisionTextModel(vision_backbone, text_model, text_model.config.hidden_size).to(DEVICE)
                    
                    # 计算参数量
                    total_params = sum(p.numel() for p in model.parameters())

                    # 数据与训练器
                    tokenizer = AutoTokenizer.from_pretrained(text_model_name)
                    train_ds = FusionDataset(train_df, IMG_DIR, tokenizer, MAX_LEN, transform)
                    val_ds = FusionDataset(val_df, IMG_DIR, tokenizer, MAX_LEN, transform)
                    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
                    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
                    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
                    criterion = nn.CrossEntropyLoss()

                    start_time = time.time()
                    best_val, best_train, last_val, last_train, epoch_run = train_eval(model, train_loader, val_loader, optimizer, criterion)
                    elapsed = round(time.time() - start_time, 2)
                    
                    #计算推理时间
                    latency_ms = measure_inference_latency(model, val_loader, DEVICE)

                    record = {
                            "src_text": text_source,
                            "model": text_model_name + "+" + vision_model_name,
                            "init": f"text={text_init}, vision={vision_init}",
                            "best_val_f1":  best_val.get("f1", 0),
                            "best_val_auc": best_val.get("auc", 0),
                            "best_val_mcc": best_val.get("mcc", 0),
                            "best_val_bal_acc": best_val.get("balanced_accuracy", 0),
                            "best_val_precision": best_val.get("precision", 0),
                            "best_val_recall":    best_val.get("recall", 0),
                            "best_train_f1":  best_train.get("f1", 0),
                            "best_train_auc": best_train.get("auc", 0),
                            "best_train_mcc": best_train.get("mcc", 0),
                            "best_train_bal_acc": best_train.get("balanced_accuracy", 0),
                            "best_train_precision": best_train.get("precision", 0),
                            "best_train_recall":    best_train.get("recall", 0),
                            "train_loss_last": last_train.get("loss", 0),
                            "val_loss_last":   last_val.get("loss", 0),
                            "epochs_run": epoch_run,
                            "seconds": elapsed,
                            "params_M": total_params / 1e6,
                            "infer_latency_ms": latency_ms,
                        }

                    summary_records.append(record)

                    # 每次都保存 CSV（避免中断损失）
                    pd.DataFrame(summary_records).to_csv(summary_path, index=False)
                    avg_time = (time.time() - global_start_time) / run_count
                    remaining = (total_runs - run_count) * avg_time
                    eta_min = int(remaining // 60)
                    eta_sec = int(remaining % 60)
                    
                    print(f"✅ Saved to {summary_path}")
                    print(f"⏱️ Elapsed: {elapsed:.1f}s | ⏳ ETA: ~{eta_min}m {eta_sec}s")