#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import os
import pandas as pd
import numpy as np
import torch

# Chronos import
from chronos import BaseChronosPipeline

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_direc', type=str, required=True,
                        help='Path to the CSV file that contains the timeseries data with label in the last column.')
    parser.add_argument('--model_name', type=str, default='amazon/chronos-bolt-base',
                        help='Chronos model name on HF Hub. e.g. amazon/chronos-bolt-base')
    parser.add_argument('--ctx', type=int, default=30,
                        help='Context (history) length for sliding window. Default=30')
    parser.add_argument('--device_map', type=str, default='cuda',
                        choices=['cpu', 'cuda'],
                        help='Device for inference. Default=cuda')
    parser.add_argument('--batch_size', type=int, default=512,
                        help='Batch size for inference. Default=512')
    args = parser.parse_args()
    return args

def create_sliding_windows(data: np.ndarray, context_length: int):
    """
    data: shape (n_steps, n_channels)
    context_length: size of the past window
    return:
      windows: shape (num_samples, context_length, n_channels)
      targets: shape (num_samples, n_channels)
    """
    n_steps, n_channels = data.shape
    num_samples = n_steps - context_length
    if num_samples <= 0:
        return np.array([]), np.array([])

    # windows[i] -> data[i : i+context_length, :]
    # targets[i] -> data[i+context_length, :]
    all_windows = []
    all_targets = []
    for i in range(num_samples):
        w = data[i : i + context_length, :]  # (context_length, n_channels)
        t = data[i + context_length, :]      # (n_channels,)
        all_windows.append(w)
        all_targets.append(t)
    return np.array(all_windows), np.array(all_targets)


def build_chronos_df_for_inference(channel_data, context_length):
    """
    To allow Chronos-Bolt to predict multiple "sliding windows" at once,
    a long-form DataFrame structure similar to GluonTS is required.

    channel_data: shape (num_samples, context_length)
    return: pd.DataFrame with columns=["item_id", "timestamp", "target"]
      * item_id : ID of each sliding window
      * timestamp : 0, 1, 2... (relative time index)
      * target : actual past observations
    """
    # channel_data.shape = (num_windows, context_length)
    # → Stack this into a "long" format and build a DataFrame
    rows = []
    num_windows = channel_data.shape[0]
    for w_id in range(num_windows):
        for t_idx in range(context_length):
            rows.append([w_id, t_idx, channel_data[w_id, t_idx]])
    df = pd.DataFrame(rows, columns=["item_id", "timestamp", "target"])
    return df


def main():
    args = parse_args()

    # -------------------------
    # 1) load csv file
    # -------------------------
    df = pd.read_csv(args.data_direc).dropna()
    # The last column is the label (anomaly flag) -> exclude from time series features
    data = df.drop(['Label'],axis=1).values.astype(float)
    label = df['Label'].values.astype(float)  

    # File name must contain "_<train_index>_" → parse train_index from the filename
    train_index_str = args.data_direc.split('_')[-3]
    train_index = int(train_index_str)

    # (train_data, test_data) split
    train_data = data[:train_index, :]   # shape = (train_steps, n_channels)
    test_data  = data                    # shape = (total_steps, n_channels)
    n_channels = test_data.shape[1]

    print(f"[INFO] Loaded data: total shape={data.shape}, label.shape={label.shape}")
    print(f"[INFO] Train shape={train_data.shape}, Test shape={test_data.shape}")

    # -------------------------
    # 2) Sliding windows for the test segment
    # -------------------------
    # create_sliding_windows -> (num_samples, ctx, n_channels)
    test_windows, test_targets = create_sliding_windows(test_data, args.ctx)
    if test_windows.size == 0:
        print("[WARN] Not enough data for the given context_length.")
        # save empty result or handle gracefully
        return

    num_samples = test_windows.shape[0]
    print(f"[INFO] test_windows: {test_windows.shape}, test_targets: {test_targets.shape}")

    # -------------------------
    # 3) Load Chronos model
    # -------------------------
    print(f"[INFO] Loading Chronos pipeline: {args.model_name}")
    pipeline = BaseChronosPipeline.from_pretrained(
        args.model_name,
        device_map=args.device_map,
        torch_dtype=torch.float32
    )
    pipeline.model.eval()

    # -------------------------
    # 4) Channel-wise prediction (in batches) + MSE calculation
    # -------------------------
    # Since Chronos-Bolt only handles 1D time series,
    #   split (num_samples, ctx, n_channels) by channel -> (num_samples, ctx)
    #   => convert to DataFrame -> pipeline.predict -> predictions (num_samples x pred_length=1)
    #   => MSE(target vs. prediction)
    preds_matrix = np.zeros_like(test_targets)  # shape (num_samples, n_channels)

    args.batch_size
    for ch in range(n_channels):
        # shape: (num_samples, ctx)
        channel_windows = test_windows[:, :, ch]

        # Final prediction array
        preds_for_ch = np.zeros(num_samples, dtype=np.float32)

        # (B) Loop for batch processing
        start_idx = 0
        while start_idx < num_samples:
            end_idx = min(start_idx + args.batch_size, num_samples)

            batch_input = channel_windows[start_idx:end_idx]  # shape=(batch, ctx)
            # Convert to tensor
            batch_input_torch = torch.tensor(batch_input, dtype=torch.float32, device=args.device_map)

            with torch.no_grad():
                quantiles, mean = pipeline.predict_quantiles(batch_input_torch, prediction_length=1)

            # If mean.shape = (batch_size,) or (batch_size, 1), apply squeeze
            mean_np = mean.detach().cpu().numpy().ravel()  # shape=(batch,)

            # Assign batch results to preds_for_ch
            preds_for_ch[start_idx:end_idx] = mean_np

            start_idx = end_idx

        # Store predictions for each channel
        preds_matrix[:, ch] = preds_for_ch

    # (D) Calculate MSE
    errors = (test_targets - preds_matrix)**2
    anomaly_scores = errors.mean(axis=1)  # MSE for each sliding window


    # -------------------------
    # 5) Result DataFrame & CSV
    # -------------------------
    # anomaly_scores.shape = (num_samples,)
    # preds_matrix.shape   = (num_samples, n_channels)
    # test_targets.shape   = (num_samples, n_channels)
    # label.shape          = (total_steps,)
    #   (Since the sliding window index differs from the actual label index,
    #    apply an offset if necessary to align them)

    # Example: simply map with label[-num_samples:] to align with window indices
    label_for_rolled = label[-num_samples:] if len(label) >= num_samples else np.zeros(num_samples)

    df_res = pd.DataFrame({
        "mse_score": anomaly_scores,
        "label": label_for_rolled
    })
    # Save only the predictions of the first channel as an example
    n_ch = preds_matrix.shape[1]
    
    # Ground truth (test segment)
    # Original test_data
    if n_ch==1:
        true_colnames = ["true_y"]
    else:
        true_colnames = [f"true_ch{ch}" for ch in range(n_ch)]
    true_df = pd.DataFrame(test_targets, columns=true_colnames)
    df_res = pd.concat([true_df, df_res], axis=1)

    if n_ch==1:
        pred_colnames = ["pred_y"]
    else:
        pred_colnames = [f"pred_ch{ch}" for ch in range(n_ch)]
    pred_df = pd.DataFrame(preds_matrix, columns=pred_colnames)
    df_res = pd.concat([pred_df, df_res], axis=1)


    if args.model_name == 'amazon/chronos-bolt-base':
        model_size = 'base'
    else:
        model_size = 'small'
    # file name
    short_id = args.data_direc.split('_')[-9][-3:] if len(args.data_direc.split('_')) >= 9 else "000"
    if n_ch == 1:
        out_name = f"./results/Chronos_ZS_{model_size}/{short_id}_U_CHRONOS_ZS.csv"
    else:
        out_name = f"./results/Chronos_ZS_{model_size}/{short_id}_M_CHRONOS_ZS.csv"
    df_res.to_csv(out_name, index=False)
    print(f"[INFO] Done. Saved results to {out_name}")

if __name__ == "__main__":
    main()
