import os
import argparse
import json
import numpy as np
import pandas as pd
import torch, warnings
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

# imports from TSB_AD
from TSB_AD.utils.slidingWindows import find_length_rank
from TSB_AD.evaluation.metrics import get_metrics


from .model.MOMEMTO import MOMEMTO
from .utils.utils import (
    set_seed,
    ensure_dir,
    load_dataset_from_folder,
    windowize,
    stack_train_lists,
)


def evaluate_one_file(
    model: MOMEMTO,
    ts_norm: np.ndarray,
    labels: np.ndarray,
    win_len: int,
    stride: int,
) -> dict:
    sw = int(find_length_rank(ts_norm))

    test_x_np, test_m_np = windowize(ts_norm, win_len=win_len, stride=stride)
    test_x = torch.from_numpy(test_x_np).float().unsqueeze(1).to("cuda")
    test_m = torch.from_numpy(test_m_np).long().to("cuda")

    with torch.no_grad():
        _, mse = model.pred(test_x, test_m)  # (B,T)

    flat_score = torch.cat([p[m == 1] for p, m in zip(mse.cpu(), test_m.cpu())]).numpy()
    L = min(len(flat_score), len(labels))
    s, y = flat_score[:L], labels[:L].astype(int)

    gm = get_metrics(s, y, slidingWindow=sw)
    return {
        "AUC_ROC": float(gm.get("AUC_ROC", gm.get("AUC-ROC", np.nan))),
        "AUC_PR":  float(gm.get("AUC_PR",  gm.get("AUC-PR", np.nan))),
        "VUS_ROC": float(gm.get("VUS_ROC", gm.get("VUS-ROC", np.nan))),
        "VUS_PR":  float(gm.get("VUS_PR",  gm.get("VUS-PR", np.nan))),
    }


def run(args):
    set_seed(args.seed)
    ensure_dir(args.out_dir)

    # 1) Load data
    print("[1/4] Loading dataset...")
    train_files, train_masks, train_domains, test_files, _ = load_dataset_from_folder(
        folder_path=args.data_root,
        win_len=args.win_len,
        stride=args.stride,
    )
    print(f"Loaded {len(test_files)} files")

    # 2) Stack train splits into tensors
    print("[2/4] Building tensors...")
    train_X, train_M, train_D = stack_train_lists(train_files, train_masks, train_domains)

    # 3) Train model
    print("[3/4] Training model...")
    model = MOMEMTO(freeze_enc=args.freeze_enc, top_k=args.top_k)
    model.fit(train_X, train_M, train_D, epochs=args.epochs, lr=args.lr)

    # 4) Evaluate
    print("[4/4] Evaluating...")
    rows = []
    with torch.no_grad():
        for fname, (ts_norm, labels) in tqdm(test_files.items(), total=len(test_files)):
            res = evaluate_one_file(
                model=model,
                ts_norm=ts_norm,
                labels=labels,
                win_len=args.win_len,
                stride=args.stride,
            )
            res["file"] = fname
            rows.append(res)

    # Save outputs
    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(args.out_dir, "metrics.csv"), index=False)
    print(f"Saved: {os.path.join(args.out_dir, 'metrics.csv')}")
    print("Done.")


def build_parser():
    p = argparse.ArgumentParser(description="Train and evaluate MOMEMTO")
    p.add_argument("--data_root", type=str, required=True, help="Folder that contains CSV files")
    p.add_argument("--epochs", type=int, default=2)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--win_len", type=int, default=512)
    p.add_argument("--stride", type=int, default=512)
    p.add_argument("--top_k", type=int, default=3)
    p.add_argument("--freeze_enc", action="store_true")
    p.add_argument("--seed", type=int, default=1234)
    p.add_argument("--out_dir", type=str, default="./results_run")
    return p


if __name__ == "__main__":
    parser = build_parser()
    args = parser.parse_args()
    run(args)
