#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse, os, gc, torch
import numpy as np, pandas as pd
from TSB_AD.evaluation.metrics import get_metrics
from MOMENT_custom import MOMENT_custom

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] using device: {device}")

# ───── argparse ───────────────────────────────────────────────
p = argparse.ArgumentParser()
p.add_argument('--csv',     required=True,  help='input CSV (TSB-AD format)')
p.add_argument('--out_dir', default='./results/MOMENT')
p.add_argument('--win',     type=int, default=256, help='MOMENT window size')
args = p.parse_args()

csv_path  = args.csv
out_dir   = args.out_dir


os.makedirs(out_dir, exist_ok=True)

df   = pd.read_csv(csv_path).dropna()
data = df.iloc[:, :-1].values.astype(float)
label = df['Label'].astype(int).to_numpy()

try:
    train_len = int(os.path.basename(csv_path).split('_')[-3])
except Exception:
    train_len = len(data)           # train=test (fallback)

data_train = data[:train_len]
data_test  = data                

clf = MOMENT_custom(win_size=args.win)
n_ch = data.shape[1]
clf.batch_size = max(1, int(32 / n_ch * args.win)) 
clf.model.embedding_vector = torch.tensor(0)        

clf.fit(data_train)
score = clf.decision_function(data_test)


np.save(os.path.join(out_dir,
         f"{os.path.basename(csv_path)}_score.npy"), score)

del clf, score, data, data_train, data_test
gc.collect()
torch.cuda.empty_cache()
