from __future__ import annotations

import argparse
import os
import random
import re
import sys
import time
from pathlib import Path
from typing import Dict, List, Tuple

import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)
from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score, balanced_accuracy_score,
    matthews_corrcoef, precision_score, recall_score
)
from sklearn.model_selection import train_test_split

# ---------------------------------------------------------------------------
# Configuration – default values (can be overridden via CLI)
# ---------------------------------------------------------------------------
CONFIG: Dict[str, object] = {
    "DATA_DIR": Path("./MNR_CL_text_right"),            # Root directory with texts
    "LABEL_FILE": Path("./苹果称重-1-7.xlsx"),      # Excel labels
    # Which cols contain ID & target?  Either *name* (str) or *index* (int)
    "LABEL_ID_COL": 0,
    "LABEL_TARGET_COL": 3,
    "ID_ZFILL": 0,      # length to zero‑pad IDs – adjust to your naming rule

    "TRAIN_RATIO": 0.8,
    "MAX_LENGTH": 512,  # 设置1024有异常
    "BATCH_SIZE": 32,
    "NUM_EPOCHS": 50,
    "LR": 1e-5,
    "WEIGHT_DECAY": 1e-2,
    "WARMUP_RATIO": 0.1,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "NUM_WORKERS": 0 if os.name == "nt" else 4,
    "MODELS": [
        "bert-base-uncased",
        "roberta-base",
        "distilbert-base-uncased",
        "albert-base-v2",
    ],
    "OUTPUT_DIR": Path("./logs_right_text"),
    "SEED": 42,
    "EARLY_STOP": 5,
    "USE_IMBALANCE_SAMPLER": True,
}

# ---------------------------------------------------------------------------
# CLI override – optional
# ---------------------------------------------------------------------------
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--label_id_col")
parser.add_argument("--label_target_col")
parser.add_argument("--id_zfill", type=int)
args, _ = parser.parse_known_args()
if args.label_id_col is not None:
    CONFIG["LABEL_ID_COL"] = int(args.label_id_col) if args.label_id_col.isdigit() else args.label_id_col
if args.label_target_col is not None:
    CONFIG["LABEL_TARGET_COL"] = int(args.label_target_col) if args.label_target_col.isdigit() else args.label_target_col
if args.id_zfill is not None:
    CONFIG["ID_ZFILL"] = args.id_zfill

# ---------------------------------------------------------------------------
# Reproducibility helpers
# ---------------------------------------------------------------------------

def set_seed(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(int(CONFIG["SEED"]))

# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------


class AppleTextDataset(Dataset):
    """Text + label dataset for an apple sample."""

    def __init__(self, df: pd.DataFrame, tokenizer, max_len: int):
        self.texts = df["text"].tolist()
        self.labels = df["label"].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt",
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item


# ---------------------------------------------------------------------------
# Data processing helpers
# ---------------------------------------------------------------------------


def _get_col(df: pd.DataFrame, col_spec):
    """Return a Series per column *spec*, which may be name (str) or idx (int)."""
    if isinstance(col_spec, int):
        return df.iloc[:, col_spec]
    else:
        return df[col_spec]


def load_labels(path: Path) -> pd.DataFrame:
    df_raw = pd.read_excel(path, dtype={CONFIG["LABEL_ID_COL"]: str})
    id_series = _get_col(df_raw, CONFIG["LABEL_ID_COL"]).str.zfill(CONFIG["ID_ZFILL"])
    label_series = _get_col(df_raw, CONFIG["LABEL_TARGET_COL"])

    # 新增清洗逻辑
    valid_mask = label_series.notna() & ~label_series.isin([float('inf'), float('-inf')])
    id_series = id_series[valid_mask]
    label_series = label_series[valid_mask].astype(int)

    return pd.DataFrame({'id': id_series, 'label': label_series})



def gather_texts(data_dir: Path) -> pd.DataFrame:
    """从每个模型文件夹中收集文本，以列形式返回：id, model1, model2, ..."""
    model_to_texts = {}
    pat = re.compile(r"(\d+)")

    for model_dir in data_dir.iterdir():
        if not model_dir.is_dir():
            continue
        texts = {}
        for txt_path in model_dir.glob("*.txt"):
            match = pat.search(txt_path.stem)
            if not match:
                continue
            idx = match.group(1).zfill(CONFIG["ID_ZFILL"])
            texts[idx] = txt_path.read_text(encoding="utf-8", errors="ignore")
        model_to_texts[model_dir.name] = pd.Series(texts, name=model_dir.name)

    all_ids = sorted(set().union(*[s.index for s in model_to_texts.values()]))
    df = pd.DataFrame(index=all_ids)
    for model, series in model_to_texts.items():
        df[model] = series
    df.reset_index(inplace=True)
    df.rename(columns={"index": "id"}, inplace=True)
    return df


def prepare_data(text_column: str) -> Tuple[pd.DataFrame, pd.DataFrame, int]:
    texts_df = gather_texts(Path(CONFIG["DATA_DIR"]))
    labels_df = load_labels(Path(CONFIG["LABEL_FILE"]))
    merged = pd.merge(texts_df, labels_df, on="id", how="inner")

    if merged.empty:
        print("❌ Merge produced 0 rows – check ID extraction & zero‑fill length.")
        sys.exit(1)
    dropped = len(labels_df) - len(merged)
    if dropped:
        print(f"⚠️  {dropped} label rows had no matching text files and were skipped.")

    label_map = {l: i for i, l in enumerate(sorted(merged["label"].unique()))}
    merged["label"] = merged["label"].map(label_map)

    # 👇 支持融合所有模型文本
    if text_column == "concat":
        model_cols = [c for c in merged.columns if c not in {"id", "label"}]
        merged["text"] = merged[model_cols].fillna("").agg(" ".join, axis=1)
    else:
        merged["text"] = merged[text_column]

    train_df, val_df = train_test_split(
        merged[["id", "text", "label"]],
        test_size=1 - CONFIG["TRAIN_RATIO"],
        random_state=int(CONFIG["SEED"]),
        stratify=merged["label"],
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), len(label_map)


def make_loader(df: pd.DataFrame, tokenizer, is_train: bool) -> DataLoader:
    ds = AppleTextDataset(df, tokenizer, int(CONFIG["MAX_LENGTH"]))

    if is_train and CONFIG["USE_IMBALANCE_SAMPLER"] and df["label"].nunique() > 1:
        label_counts = df["label"].value_counts().sort_index().values
        weights = 1.0 / torch.tensor(label_counts, dtype=torch.float)
        sample_weights = weights[df["label"].values]
        sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
        return DataLoader(ds, batch_size=int(CONFIG["BATCH_SIZE"]), sampler=sampler, num_workers=int(CONFIG["NUM_WORKERS"]))

    return DataLoader(
        ds,
        batch_size=int(CONFIG["BATCH_SIZE"]),
        shuffle=is_train,
        num_workers=int(CONFIG["NUM_WORKERS"]),
    )


# ---------------------------------------------------------------------------
# Training / evaluation loops (unchanged)
# ---------------------------------------------------------------------------


def train_epoch(model: nn.Module, loader: DataLoader, optim, sched, device) -> float:
    model.train()
    losses = []
    for batch in tqdm(loader, leave=False):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        loss = out.loss
        loss.backward()
        optim.step()
        sched.step()
        optim.zero_grad()
        losses.append(loss.item())
    return float(sum(losses) / len(losses))


def compute_metrics(golds, preds, probs, n_labels: int) -> Dict[str, float]:
    metrics = {
        "acc": accuracy_score(golds, preds),
        "bal_acc": balanced_accuracy_score(golds, preds),
        "f1": f1_score(golds, preds, average="macro", zero_division=0),
        "precision": precision_score(golds, preds, average="macro", zero_division=0),
        "recall": recall_score(golds, preds, average="macro", zero_division=0),
        "mcc": matthews_corrcoef(golds, preds),
    }
    try:
        if n_labels == 2:
            metrics["auc"] = roc_auc_score(golds, [p[1] for p in probs])
        else:
            metrics["auc"] = roc_auc_score(golds, probs, multi_class="ovr", average="macro")
    except ValueError:
        metrics["auc"] = float("nan")
    return metrics

def eval_epoch(model: nn.Module, loader: DataLoader, device, n_labels: int) -> Dict[str, float]:
    model.eval()
    losses, preds, probs, golds = [], [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            out = model(**batch)
            losses.append(out.loss.item())
            soft = out.logits.softmax(dim=-1).cpu()
            probs.extend(soft.numpy())
            preds.extend(soft.argmax(dim=-1).numpy())
            golds.extend(batch["labels"].cpu().numpy())
    metrics = compute_metrics(golds, preds, probs, n_labels)
    metrics["loss"] = float(sum(losses) / len(losses))
    return metrics


# ---------------------------------------------------------------------------
# Main experiment orchestration (unchanged except minor prints)
# ---------------------------------------------------------------------------


def run() -> None:
    """Master loop: text_source → backbone → init, 记录 epoch & summary Excel"""
    out_dir = Path(CONFIG["OUTPUT_DIR"])
    out_dir.mkdir(parents=True, exist_ok=True)

    # 自动识别子文件夹 + concat
    data_dir = Path(CONFIG["DATA_DIR"])
    text_sources = [d.name for d in data_dir.iterdir() if d.is_dir()] + ["concat"]

    summary_records: List[Dict[str, object]] = []

    for text_source in text_sources:
        print(f"\n### Using text from: {text_source} ###")
        train_df, val_df, num_labels = prepare_data(text_source)

        for backbone in CONFIG["MODELS"]:
            for init in ("pretrained", "scratch"):
                tag = f"{text_source}__{backbone.replace('/', '_')}--{init}"
                print(f"\n===== {tag} =====")

                # ── 目录 & TensorBoard
                exp_dir = out_dir / tag
                exp_dir.mkdir(parents=True, exist_ok=True)
                tb = SummaryWriter(log_dir=exp_dir / "tensorboard")

                # ── 模型
                tok = AutoTokenizer.from_pretrained(backbone)
                if init == "pretrained":
                    model = AutoModelForSequenceClassification.from_pretrained(
                        backbone, num_labels=num_labels
                    )
                else:
                    cfg = AutoConfig.from_pretrained(backbone, num_labels=num_labels)
                    model = AutoModelForSequenceClassification.from_config(cfg)
                model.to(CONFIG["DEVICE"])

                # ── DataLoader & 优化器
                train_loader = make_loader(train_df, tok, True)
                val_loader = make_loader(val_df, tok, False)
                optim = torch.optim.AdamW(
                    model.parameters(),
                    lr=float(CONFIG["LR"]),
                    weight_decay=float(CONFIG["WEIGHT_DECAY"]),
                )
                total_steps = len(train_loader) * int(CONFIG["NUM_EPOCHS"])
                sched = get_linear_schedule_with_warmup(
                    optim,
                    num_warmup_steps=int(CONFIG["WARMUP_RATIO"] * total_steps),
                    num_training_steps=total_steps,
                )

                # ── 训练循环
                best_val_f1, no_improve = 0.0, 0
                best_val_metrics, best_train_metrics = {}, {}
                epoch_records: List[Dict[str, float]] = []
                start = time.time()

                for epoch in range(1, int(CONFIG["NUM_EPOCHS"]) + 1):
                    _ = train_epoch(model, train_loader, optim, sched, CONFIG["DEVICE"])
                    train_metrics = eval_epoch(model, train_loader, CONFIG["DEVICE"], num_labels)
                    val_metrics   = eval_epoch(model, val_loader,   CONFIG["DEVICE"], num_labels)

                    # TensorBoard
                    for split, mets in (("train", train_metrics), ("val", val_metrics)):
                        for k, v in mets.items():
                            tb.add_scalar(f"{k.upper()}/{split}", v, epoch)

                    # console
                    print(
                        f"Ep{epoch:02d} | "
                        f"tr_f1={train_metrics['f1']:.3f} tr_mcc={train_metrics['mcc']:.3f} "
                        f"val_f1={val_metrics['f1']:.3f} val_mcc={val_metrics['mcc']:.3f} "
                        f"val_auc={val_metrics['auc']:.3f}"
                    )

                    # 早停 & 记录最佳
                    if val_metrics["f1"] > best_val_f1:
                        best_val_f1 = val_metrics["f1"]
                        best_val_metrics = val_metrics.copy()
                        best_train_metrics = train_metrics.copy()
                        no_improve = 0
                    else:
                        no_improve += 1
                        if CONFIG["EARLY_STOP"] and no_improve >= int(CONFIG["EARLY_STOP"]):
                            print("Early stopping ✋")
                            epoch += 1  # 计入 epochs_run
                            break

                    # epoch Excel 行
                    epoch_records.append(
                        {"epoch": epoch, **{f"train_{k}": v for k, v in train_metrics.items()},
                                         **{f"val_{k}":   v for k, v in val_metrics.items()}}
                    )

                tb.close()
                elapsed = round(time.time() - start, 2)

                # ── 保存 epoch_metrics.xlsx
                pd.DataFrame(epoch_records).to_excel(exp_dir / "epoch_metrics.xlsx", index=False)

                # ── 汇总
                summary_records.append({
                    "src_text": text_source,
                    "model": backbone,
                    "init": init,
                    # best – VAL
                    "best_val_f1":  best_val_metrics.get("f1", 0),
                    "best_val_auc": best_val_metrics.get("auc", 0),
                    "best_val_mcc": best_val_metrics.get("mcc", 0),
                    "best_val_bal_acc": best_val_metrics.get("bal_acc", 0),
                    # best – TRAIN (与 best VAL 同 epoch)
                    "best_train_f1":  best_train_metrics.get("f1", 0),
                    "best_train_auc": best_train_metrics.get("auc", 0),
                    "best_train_mcc": best_train_metrics.get("mcc", 0),
                    "best_train_bal_acc": best_train_metrics.get("bal_acc", 0),
                    # last-epoch losses
                    "train_loss_last": train_metrics["loss"],
                    "val_loss_last":   val_metrics["loss"],
                    "epochs_run": epoch,
                    "seconds": elapsed,
                })

    # ── 保存 summary_results.xlsx
    summary_df = pd.DataFrame(summary_records)
    summary_df.to_excel(out_dir / "summary_results.xlsx", index=False)
    print(f"\nSummary saved → { (out_dir / 'summary_results.xlsx').resolve() }")

if __name__ == "__main__":
    run()