import numpy as np
import pandas as pd

# This version will introduce one day latency in price window
def recover_price_matrix_lag(return_matrix, initial_price=100):
    T, N = return_matrix.shape
    prices = np.zeros((T, N))
    prices[0, :] = initial_price  # Set initial prices for all assets

    # Iteratively compute price using cumulative product
    prices[1:, :] = initial_price * np.cumprod(return_matrix[:-1, :], axis=0)

    return prices


def recover_price_matrix(return_matrix, initial_price=100):
    T, N = return_matrix.shape
    prices = np.zeros((T, N))

    # First price based on initial price and first return
    prices[0, :] = initial_price * return_matrix[0, :]

    # Iteratively compute the rest of the prices
    for t in range(1, T):
        prices[t, :] = prices[t - 1, :] * return_matrix[t, :]

    return prices


def load_data(
    return_matrix, daily_date, win_size=30, lag_oneday=False, batch_norm=False
):
    T, N = return_matrix.shape
    daily_return = return_matrix

    # Precompute full price matrix if batch_norm=False
    if not batch_norm:
        if lag_oneday:
            full_price_matrix = recover_price_matrix_lag(
                return_matrix, initial_price=100
            )
        else:
            full_price_matrix = recover_price_matrix(return_matrix, initial_price=100)

    # Precompute all per-window price matrices if batch_norm=True
    price_windows = None
    if batch_norm:
        price_windows = [[None for _ in range(T)] for _ in range(N)]
        for i in range(N):  # per asset
            for t in range(1, T):
                start_idx = max(0, t - win_size)
                ret_window = daily_return[start_idx:t, i]
                price_window = np.zeros_like(ret_window)
                if len(ret_window) > 0:
                    price_window[0] = 100 * ret_window[0]
                    for j in range(1, len(ret_window)):
                        price_window[j] = price_window[j - 1] * ret_window[j]
                price_windows[i][t] = price_window

    # Construct processed data
    processed_data_batch = []
    for t in range(1, T - 1):
        start_idx = max(0, t - win_size)
        batch_data = {
            "input_returns": [],
            "input_prices": [],
            "target_return": daily_return[t],
            "input_date_lists": list(daily_date[start_idx:t]),
            "output_date": daily_date[t],
        }

        for i in range(N):
            batch_data["input_returns"].append(daily_return[start_idx:t, i])

            if batch_norm:
                batch_data["input_prices"].append(price_windows[i][t])
            else:
                batch_data["input_prices"].append(full_price_matrix[start_idx:t, i])

        processed_data_batch.append(batch_data)

    return processed_data_batch
