import time,gc
import random
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, AutoConfig
from torchvision import transforms
from sklearn.model_selection import train_test_split

from fusion_train_CA_7_26_v1 import (
    IMG_DIR, TEXT_ROOT, MAX_LEN, IMG_SIZE,
    BATCH_SIZE, LR, SEED, OUTPUT_DIR,
    compute_mean_std, load_data,
    get_vision_model_patch, VisionTextCrossAttentionModel,
    train_eval, measure_inference_latency
)

# 固定随机种子
random.seed(SEED)
torch.manual_seed(SEED)

# for test
import fusion_train_CA_7_26_v1
fusion_train_CA_7_26_v1.EPOCHS = 50

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 定义消融类型
ABLATIONS = {
    'none':        {'mask_text': False, 'mask_vision': False, 'noise_text': False, 'noise_vision': False},
    'mask_vision': {'mask_text': False, 'mask_vision': True,  'noise_text': False, 'noise_vision': False},
    'mask_text':   {'mask_text': True,  'mask_vision': False, 'noise_text': False, 'noise_vision': False},
    'noise_vision':{'mask_text': False, 'mask_vision': False, 'noise_text': False, 'noise_vision': True},
    'noise_text':  {'mask_text': False, 'mask_vision': False, 'noise_text': True,  'noise_vision': False},
}

class AblationFusionDataset(torch.utils.data.Dataset):
    def __init__(self, base_df, img_dir, tokenizer, max_len, transform, ablation_cfg):
        from fusion_train_CA_7_26_v1 import FusionDataset
        self.base_ds = FusionDataset(base_df, img_dir, tokenizer, max_len, transform)
        self.cfg = ablation_cfg
        self.mask_tok = tokenizer.mask_token_id
        self.noise_std = 0.1

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, idx):
        item = self.base_ds[idx]
        # 文本 mask/noise
        if self.cfg['mask_text']:
            item['input_ids'] = torch.full_like(item['input_ids'], self.mask_tok)
        elif self.cfg['noise_text']:
            mask_prob = 0.15
            mask = (torch.rand_like(item['input_ids'], dtype=torch.float) < mask_prob)
            item['input_ids'][mask] = self.mask_tok
        # 视觉 mask/noise
        if self.cfg['mask_vision']:
            item['image'] = torch.zeros_like(item['image'])
        elif self.cfg['noise_vision']:
            noise = torch.randn_like(item['image']) * self.noise_std
            item['image'] = torch.clamp(item['image'] + noise, 0.0, 1.0)
        return item


def main():
    # 文本 source
    text_source = sorted([
        d.name for d in Path(TEXT_ROOT).iterdir()
        if d.is_dir() and not d.name.startswith('.')
    ])[0]
    df = load_data(text_source)
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=SEED, stratify=df["label"])

    # 计算图像均值和标准差
    mean, std = compute_mean_std(train_df, IMG_DIR, IMG_SIZE)
    print(f"[{text_source}] Train set mean: {mean}, std: {std}")
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean.tolist(), std=std.tolist()),
    ])

    # 模型配置
    text_model_name   = "albert-base-v2"
    vision_model_name = "mobilenet_v3_small"
    text_init = "scratch"
    vision_init = "pretrained"

    tokenizer = AutoTokenizer.from_pretrained(text_model_name)

 
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # 准备记录
    summary_records = []
    summary_path = Path(OUTPUT_DIR) / f"CA_ablation_metrics_{text_model_name}_{vision_model_name}.csv"
    global_start = time.time()
    total_runs = len(ABLATIONS)
    run_count = 0

    # 开始循环
    for i in range(30):
        for abl_name, cfg in ABLATIONS.items():
            run_count += 1
            print(f"\n==== Ablation: {abl_name} ({run_count}/{total_runs}) ====")
            
            vision_backbone, vision_dim = get_vision_model_patch(
                vision_model_name, img_size=IMG_SIZE, pretrained=(vision_init=="pretrained")
            )
            
            # 文本模型初始化
            if text_init == "pretrained":
                text_model = AutoModel.from_pretrained(text_model_name)
            else:
                cfg1 = AutoConfig.from_pretrained(text_model_name)
                text_model = AutoModel.from_config(cfg1)
            
            model = VisionTextCrossAttentionModel(
                vision_backbone, text_model, text_model.config.hidden_size, vision_dim, num_classes=2
            )

            model.to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=LR)
            criterion = nn.CrossEntropyLoss()
    
            train_ds = AblationFusionDataset(train_df, IMG_DIR, tokenizer, MAX_LEN, transform, cfg)
            val_ds   = AblationFusionDataset(val_df,   IMG_DIR, tokenizer, MAX_LEN, transform, cfg)
            train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
            val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
    
            # 计时 & 训练评估
            start_time = time.time()
            best_val, best_train, last_val, last_train, epoch_run = train_eval(
                model, train_loader, val_loader, optimizer, criterion
            )
            elapsed = round(time.time() - start_time, 2)
            latency_ms = measure_inference_latency(model, val_loader, device)
            total_params = count_params(model)
    
            record = {
                "src_text": text_source,
                "model": f"{text_model_name}+{vision_model_name}",
                "init": f"text={text_init}, vision={vision_init}",
                "input":f"{abl_name}",
                "best_val_f1":     best_val.get("f1", 0),
                "best_val_auc":    best_val.get("auc", 0),
                "best_val_mcc":    best_val.get("mcc", 0),
                "best_val_bal_acc":best_val.get("balanced_accuracy", 0),
                "best_val_precision": best_val.get("precision", 0),
                "best_val_recall":    best_val.get("recall", 0),
                "best_train_f1":    best_train.get("f1", 0),
                "best_train_auc":   best_train.get("auc", 0),
                "best_train_mcc":   best_train.get("mcc", 0),
                "best_train_bal_acc": best_train.get("balanced_accuracy", 0),
                "best_train_precision": best_train.get("precision", 0),
                "best_train_recall":    best_train.get("recall", 0),
                "train_loss_last":  last_train.get("loss", 0),
                "val_loss_last":    last_val.get("loss", 0),
                "epochs_run":       epoch_run,
                "seconds":          elapsed,
                "params_M":         round(total_params / 1e6, 3),
                "infer_latency_ms": round(latency_ms, 2),
            }
            summary_records.append(record)
    
            # 保存中间结果
            pd.DataFrame(summary_records).to_csv(summary_path, index=False)
            avg_time = (time.time() - global_start) / run_count
            remaining = (total_runs - run_count) * avg_time
            eta_m, eta_s = divmod(int(remaining), 60)
            print(f"✅ Saved to {summary_path}")
            print(f"⏱️ Step time: {elapsed}s | ⏳ ETA: ~{eta_m}m {eta_s}s")
            del model, text_model, vision_backbone
            torch.cuda.empty_cache()
            gc.collect()

    # 完成
    print("\nAll ablations completed.")

if __name__ == '__main__':
    main()
