#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import os
import pandas as pd
import numpy as np
import torch
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split

# Moirai-MoE 관련 import
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule

def create_sliding_windows(data: np.ndarray, context_length: int):
    """
    data: shape (n_steps, n_channels)
    context_length: size of the past window
    return:
      windows: shape (num_samples, context_length, n_channels)
      targets: shape (num_samples, n_channels)
    """
    n_steps, n_channels = data.shape
    num_samples = n_steps - context_length
    if num_samples <= 0:
        return np.array([]), np.array([])

    # windows[i] -> data[i : i+context_length, :]
    # targets[i] -> data[i+context_length, :]
    all_windows = []
    all_targets = []
    for i in range(num_samples):
        w = data[i : i + context_length, :]  # (context_length, n_channels)
        t = data[i + context_length, :]      # (n_channels,)
        all_windows.append(w)
        all_targets.append(t)
    return np.array(all_windows), np.array(all_targets)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_direc', type=str, required=True,
                        help='Path to the CSV file that contains the timeseries data with label in the last column.')
    parser.add_argument('--model_size', type=str, default='base', choices=['small', 'base', 'large'],
                        help='Moirai model size to use. Default: small')
    parser.add_argument('--ctx', type=int, default=256,
                        help='Context length (history) for Moirai. Default=30')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Batch size for inference. Default=32')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    # ---------------------------------------
    # 1) Load CSV + Train/Test split
    # ---------------------------------------
    df = pd.read_csv(args.data_direc).dropna()
    # Drop the last column (anomaly label) → keep only time series features
    data = df.drop(['Label'],axis=1).values.astype(float)
    label = df['Label'].values.astype(float)  
    print(args.data_direc)
    # Example: Assume filename contains "_<train_index>_"
    #     e.g., mydata_1000_xxx.csv -> train_index = 1000
    # Get train_index from args.data_direc.split('_')[-3]
    train_index_str = args.data_direc.split('_')[-3]
    train_index = int(train_index_str)

    train_data = data[:train_index, :]   # shape=(train_steps, n_channel)
    test_data = data    # shape=(test_steps,  n_channel)

    n_channel = train_data.shape[1]
    total_steps = train_data.shape[0] + test_data.shape[0]

    print(f"[INFO] Loaded data: shape={data.shape}, label.shape={label.shape}")
    print(f"[INFO] Train shape={train_data.shape}, Test shape={test_data.shape}")

    # ---------------------------------------
    # 2) Create GluonTS Dataset (Test segment)
    #    → Predict one step at a time with rolling window
    # ---------------------------------------
    # Moirai uses PandasDataset
    # Create DataFrame with time index (either date or integer)
    idx_train = np.arange(train_data.shape[0])
    idx_test  = np.arange(train_data.shape[0], total_steps)
    
    df_test = pd.DataFrame(test_data, index=idx_test)  # shape = (test_steps, n_channel)

    dates = pd.date_range("2025-01-01", periods=len(test_data), freq="S")
    ts_series = pd.Series(test_data[:, 0].astype("float32"), index=dates)

    test_ds = PandasDataset({"item_0": ts_series}, freq="S")

    train_pd, test_template = split(test_ds, offset=0)

    test_data_rolled = test_template.generate_instances(
        prediction_length=1,
        windows=len(ts_series),
        distance=1,
    )



    # ---------------------------------------
    # 3) Load model (Moirai-MoE)
    # ---------------------------------------
    # e.g., "Salesforce/moirai-1.1-R-small" 
    
    model = MoiraiForecast(
        module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{args.model_size}"),
        prediction_length=1,
        context_length=args.ctx,
        patch_size=64,
        num_samples=100,
        target_dim=1,
        feat_dynamic_real_dim=0,
        past_feat_dynamic_real_dim=0,
    )

    predictor = model.create_predictor(batch_size=args.batch_size)

    # ---------------------------------------
    # 4) Prediction & Anomaly Score (MSE)
    # ---------------------------------------
    # Use rolling test_data_rolled
    # Predict one step for each window → calculate error
    forecasts = predictor.predict(test_data_rolled.input)
    forecast_iter = iter(forecasts)
    print("[INFO] Predicting...")
    anomaly_scores = []
    for i, (inp, lbl) in enumerate(zip(test_data_rolled.input, test_data_rolled.label)):
        fct = next(forecast_iter)

        # fct.mean.shape = (1, n_channel) -> (prediction_length=1, target_dim=n_channel)
        y_pred = fct.mean.reshape(-1)   # shape (n_channel,)
        y_true = lbl["target"].reshape(-1)  # shape (n_channel,)

        # Channel-wise (y_true - y_pred)^2
        # Here we average the errors across channels to get a single score
        err = np.mean((y_true - y_pred)**2)
        anomaly_scores.append(y_pred.tolist() + y_true.tolist() + [err])

    res_df = pd.DataFrame(anomaly_scores)
    n_ch = y_pred.shape[0]
    if n_ch==1:
        res_df_col = ['true_y','pred_y','mse_score']
    else:
        res_df_col = [f"true_ch{ch}" for ch in range(n_ch)] + [f"pred_ch{ch}" for ch in range(n_ch)] + ['mse_score']
    res_df.columns = res_df_col

    res_df['label'] = label
    # anomaly_scores[i] → MSE at rolled index i (i.e., in the last test_steps segment)

    if n_ch==1:
        res_df.to_csv(f'./results/Moirai_ZS_{args.model_size}/'+str(args.data_direc.split('_')[-9][-3:]) +'_U_MOIRAI_ZS.csv')
    else:
        res_df.to_csv(f'./results/Moirai_ZS_{args.model_size}/'+str(args.data_direc.split('_')[-9][-3:]) +'_M_MOIRAI_ZS.csv')
    print("[INFO] Done predicting. Anomaly scores shape:", res_df.shape)
    
    print("Done.")

if __name__ == "__main__":
    main()
