from __future__ import annotations
import os, csv, random, time
from pathlib import Path
from typing import Tuple, List

import pandas as pd
import numpy as np
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score,
    recall_score, roc_auc_score, f1_score, matthews_corrcoef
)

# 尝试导入 fvcore 进行 FLOPs 统计
try:
    from fvcore.nn import FlopCountAnalysis  # type: ignore
except ImportError:  # 若未安装则跳过 FLOPs
    FlopCountAnalysis = None

# -----------------------------------------------------------------------------
# CONFIG — 修改这里即可
# -----------------------------------------------------------------------------
CONFIG = {
    'DATA_DIR': Path('./MNR_figure'),
    'LABEL_FILE': Path('./苹果称重-1-7.xlsx'),
    'TRAIN_RATIO': 0.8,
    'IMG_SIZE': 224,
    'BATCH_SIZE': 32,
    'NUM_EPOCHS': 60,
    'LR': 2e-4,
    'WEIGHT_DECAY': 1e-4,
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
    'NUM_WORKERS': 0 if os.name == 'nt' else 4,
    'MODELS': [
        'resnet18', 'resnet34', 'resnet50',
        'densenet121', 'efficientnet_b0', 'mobilenet_v3_small',
        'shufflenet_v2_x1_0', 'vgg16'
    ],
    'OUTPUT_DIR': Path('./logs_ms'),
    'SEED': 42,
}

# -----------------------------------------------------------------------------
# 随机种子
# -----------------------------------------------------------------------------
random.seed(CONFIG['SEED'])
np.random.seed(CONFIG['SEED'])
torch.manual_seed(CONFIG['SEED'])

# -----------------------------------------------------------------------------
# 数据集
# -----------------------------------------------------------------------------
class AppleDataset(Dataset):
    """读取 PNG 图像并映射标签"""

    def __init__(self, paths: List[Path], labels: List[int], transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label

# -----------------------------------------------------------------------------
# DataLoader 构建
# -----------------------------------------------------------------------------

def compute_mean_std(paths, img_size):
    """计算图像数据集的均值和标准差"""
    imgs = []
    for p in paths:
        img = Image.open(p).convert('RGB').resize((img_size, img_size))
        img = np.array(img) / 255.0  # 归一化到0~1
        imgs.append(img)
    imgs = np.stack(imgs, axis=0)  # shape: (N, H, W, C)
    # 转为channel first: (N, C, H, W)
    imgs = imgs.transpose((0, 3, 1, 2))
    mean = imgs.mean(axis=(0, 2, 3))
    std = imgs.std(axis=(0, 2, 3))
    return mean, std

def build_dataloaders() -> Tuple[DataLoader, DataLoader]:
    df = (
        pd.read_excel(CONFIG['LABEL_FILE'], dtype={0: str})
        .iloc[:, [0, 3]]
        .dropna()
    )
    df.iloc[:, 0] = df.iloc[:, 0].str.zfill(3)
    id_to_label = dict(zip(df.iloc[:, 0], df.iloc[:, 1].astype(int)))

    paths, labels = [], []
    for p in sorted(CONFIG['DATA_DIR'].glob('*.png')):
        key = p.stem
        if key in id_to_label:
            paths.append(p)
            labels.append(id_to_label[key])
    assert paths, '未找到匹配的图像文件！'

    train_tf = transforms.Compose([
        transforms.Resize((CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'])),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
    ])
    val_tf = transforms.Compose([
        transforms.Resize((CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'])),
        transforms.ToTensor(),
    ])

    total = len(paths)
    train_size = int(total * CONFIG['TRAIN_RATIO'])
    indices = list(range(total))
    random.shuffle(indices)
    train_idx, val_idx = indices[:train_size], indices[train_size:]
    train_paths = [paths[i] for i in train_idx]
    val_paths = [paths[i] for i in val_idx]

    # 1. 统计训练集 mean/std
    mean, std = compute_mean_std(train_paths, CONFIG['IMG_SIZE'])
    print(f"[INFO] mean: {mean}, std: {std}")

    # 2. 按mean/std归一化
    normalize = transforms.Normalize(mean=mean.tolist(), std=std.tolist())

    train_tf = transforms.Compose([
        transforms.Resize((CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'])),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_tf = transforms.Compose([
        transforms.Resize((CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'])),
        transforms.ToTensor(),
        normalize,
    ])

    train_ds = AppleDataset(train_paths, [labels[i] for i in train_idx], train_tf)
    val_ds = AppleDataset(val_paths, [labels[i] for i in val_idx], val_tf)

    train_dl = DataLoader(train_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=True,
                          num_workers=CONFIG['NUM_WORKERS'], pin_memory=True)
    val_dl = DataLoader(val_ds, batch_size=CONFIG['BATCH_SIZE'] * 2, shuffle=False,
                        num_workers=CONFIG['NUM_WORKERS'], pin_memory=True)
    return train_dl, val_dl

# -----------------------------------------------------------------------------
# 模型工厂
# -----------------------------------------------------------------------------

def get_model(name: str, num_classes: int = 2, pretrained: bool = True):
    weight_map = {
        'resnet18': models.ResNet18_Weights.DEFAULT,
        'resnet34': models.ResNet34_Weights.DEFAULT,
        'resnet50': models.ResNet50_Weights.DEFAULT,
        'densenet121': models.DenseNet121_Weights.DEFAULT,
        'efficientnet_b0': models.EfficientNet_B0_Weights.DEFAULT,
        'mobilenet_v3_small': models.MobileNet_V3_Small_Weights.DEFAULT,
        'shufflenet_v2_x1_0': models.ShuffleNet_V2_X1_0_Weights.DEFAULT,
        'vgg16': models.VGG16_Weights.DEFAULT,
    }
    fn_map = {
        'resnet18': models.resnet18,
        'resnet34': models.resnet34,
        'resnet50': models.resnet50,
        'densenet121': models.densenet121,
        'efficientnet_b0': models.efficientnet_b0,
        'mobilenet_v3_small': models.mobilenet_v3_small,
        'shufflenet_v2_x1_0': models.shufflenet_v2_x1_0,
        'vgg16': models.vgg16,
    }
    weights = weight_map[name] if pretrained else None
    model = fn_map[name](weights=weights)

    if 'densenet' in name:
        in_dim = model.classifier.in_features
        model.classifier = nn.Linear(in_dim, num_classes)
    elif 'efficientnet' in name:
        in_dim = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_dim, num_classes)
    elif 'mobilenet' in name:
        in_dim = model.classifier[3].in_features
        model.classifier[3] = nn.Linear(in_dim, num_classes)
    elif 'shufflenet' in name or 'resnet' in name:
        in_dim = model.fc.in_features
        model.fc = nn.Linear(in_dim, num_classes)
    elif 'vgg' in name:
        in_dim = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(in_dim, num_classes)
    else:
        raise ValueError(f"未支持的模型: {name}")
    return model

# -----------------------------------------------------------------------------
# 评估 / 训练
# -----------------------------------------------------------------------------

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    return running_loss / len(dataloader.dataset)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    preds, targets, losses = [], [], []
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            losses.append(loss.item() * x.size(0))
            probs = torch.softmax(out, dim=1)[:, 1]
            preds.append(probs.cpu())
            targets.append(y.cpu())
    preds = torch.cat(preds).numpy()
    targets = torch.cat(targets).numpy()
    y_hat = (preds >= 0.5).astype(int)
    metrics = {
        'loss':              sum(losses) / len(dataloader.dataset),
        'accuracy':          accuracy_score(targets, y_hat),
        'balanced_accuracy': balanced_accuracy_score(targets, y_hat),
        'precision':         precision_score(targets, y_hat, zero_division=0),
        'recall':            recall_score(targets, y_hat, zero_division=0),
        'auc':               roc_auc_score(targets, preds) if len(np.unique(targets)) > 1 else 0.5,
        'mcc':               matthews_corrcoef(targets, y_hat),
        'f1':                f1_score(targets, y_hat, zero_division=0),
    }
    return metrics

# -----------------------------------------------------------------------------
# 统计 & CSV
# -----------------------------------------------------------------------------

def measure_inference_latency(model, dataloader, device, max_samples: int = 50, warmup: int = 5) -> float:
    """返回毫秒 / 图像的平均前向时间（取验证集前 max_samples 张图像）."""
    model.eval()
    imgs = []
    for x, _ in dataloader:
        imgs.append(x)
        if sum(b.size(0) for b in imgs) >= max_samples:
            break
    imgs = torch.cat(imgs)[:max_samples].to(device)

    with torch.no_grad():
        # -- warm-up 以避免首次 CUDA 加载开销
        for _ in range(warmup):
            _ = model(imgs)
        if device.startswith("cuda"):
            torch.cuda.synchronize()

        t0 = time.time()
        _ = model(imgs)
        if device.startswith("cuda"):
            torch.cuda.synchronize()
        elapsed = time.time() - t0

    return elapsed / imgs.size(0) * 1000.0  # ms / image


def save_csv_header(path: Path):
    if not path.exists():
        with path.open('w', newline='') as f:
            csv.writer(f).writerow([
                'epoch', 'loss', 'accuracy', 'balanced_accuracy', 'precision',
                'recall', 'auc', 'mcc', 'f1'
            ])


def append_csv_row(path: Path, epoch: int, metrics: dict):
    with path.open('a', newline='') as f:
        csv.writer(f).writerow([epoch] + [f"{metrics[k]:.6f}" for k in [
            'loss', 'accuracy', 'balanced_accuracy', 'precision',
            'recall', 'auc', 'mcc', 'f1']])


# -----------------------------------------------------------------------------
# 主入口
# -----------------------------------------------------------------------------

def run_training() -> None:
    out_dir = Path(CONFIG["OUTPUT_DIR"])
    out_dir.mkdir(parents=True, exist_ok=True)
    train_dl, val_dl = build_dataloaders()

    summary_records: List[dict] = []      # ← 用于最终 summary_results.xlsx

    for model_name in CONFIG["MODELS"]:
        for init in ("pretrained", "scratch"):
            tag = f"{model_name}--{init}"
            print(f"\n===== {tag} =====")

            # ── 目录 & CSV
            exp_dir = out_dir / tag
            exp_dir.mkdir(parents=True, exist_ok=True)
            csv_path = exp_dir / "epoch_metrics.csv"
            save_csv_header(csv_path)

            # ── 模型 & 统计 params/FLOPs
            model = get_model(model_name, pretrained=(init == "pretrained")).to(CONFIG["DEVICE"])
            num_params = sum(p.numel() for p in model.parameters())

            # FLOPs（若安装了 fvcore）
            if FlopCountAnalysis is not None:
                with torch.no_grad():
                    try:
                        from warnings import catch_warnings, simplefilter
                        dummy = torch.randn(1, 3, CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]).to(CONFIG["DEVICE"])
                        with catch_warnings():
                            simplefilter("ignore")
                            flops_total = int(FlopCountAnalysis(model, dummy).total())
                    except Exception:
                        flops_total = 0
            else:
                flops_total = 0

            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(
                model.parameters(), lr=CONFIG["LR"], weight_decay=CONFIG["WEIGHT_DECAY"]
            )

            # ── 训练循环
            best_val_f1 = 0.0
            best_val_metrics, best_train_metrics = {}, {}
            epoch_records: List[dict] = []
            start_time = time.time()

            for epoch in range(1, CONFIG["NUM_EPOCHS"] + 1):
                train_one_epoch(model, train_dl, criterion, optimizer, CONFIG["DEVICE"])
                train_metrics = evaluate(model, train_dl, criterion, CONFIG["DEVICE"])
                val_metrics   = evaluate(model, val_dl,   criterion, CONFIG["DEVICE"])

                append_csv_row(csv_path, epoch, val_metrics)

                # 控制台打印
                print(
                    f"Ep{epoch:02d} | "
                    f"tr_f1={train_metrics['f1']:.3f} tr_mcc={train_metrics['mcc']:.3f} "
                    f"val_f1={val_metrics['f1']:.3f} val_mcc={val_metrics['mcc']:.3f} "
                    f"val_auc={val_metrics['auc']:.3f}"
                )

                # 记录最佳（按 F1）
                if val_metrics["f1"] > best_val_f1:
                    best_val_f1  = val_metrics["f1"]
                    best_val_metrics = val_metrics.copy()
                    best_train_metrics = train_metrics.copy()
                    torch.save(model.state_dict(), exp_dir / "best.pth")

                # epoch 行（写入 Excel 用）
                epoch_records.append({
                    "epoch": epoch,
                    **{f"train_{k}": v for k, v in train_metrics.items()},
                    **{f"val_{k}":   v for k, v in val_metrics.items()},
                })

            # ── epoch_metrics.xlsx
            pd.DataFrame(epoch_records).to_excel(exp_dir / "epoch_metrics.xlsx", index=False)

            # ── 效率指标
            total_train_time = round(time.time() - start_time, 2)
            latency_ms       = measure_inference_latency(model, val_dl, CONFIG["DEVICE"])
            if torch.cuda.is_available():
                gpu_mem = round(torch.cuda.max_memory_allocated(CONFIG["DEVICE"]) / 1024 ** 2, 2)  # MB
                torch.cuda.reset_peak_memory_stats(CONFIG["DEVICE"])
            else:
                gpu_mem = 0.0

            # ── 写入 CSV 末尾
            with csv_path.open("a", newline="") as f:
                w = csv.writer(f)
                w.writerow(["total_train_time_sec", total_train_time])
                w.writerow(["inference_latency_ms", f"{latency_ms:.2f}"])
                w.writerow(["gpu_memory_MB", gpu_mem])
                w.writerow(["num_params", num_params])
                w.writerow(["FLOPs", f"{flops_total:.0f}"])

            # ── summary 行
            summary_records.append({
                "model": model_name,
                "init": init,
                # best – VAL
                "best_val_f1":  best_val_metrics.get("f1", 0),
                "best_val_auc": best_val_metrics.get("auc", 0),
                "best_val_mcc": best_val_metrics.get("mcc", 0),
                "best_val_bal_acc": best_val_metrics.get("balanced_accuracy", 0),
                # best – TRAIN
                "best_train_f1":  best_train_metrics.get("f1", 0),
                "best_train_auc": best_train_metrics.get("auc", 0),
                "best_train_mcc": best_train_metrics.get("mcc", 0),
                "best_train_bal_acc": best_train_metrics.get("balanced_accuracy", 0),
                # last-epoch losses
                "train_loss_last": train_metrics["loss"],
                "val_loss_last":   val_metrics["loss"],
                "epochs_run": CONFIG["NUM_EPOCHS"],
                "seconds": total_train_time,
                # efficiency
                "infer_latency_ms": latency_ms,
                "gpu_mem_MB": gpu_mem,
                "params": num_params,
                "FLOPs": flops_total,
            })

    # ── 保存 summary_results.xlsx
    pd.DataFrame(summary_records).to_excel(out_dir / "summary_results.xlsx", index=False)
    print(f"\nSummary saved → { (out_dir / 'summary_results.xlsx').resolve() }")


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn', force=True)
    run_training()