#!/usr/bin/env python3

import numpy as np
import json
from pathlib import Path

script_dir = Path(__file__).resolve().parent

def slice_data(dataset, observation_length, prediction_horizon, overlap, force_recreate=False, dtype=np.float32, normaliser=lambda x: x, reduced=False):
    if overlap < 0:
        overlap = observation_length + overlap
    config_path = script_dir / f"data/{dataset}/sliced/last_config.json"
    save_path = script_dir / f"data/{dataset}/sliced/"
    save_path.mkdir(parents=True, exist_ok=True)

    # Paths for data files
    x_file = save_path / "X.npy"
    y_file = save_path / "Y.npy"

    # Current configuration
    current_config = {
        "observation_length": observation_length,
        "prediction_horizon": prediction_horizon,
        "overlap": overlap
    }

    # Check if the configuration matches the last used one and if files exist
    if config_path.exists() and x_file.exists() and y_file.exists() and not force_recreate:
        with open(config_path, "r") as f:
            last_config = json.load(f)
        if last_config == current_config:
            print("Configuration matches the last used one, and data files exist. Skipping data recreation.")
            return np.load(x_file).astype(dtype), np.load(y_file).astype(dtype)
        else:
            print("Configuration does not match. Recreating data...")
    else:
        print("Configuration file or data files missing. Recreating data...")

    # Proceed to recreate the data
    data = np.load(script_dir / f"data/{dataset}/raw/raw.npy").astype(dtype)
    X, Y = make_windows(data, OL=observation_length, H=prediction_horizon, overlap=overlap, normaliser=normaliser, reduced=reduced)
    if not force_recreate:
        # Save the current configuration
        np.save(x_file, X)
        np.save(y_file, Y)
        with open(config_path, "w") as f:
            json.dump(current_config, f)

    return X, Y


def make_windows(
        data: np.ndarray,
        OL: int,
        H: int,
        overlap: int = 10,
        normaliser=lambda x: x,
        reduced=False
):
    """
    Build sliding-window inputs X with zero-padding and multi-step targets Y.
    Rescale each variable to have a minimum of 0 and a specified variance.

    Args:
        data: np.ndarray, shape (n_vars, T)
        OL: observation length (>= 50)
        H: horizon (>= 1)
        overlap: number of points overlapping between windows
        target_variance: target variance for rescaling (default 1)

    Returns:
        X: np.ndarray, shape (n_slices, n_vars, OL + H)
        Y: np.ndarray, shape (n_slices, n_vars, H)
    """
    n_vars, T = data.shape
    step = OL - overlap
    if step <= 0:
        raise ValueError("Overlap must be less than OL")

    # Compute number of windows
    n_slices = (T - OL - H) // step + 1
    if n_slices <= 0:
        raise ValueError("Time series too short for given OL and H")

    data_rescaled = normaliser(data)

    # Prepare X and Y
    if reduced:
        X = np.empty((n_slices, n_vars, OL), dtype=data.dtype)
    else:
        X = np.empty((n_slices, n_vars, OL + H), dtype=data.dtype)
    Y = np.empty((n_slices, n_vars, H), dtype=data.dtype)

    for i in range(n_slices):
        start = i * step
        # Fill X with observed data and append H zeros
        X[i, :, :OL] = data_rescaled[:, start: start + OL]
        if not reduced:
            X[i, :, OL:] = 0  # Append H zeros
        # Fill Y with the target data
        Y[i] = data_rescaled[:, start + OL: start + OL + H]

    return X, Y



