import math
import numpy as np
import pandas as pd
from numpy.lib.stride_tricks import sliding_window_view

def normalize_zscore(arr, norm_window=5, num_chunks=50):
    print(f"Window: {norm_window}")
    chunk_size = min(100000, math.ceil(arr.shape[0]/num_chunks)) # reduces memory requirement
    new_arr = np.zeros_like(arr, dtype=float)
    
    for i in range(0, len(arr), chunk_size):
        if (i//chunk_size)% 50==0: print(i)
        chunk_start = max(i - (norm_window - 1), 0)
        chunk_end = min(i + chunk_size, len(arr))
        chunk = arr[chunk_start:chunk_end]
        chunk_rolled = sliding_window_view(chunk, window_shape=(norm_window,), axis=0)
        chunk_mu = chunk_rolled.mean(axis=-1)
        chunk_std = chunk_rolled.std(ddof=1, axis=-1)
        chunk_norm = (chunk[norm_window-1:] - chunk_mu) / (chunk_std+1e-8)
        new_arr[chunk_start+norm_window-1:chunk_end] = chunk_norm

    new_arr = new_arr[norm_window-1:]
    return new_arr

def returns_rolling_norm(arr, k=5, rw=20, chunk_size=70000):

    print(f"Window: {rw}, Forecast Horizon: {k}")
    assert chunk_size > rw
    new_arr = np.full_like(arr, np.nan, dtype=float)  # Initialize with NaNs for padding cases

    for i in range(0, len(arr), chunk_size):
        if (i // chunk_size) % 50 == 0:
            print(f"Processing chunk starting at index {i}")
        
        # Define the chunk boundaries
        chunk_start = max(i - (rw - 1) - k, 0)
        chunk_end = min(i + chunk_size, len(arr))
        
        # Extract the chunk for processing
        chunk = arr[chunk_start:chunk_end]
        
        # Compute rolling statistics shifted by the forecast horizon
        chunk_rolled = sliding_window_view(chunk, window_shape=(rw,), axis=0)
        chunk_mu = chunk_rolled.mean(axis=-1)
        chunk_std = chunk_rolled.std(ddof=1, axis=-1)
        
        # Normalize values using shifted rolling statistics
        valid_start = rw - 1 + k

        # get rid of zero stds
        chunk = chunk[valid_start:]
        chunk_mu = chunk_mu[:-k]
        chunk_std = chunk_std[:-k]
        chunk_norm = np.where(np.isclose(chunk_std,0, atol=1e-2), chunk,
                              (chunk - chunk_mu) / (chunk_std + 1e-8))
        # chunk_norm = (chunk[valid_start:] - chunk_mu[:-k]) / (chunk_std[:-k] + 1e-8)
        
        # Update the normalized array (adjusting for forecast horizon)
        result_start = chunk_start + valid_start
        result_end = result_start + len(chunk_norm)
        new_arr[result_start:result_end] = chunk_norm

    return new_arr

def welford(dataset, chunk_size=10000):
    means = None
    M2 = None
    n = 0
    for i in range(0, len(dataset), chunk_size):
        chunk_end = min(i+chunk_size, len(dataset))
        chunk = dataset[i:chunk_end]
        chunk_mean = chunk.mean(axis=0)
        chunk_count = len(chunk)
        chunk_var = chunk.var(axis=0, ddof=0) * chunk_count

        if means is None:
            means = chunk_mean
            M2 = chunk_var
        else:
            delta = chunk_mean - means
            total_count = n + chunk_count

            means += delta * chunk_count / total_count
            M2 += chunk_var + delta**2 * n * chunk_count / total_count
        
        n += chunk_count

    print('std of welford is ddof=1')
    variance = M2 / (n-1)
    std = variance**0.5

    return means, std

def mean_and_std(dataset):
    values = dataset.flatten()
    means = None
    M2 = None
    n = 0
    chunk_size = 10000
    print(f'chunk size: {chunk_size}')
    for i in range(0, len(values), chunk_size):
        chunk_end = min(i+chunk_size, len(values))
        chunk = values[i:chunk_end]
        chunk_mean = chunk.mean()
        chunk_count = len(chunk)
        chunk_var = chunk.var(ddof=0) * chunk_count

        if means is None:
            means = chunk_mean
            M2 = chunk_var
        else:
            delta = chunk_mean - means
            total_count = n + chunk_count

            means += delta * chunk_count / total_count
            M2 += chunk_var + delta**2 * n * chunk_count / total_count
        
        n += chunk_count

    variance = M2 / (n - 1)
    std = variance**0.5

    print(f"number of values in dataset: {n}")

    return means, std

def rolling_normalization_exclude_current_chunk(data, chunk_size=10000, window_size=50000):
    if not isinstance(data, pd.DataFrame):
        data = pd.DataFrame(data)
    
    normalized_data = pd.DataFrame(index=data.index, columns=data.columns, dtype=float)  # To store the normalized results
    
    num_timesteps = len(data)
    for start_idx in range(0, num_timesteps, chunk_size):
        # Define the range for the rolling window
        window_start = max(0, start_idx - window_size)
        current_chunk_start = start_idx
        current_chunk_end = min(start_idx + chunk_size, num_timesteps)
        
        # Exclude the current chunk from the rolling window
        rolling_window = data.iloc[window_start:current_chunk_start]
        
        # Slice the current chunk
        current_chunk = data.iloc[current_chunk_start:current_chunk_end]
        
        # Calculate mean and std from the rolling window
        if not rolling_window.empty:
            mean, std = welford(rolling_window.values, 5000)
            # mean = rolling_window.mean()
            # std = rolling_window.std()
            # Normalize the current chunk
            normalized_data.iloc[current_chunk_start:current_chunk_end] = (current_chunk - mean) / std
        else:
            mean = 0
            std = 1  # Avoid division by zero in case the rolling window is empty
            normalized_data.iloc[current_chunk_start:current_chunk_end] = np.nan
        
    
    return normalized_data

def rolling_normalization_exclude_current_chunk(data, chunk_size=10000, window_size=50000):
    if not isinstance(data, pd.DataFrame):
        data = pd.DataFrame(data)
    
    normalized_data = pd.DataFrame(index=data.index, columns=data.columns, dtype=float)  # To store the normalized results
    
    num_timesteps = len(data)
    for start_idx in range(0, num_timesteps, chunk_size):
        # Define the range for the rolling window
        window_start = max(0, start_idx - window_size)
        current_chunk_start = start_idx
        current_chunk_end = min(start_idx + chunk_size, num_timesteps)
        
        # Exclude the current chunk from the rolling window
        rolling_window = data.iloc[window_start:current_chunk_start]
        
        # Slice the current chunk
        current_chunk = data.iloc[current_chunk_start:current_chunk_end]
        
        # Calculate mean and std from the rolling window
        if not rolling_window.empty:
            mean, std = welford(rolling_window.values, 5000)
            # mean = rolling_window.mean()
            # std = rolling_window.std()
            # Normalize the current chunk
            normalized_data.iloc[current_chunk_start:current_chunk_end] = (
                current_chunk - mean) / std
        else:
            mean = 0
            std = 1  # Avoid division by zero in case the rolling window is empty
            normalized_data.iloc[current_chunk_start:current_chunk_end] = np.nan
    
    return normalized_data

def rolling_normalization_exclude_current_chunk_shift(
    data, chunk_size=86400, window_size=5*86400, horizon=5
):
    if not isinstance(data, pd.DataFrame):
        data = pd.DataFrame(data)
    
    normalized_data = pd.DataFrame(index=data.index, columns=data.columns, dtype=float)
    num_timesteps = len(data)
    
    for start_idx in range(0, num_timesteps, chunk_size):
        # Define the range for the rolling window
        window_start = max(0, start_idx - window_size)
        exclude_recent = max(0, start_idx - horizon + 1)
        current_chunk_start = start_idx
        current_chunk_end = min(start_idx + chunk_size, num_timesteps)
        
        # Exclude the most recent `horizon-1` timesteps
        rolling_window = data.iloc[window_start:exclude_recent]
        
        # Slice the current chunk
        current_chunk = data.iloc[current_chunk_start:current_chunk_end]
        
        # Calculate mean and std from the adjusted rolling window
        if not rolling_window.empty:
            mean, std = welford(rolling_window.values, 5000)
            # Normalize the current chunk
            normalized_data.iloc[current_chunk_start:current_chunk_end] = (
                current_chunk - mean
            ) / std
        else:
            mean = 0
            std = 1  # Avoid division by zero in case the rolling window is empty
            normalized_data.iloc[current_chunk_start:current_chunk_end] = np.nan
    
    return normalized_data
