from generate_data import poisson_stream 
import numpy as np

# =========================
# Simulation configuration
# =========================
# Ambient dimension (d = 3 in the first simulation setting)
dim = 3

# Training length and total stream length (as in the experiments)
N_train = 1000
N_total = 2000

# Change-point location (b = 1200)
target = 1200

# Location parameters used in intensity definitions (see lam2)
mu1 = 0.7
mu2 = 0.3
mu3 = 0.5
##############################

# =========================
# Pre-/post-change intensities
# =========================
def lam1(X):
    """
    Pre-change intensity λ1 for d=3:
        s = x1 + x2 + x3
        λ1(x) = 5*(sin(s) + 1) + 5*(cos(s) + 1)
    Vectorized over rows of X ∈ [0,1]^3.
    """
    # Alternative Gaussian-mixture prototype kept as commented reference:
    # s1 = -1*np.linalg.norm(X - mu1, axis=1)**2
    # s2 = -1*np.linalg.norm(X - mu2, axis=1)**2
    # return 10*(np.exp(s1) + np.exp(s2))

    s = X.sum(axis=1)
    return 5*(np.sin(s) + 1.0) + 5*(np.cos(s) + 1.0) 

def lam2(X):
    """
    Post-change intensity λ2 (Gaussian bump centered at (mu3, mu3, mu3)):
        λ2(x) = a * exp(-||x - mu3||^2), with a = 10 here.
    To switch signal strength in experiments, change the amplitude externally (e.g., a=20).
    """
    # s = X.sum(axis=1)  # not used for this post-change choice
    s3 = -1*np.linalg.norm(X - mu3, axis=1)**2
    return 10*np.exp(s3)

# λ1 generates the training segment; sufficient pre-change samples are needed to
# initialize/tune the detector as described in the manuscript.

# Stream generator for PPP realizations on [0,1]^d
generator = poisson_stream(dim)

# =========================
# Detector hyperparameters
# =========================
# max_len: maximum lag/window length W considered at each time
max_len = 100
# min_lag: minimal gap between past and recent windows
min_lag = 2*(max_len // 3)

# m: basis size per dimension (M). For d>1, M=2 matches the robustness defaults.
if dim == 1:
    m = 3
else:
    m = 2

# Coordinate split (two groups) for tensor matricization (y|z split in the paper)
index = [[], []]
for j in range(dim):
    if j < dim // 2:
        index[0].append(j)
    else:
        index[1].append(j)

# shapes: per-dimension basis sizes (uniform M here), used by tensor/SVD routines
shapes = [m for __ in range(dim)]
# rank r: target rank for the restricted SVD step
rank = 3

from matrix_detection import matrix_detection

# =========================
# Monte Carlo replication loop
# =========================
RR = 100  # number of replications per configuration

# Counters for the Matrix detector (T = Tensor/Matrix method).
# Placeholders for other baselines (Mean, KIE, MMD) are declared for consistency.
T_count_false = 0
T_count_success = 0
T_delays = 0
M_count_false = 0
M_count_success = 0
M_delays = 0
MMD_count_false = 0
MMD_count_success = 0
MMD_delays = 0

for rr in range(RR):
    print(rr)

    # -------------------------
    # Generate one full stream
    # -------------------------
    # Pre-change segment up to the change point (length = target)
    data = generator.generate(target, lam1, [])

    # Post-change segment after the change point (length = N_total - target)
    data = generator.generate(N_total - target, lam2, data)
    print('data')

    # --------------------------------
    # Initialize detector on training
    # --------------------------------
    # matrix_detection consumes:
    #   - dim, m, max_len, min_lag, rank: algorithm hyperparameters
    #   - data[:N_train]: training realizations
    #   - index: coordinate split for tensor-product basis/matricization
    detection_function = matrix_detection(dim, m, max_len, min_lag, rank, data[:N_train], index)

    # --------------------------------
    # Online detection phase
    # --------------------------------
    # Scan sequentially from the end of training. `detect` returns True at first alarm.
    for i in range(N_train, N_total, 1):
        if detection_function.detect(data[i]):
            break

    # --------------------------------
    # Tally outcomes for the Matrix detector
    # --------------------------------
    # If i < target → false alarm (alarm before change point).
    # Else          → successful detection with delay (i - target).
    if i - target < 0:
        T_count_false += 1
    else:
        T_count_success += 1
        T_delays += i - target

    print('T false discoveries =', T_count_false)
    if T_count_success > 0:
        print('T average delay =', T_delays / T_count_success)
    print("\n")

# Notes:
# • To reproduce full FAP–ADD tradeoff curves, wrap this loop in an outer sweep
#   of the detector’s threshold factor (as implemented inside matrix_detection)
#   and record (FAP, ADD) per level, including the censoring rule when no alarm
#   occurs by N_total (delay = N_total - target).
# • Additional baselines (Mean, KIE, MMD) can be plugged in analogously with
#   their own detectors and counters.
