import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


def generate_linear_trend(num_points, slope, bias):
    t = np.arange(num_points)
    y = t * slope + bias
    return y


def generate_const_trend(num_points, level):
    return np.ones(num_points) * level


def generate_trend_from_segments(segments):
    ys = [callback(*args) for (callback, args) in segments]
    y = np.concatenate(ys)
    return y


def generate_ar1_noise(num_points, phi=0.8, sigma=0.2, seed=None):
    """
    Generates AR(1) noise: e_t = phi * e_{t-1} + N(0, sigma^2).
    """
    if seed is not None:
        np.random.seed(seed)
    noise = np.zeros(num_points)
    noise[0] = np.random.normal(0, sigma)
    for t in range(1, num_points):
        noise[t] = phi * noise[t - 1] + np.random.normal(0, sigma)
    return noise


def generate_grouped_noise(
    num_points, n_groups=3, group_means=None, group_sigmas=None, seed=None
):
    """
    Generates Gaussian noise in contiguous groups.

    Parameters:
        num_points (int): Total number of points.
        n_groups (int): Number of groups.
        group_means (list or array): Means for each group. If None, defaults to equally spaced means from -0.5 to 0.5.
        group_sigmas (list or array): Standard deviations for each group. If None, defaults to 0.2 for every group.
        seed (int): Random seed for reproducibility.

    Returns:
        noise (np.ndarray): An array of Gaussian noise with group-specific parameters.
    """
    if seed is not None:
        np.random.seed(seed)

    # Determine group sizes (contiguous blocks)
    group_sizes = [num_points // n_groups] * n_groups
    remainder = num_points % n_groups
    for i in range(remainder):
        group_sizes[i] += 1

    # Default group means and sigmas, if not provided
    if group_means is None:
        group_means = np.linspace(-0.5, 0.5, n_groups)
    if group_sigmas is None:
        group_sigmas = [0.2] * n_groups

    noise = np.zeros(num_points)
    start = 0
    for i in range(n_groups):
        end = start + group_sizes[i]
        noise[start:end] = np.random.normal(
            loc=group_means[i], scale=group_sigmas[i], size=group_sizes[i]
        )
        start = end
    return noise


def add_outliers(signal, outlier_frac=0.01, outlier_scale=2.0, seed=None):
    """
    Adds random outliers to the signal.
    """
    if seed is not None:
        np.random.seed(seed)
    n = len(signal)
    n_outliers = int(n * outlier_frac)
    outlier_indices = np.random.choice(n, size=n_outliers, replace=False)
    outliers = np.random.normal(0, outlier_scale, size=n_outliers)
    signal_with_outliers = signal.copy()
    signal_with_outliers[outlier_indices] += outliers
    return signal_with_outliers


def main():
    segments = [
        (generate_const_trend, [80, -1]),
        (generate_const_trend, [80, 4]),
        (generate_linear_trend, [80, -0.1, 4]),
        (generate_linear_trend, [40, 0.2, -4]),
    ] * 5
    num_points = sum((seg_num_points for (_, (seg_num_points, *_)) in segments))

    # Generate piecewise trend signal.
    trend_signal = generate_trend_from_segments(segments)

    # Generate AR(1) noise.
    ar1_noise = generate_ar1_noise(num_points, phi=0.8, sigma=0.2, seed=42)

    # Generate grouped Gaussian noise; you can adjust n_groups and parameters.
    grouped_noise = generate_grouped_noise(
        num_points,
        n_groups=3,
        group_means=[-0.2, 0.0, 0.2],
        group_sigmas=[0.1, 0.3, 0.2],
        seed=24,
    )

    # Combine the noises.
    combined_noise = ar1_noise + grouped_noise

    # Add noise to the trend signal.
    trend_plus_noise = trend_signal + combined_noise

    # Add outliers.
    full_signal = add_outliers(
        trend_plus_noise, outlier_frac=0.02, outlier_scale=2.0, seed=123
    )

    # Plot the results.
    fig, axes = plt.subplots(3, 1, figsize=(10, 10), sharex=True)
    axes[0].plot(trend_signal, label="Trend Signal", color="C0")
    axes[0].set_title("Trend Signal")
    axes[0].legend(loc="upper right")

    axes[1].plot(
        trend_plus_noise, label="Trend + AR(1) Noise + Grouped Noise", color="C1"
    )
    axes[1].set_title("Trend with Combined Noise")
    axes[1].legend(loc="upper right")

    axes[2].plot(full_signal, label="Full Signal with Outliers", color="C2")
    axes[2].set_title("Final Signal")
    axes[2].legend(loc="upper right")

    plt.tight_layout()
    plt.savefig("toy.png")

    out_data = trend_plus_noise

    df = pd.DataFrame(
        data=np.stack([out_data[1:], out_data[:-1]], axis=1),
        columns=["Y", "X"],
    )
    df.to_csv("data.csv")


if __name__ == "__main__":
    main()
