import os,gc
import time
import random
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
from sklearn.model_selection import train_test_split

from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score,
    recall_score, roc_auc_score, f1_score, matthews_corrcoef
)
import numpy as np
from transformers import AutoConfig

import warnings
warnings.filterwarnings("ignore")

# ================= 超参数 =================
LABEL_FILE = './苹果称重-1-7.xlsx'
IMG_DIR = './MNR_figure'
TEXT_ROOT = './MNR_CL_text_right'

text_root_dir = Path(TEXT_ROOT)
TEXT_SOURCES = sorted([d.name for d in text_root_dir.iterdir() if d.is_dir()])

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_LEN = 512
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50
LR = 2e-4
SEED = 42
OUTPUT_DIR = './logs_fusion_Gate'
# ==========================================

random.seed(SEED)
torch.manual_seed(SEED)

# ============== Dataset ==============

def compute_mean_std(df, img_dir, img_size):
    imgs = []
    for idx, row in df.iterrows():
        img_path = os.path.join(img_dir, f"{int(row['id'])}.png")
        img = Image.open(img_path).convert('RGB').resize((img_size, img_size))
        img = np.array(img) / 255.0  # 归一化到[0,1]
        imgs.append(img)
    imgs = np.stack(imgs, axis=0)  # (N, H, W, C)
    imgs = imgs.transpose((0, 3, 1, 2))  # (N, C, H, W)
    mean = imgs.mean(axis=(0, 2, 3))
    std = imgs.std(axis=(0, 2, 3))
    return mean, std

# ==== 定义 FiLM 模块 （轻量可解释）====
class FiLM(nn.Module):
    def __init__(self, text_dim, num_channels):
        super().__init__()
        self.gamma_gen = nn.Linear(text_dim, num_channels)
        self.beta_gen  = nn.Linear(text_dim, num_channels)

    def forward(self, x, text_feat):
        # x: (B, C, H, W), text_feat: (B, D)
        gamma = self.gamma_gen(text_feat).unsqueeze(-1).unsqueeze(-1)
        beta  = self.beta_gen(text_feat).unsqueeze(-1).unsqueeze(-1)
        return gamma * x + beta

# 定义消融类型
ABLATIONS = {
    'none':        {'mask_text': False, 'mask_vision': False, 'noise_text': False, 'noise_vision': False},
    'mask_vision': {'mask_text': False, 'mask_vision': True,  'noise_text': False, 'noise_vision': False},
    'mask_text':   {'mask_text': True,  'mask_vision': False, 'noise_text': False, 'noise_vision': False},
    'noise_vision':{'mask_text': False, 'mask_vision': False, 'noise_text': False, 'noise_vision': True},
    'noise_text':  {'mask_text': False, 'mask_vision': False, 'noise_text': True,  'noise_vision': False},
}


class FusionDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_dir, tokenizer, max_len, transform):
        self.df = df
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        id_ = row["id"]
        label = row["label"]

        # image
        img_path = os.path.join(self.img_dir, f"{int(id_)}.png")
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # text
        encoding = self.tokenizer(
            row["text"], padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        return {
            "image": image,
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

class AblationFusionDataset(torch.utils.data.Dataset):
    def __init__(self, base_df, img_dir, tokenizer, max_len, transform, ablation_cfg):
        self.base_ds = FusionDataset(base_df, img_dir, tokenizer, max_len, transform)
        self.cfg = ablation_cfg
        self.mask_tok = tokenizer.mask_token_id
        self.noise_std = 0.1

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, idx):
        item = self.base_ds[idx]
        # 文本 mask/noise
        if self.cfg['mask_text']:
            item['input_ids'] = torch.full_like(item['input_ids'], self.mask_tok)
        elif self.cfg['noise_text']:
            mask_prob = 0.15
            mask = (torch.rand_like(item['input_ids'], dtype=torch.float) < mask_prob)
            item['input_ids'][mask] = self.mask_tok
        # 视觉 mask/noise
        if self.cfg['mask_vision']:
            item['image'] = torch.zeros_like(item['image'])
        elif self.cfg['noise_vision']:
            noise = torch.randn_like(item['image']) * self.noise_std
            item['image'] = torch.clamp(item['image'] + noise, 0.0, 1.0)
        return item

# ============== Model ==============
class FusionModel(nn.Module):
    def __init__(self, vision_backbone, text_model, text_hidden, num_channels, num_classes=2):
        super().__init__()
        self.vision_backbone = vision_backbone  # 输出 (B, C, H, W)
        self.text_model = text_model
        self.film = FiLM(text_hidden, num_channels)
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Linear(num_channels, num_classes)

    def forward(self, image, input_ids, attention_mask):
        # 视觉特征
        feat_map = self.vision_backbone(image)
        # 文本特征: 取 [CLS]
        txt_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        txt_feat = txt_out.last_hidden_state[:,0,:]
        # FiLM 调制
        mod_map = self.film(feat_map, txt_feat)
        # 池化与分类
        pooled = self.pool(mod_map).flatten(1)
        return self.classifier(pooled)

# ============== Helpers ==============

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_vision_backbone_and_channels(name, pretrained=True):
    # 获取预训练模型
    base = getattr(models, name)(pretrained=pretrained)
    # 自动去除最后一层（分类层）
    children = list(base.children())
    backbone = nn.Sequential(*children[:-1])
    # 使用 dummy tensor 推断特征通道数
    with torch.no_grad():
        dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE)
        feat = backbone(dummy)
        num_channels = feat.shape[1]
    return backbone, num_channels

def load_data(text_source) -> pd.DataFrame:
    label_df = pd.read_excel(LABEL_FILE, dtype={0: str})
    label_df = label_df.iloc[:, [0, 3]].dropna()
    label_df.columns = ['id', 'label']
    label_df["id"] = label_df["id"].astype(str).str.zfill(3)

    id_set = set(label_df["id"].str.lstrip("0"))  # 移除前导0用于文件匹配
    text_dir = Path(TEXT_ROOT) / text_source

    texts = {}
    for txt_file in text_dir.glob("*.txt"):
        raw_id = txt_file.stem  # 如 1, 42
        if raw_id in id_set:
            std_id = raw_id.zfill(3)  # 对应 Excel 中的 '001'
            texts[std_id] = txt_file.read_text(encoding='utf-8', errors='ignore')

    df = label_df[label_df["id"].isin(texts.keys())].copy()
    df["text"] = df["id"].map(texts)
    df = df.dropna()
    df["label"] = df["label"].astype(int)
    return df

def evaluate(loader,model):
    model.eval()
    preds, targets, losses = [], [], []
    with torch.no_grad():
        for batch in loader:
            image = batch["image"].to(DEVICE)
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            out = model(image, input_ids, attention_mask)
            # —— 补丁开始 —— 
            # 如果 out 里有 NaN，先把它们全替换成 0
            if torch.isnan(out).any():
                out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
            # 计算 loss
            loss = criterion(out, labels)
            losses.append(loss.item() * image.size(0))

            # 计算概率
            probs = torch.softmax(out, dim=1)[:, 1]
            # 把概率里的 NaN 全部设为 0.5（最不偏不倚的中立值）
            probs = torch.nan_to_num(probs, nan=0.5)
            preds.append(probs.cpu())
            targets.append(labels.cpu())
        # —— 补丁结束 —— 

    preds = torch.cat(preds).numpy()
    targets = torch.cat(targets).numpy()
    y_hat = (preds >= 0.5).astype(int)

    metrics = {
        'loss':   sum(losses) / len(loader.dataset),
        'accuracy': accuracy_score(targets, y_hat),
        'balanced_accuracy': balanced_accuracy_score(targets, y_hat),
        'precision': precision_score(targets, y_hat, zero_division=0),
        'recall':    recall_score(targets, y_hat, zero_division=0),
        # 如果只有一个类或者仍有 NaN，就默认 AUC=0.5
        'auc':   roc_auc_score(targets, preds) if (len(np.unique(targets))>1 and not np.isnan(preds).any()) else 0.5,
        'mcc':   matthews_corrcoef(targets, y_hat),
        'f1':    f1_score(targets, y_hat, zero_division=0),
    }
    return metrics

def train_eval(model, train_loader, val_loader, optimizer, criterion, patience=20):
    best_val_f1 = 0.0
    best_val_metrics = {}
    best_train_metrics = {}
    last_val_metrics = {}
    last_train_metrics = {}

    no_improve_epochs = 0

    for epoch in range(1, EPOCHS + 1):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(
                batch["image"].to(DEVICE),
                batch["input_ids"].to(DEVICE),
                batch["attention_mask"].to(DEVICE)
            )
            loss = criterion(out, batch["label"].to(DEVICE))
            loss.backward()
            optimizer.step()

        train_metrics = evaluate(train_loader,model)
        val_metrics = evaluate(val_loader,model)

        last_train_metrics = train_metrics
        last_val_metrics = val_metrics

        print(f"Ep{epoch:02d} | "
              f"tr_f1={train_metrics['f1']:.3f} "
              f"val_f1={val_metrics['f1']:.3f} "
              f"val_auc={val_metrics['auc']:.3f} "
              f"val_mcc={val_metrics['mcc']:.3f}")

        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            best_val_metrics = val_metrics.copy()
            best_train_metrics = train_metrics.copy()
            no_improve_epochs = 0  # reset counter
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"⏹️ Early stopping triggered at epoch {epoch}")
                break

    return best_val_metrics, best_train_metrics, last_val_metrics, last_train_metrics, epoch

def measure_inference_latency(model, loader, device, max_samples=50, warmup=5):
    model.eval()
    total_time = 0
    total_num = 0
    with torch.no_grad():
        n = 0
        for batch in loader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)

            # warmup
            for _ in range(warmup):
                _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t0 = time.time()
            _ = model(image, input_ids, attn_mask)
            if device.startswith("cuda"):
                torch.cuda.synchronize()
            t1 = time.time()
            n_batch = image.size(0)
            total_time += (t1 - t0)
            total_num += n_batch

            n += n_batch
            if n >= max_samples:
                break

    avg_time = total_time / total_num * 1000  # ms/张
    return avg_time

# ============== Main Loop ==============
Path(OUTPUT_DIR).mkdir(exist_ok=True)

summary_path = Path(OUTPUT_DIR) / "summary_results.csv"
summary_records = []

global_start_time = time.time()

text_source = sorted([
    d.name for d in Path(TEXT_ROOT).iterdir()
    if d.is_dir() and not d.name.startswith('.')
])[0]
df = load_data(text_source)
train_df, val_df = train_test_split(df, test_size=0.2, random_state=SEED, stratify=df["label"])

# 统计训练集 mean/std
mean, std = compute_mean_std(train_df, IMG_DIR, IMG_SIZE)
print(f"[{text_source}] Train set mean: {mean}, std: {std}")

# 定义transform
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean.tolist(), std=std.tolist()),
])

# 模型配置
text_model_name   = "roberta-base"
vision_model_name = "resnet18"
text_init = "pretrained"
vision_init = "scratch"

# 准备记录
summary_records = []
summary_path = Path(OUTPUT_DIR) / f"Gate_ablation_metrics_{text_model_name}_{vision_model_name}.csv"
global_start = time.time()
total_runs = len(ABLATIONS)
run_count = 0

tokenizer = AutoTokenizer.from_pretrained(text_model_name)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 开始循环
for i in range(30):
    for abl_name, cfg in ABLATIONS.items():
        run_count += 1
        print(f"\n==== Ablation: {abl_name} ({run_count}/{total_runs}) ====")
        
        vision_backbone, num_channels = get_vision_backbone_and_channels(vision_model_name, pretrained=(vision_init=="pretrained"))
        # 文本模型初始化
        if text_init == "pretrained":
            text_model = AutoModel.from_pretrained(text_model_name)
        else:
            cfg = AutoConfig.from_pretrained(text_model_name)
            text_model = AutoModel.from_config(cfg)
        
        model = FusionModel(vision_backbone,
                             text_model,
                             text_model.config.hidden_size,
                             num_channels,
                             num_classes=2).to(DEVICE)

        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        criterion = nn.CrossEntropyLoss()
    
        train_ds = AblationFusionDataset(train_df, IMG_DIR, tokenizer, MAX_LEN, transform, cfg)
        val_ds   = AblationFusionDataset(val_df,   IMG_DIR, tokenizer, MAX_LEN, transform, cfg)
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
        val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
    
        # 计时 & 训练评估
        start_time = time.time()
        best_val, best_train, last_val, last_train, epoch_run = train_eval(
            model, train_loader, val_loader, optimizer, criterion
        )
        elapsed = round(time.time() - start_time, 2)
        latency_ms = measure_inference_latency(model, val_loader, device)
        total_params = count_params(model)
    
        record = {
            "src_text": text_source,
            "model": f"{text_model_name}+{vision_model_name}",
            "init": f"text={text_init}, vision={vision_init}",
            "input":f"{abl_name}",
            "best_val_f1":     best_val.get("f1", 0),
            "best_val_auc":    best_val.get("auc", 0),
            "best_val_mcc":    best_val.get("mcc", 0),
            "best_val_bal_acc":best_val.get("balanced_accuracy", 0),
            "best_val_precision": best_val.get("precision", 0),
            "best_val_recall":    best_val.get("recall", 0),
            "best_train_f1":    best_train.get("f1", 0),
            "best_train_auc":   best_train.get("auc", 0),
            "best_train_mcc":   best_train.get("mcc", 0),
            "best_train_bal_acc": best_train.get("balanced_accuracy", 0),
            "best_train_precision": best_train.get("precision", 0),
            "best_train_recall":    best_train.get("recall", 0),
            "train_loss_last":  last_train.get("loss", 0),
            "val_loss_last":    last_val.get("loss", 0),
            "epochs_run":       epoch_run,
            "seconds":          elapsed,
            "params_M":         round(total_params / 1e6, 3),
            "infer_latency_ms": round(latency_ms, 2),
        }
        summary_records.append(record)
    
        # 保存中间结果
        pd.DataFrame(summary_records).to_csv(summary_path, index=False)
        avg_time = (time.time() - global_start) / run_count
        remaining = (total_runs - run_count) * avg_time
        eta_m, eta_s = divmod(int(remaining), 60)
        print(f"✅ Saved to {summary_path}")
        print(f"⏱️ Step time: {elapsed}s | ⏳ ETA: ~{eta_m}m {eta_s}s")
        del model, text_model, vision_backbone
        torch.cuda.empty_cache()
        gc.collect()

# 完成
print("\nAll ablations completed.")