import pandas as pd
# import pickle5 as pickle
from multiprocessing import Pool
from numpy.lib.stride_tricks import sliding_window_view
import numpy as np
import math
import torch
from data_preprocessing.rolling_norms import welford

def flip_cols(df):
    levels = []
    for level in range(0,20,2):
        one_level = []
        one_level.append(df.iloc[:,level+20])
        one_level.append(df.iloc[:,level+21])
        one_level.append(df.iloc[:,level])
        one_level.append(df.iloc[:,level+1])
        levels.append(pd.concat(one_level, axis=1))
    new_df = pd.concat(levels, axis=1)
    new_df.columns = np.arange(new_df.shape[1])
        
    print(new_df.head())
    print(new_df.shape)
    return new_df


def calculate_labels(mid_prices, k, threshold):
    means = mid_prices.rolling(window=k).mean()
    means = means.shift(-k)
    changes = (means - mid_prices) / mid_prices
    labels = np.array([2]*(len(changes)-k) + [np.nan]*k, dtype=float)
    labels[(changes <= -threshold).values.flatten()] = 3
    labels[(changes >= threshold).values.flatten()] = 1
    labels = pd.DataFrame(labels, dtype=float)
    return labels

def rolling_norm_exclude_current_tstep(df, norm_window=2000, num_chunks=50):
    print(f"Window: {norm_window}")
    arr = df.values
    chunk_size = math.ceil(arr.shape[0]/num_chunks) # reduces memory requirement
    print(f"Chunk Size: {chunk_size}")
    new_arr = np.zeros_like(arr, dtype=float)
    for i in range(0, len(arr), chunk_size):
        print(i)
        chunk_start = max(i - norm_window, 0)
        chunk_end = min(i + chunk_size, len(arr))
        chunk = arr[chunk_start:chunk_end]
        chunk_rolled = sliding_window_view(chunk, window_shape=(norm_window,), axis=0)
        chunk_mu = chunk_rolled.mean(axis=-1)
        chunk_std = chunk_rolled.std(ddof=1, axis=-1)
        chunk_norm = (chunk[norm_window:] - chunk_mu[:-1]) / (chunk_std[:-1]+1e-8)
        new_arr[chunk_start+norm_window:chunk_end] = chunk_norm
    new_arr = new_arr[norm_window:]
    return new_arr

def rolling_norm_include_current_tstep(df, norm_window=5, chunk_size=50000):
    print(f"Window: {norm_window}")
    print(f"df shape: {df.shape}")
    arr = df.values
    new_arr = np.zeros_like(arr, dtype=float)
    
    for i in range(0, len(arr), chunk_size):
        print(i)
        chunk_start = max(i - (norm_window - 1), 0)
        chunk_end = min(i + chunk_size, len(arr))
        chunk = arr[chunk_start:chunk_end]
        chunk_rolled = sliding_window_view(chunk, window_shape=(norm_window,), axis=0)
        chunk_mu = chunk_rolled.mean(axis=-1)
        chunk_std = chunk_rolled.std(ddof=1, axis=-1)
        chunk_norm = (chunk[norm_window-1:] - chunk_mu) / (chunk_std+1e-8)
        new_arr[chunk_start+norm_window-1:chunk_end] = chunk_norm

    new_arr = new_arr[norm_window-1:]
    new_df = pd.DataFrame(new_arr)
    return new_df


def rolling_normalization_exclude_current_chunk(data, chunk_size=10000, window_size=50000):
    if not isinstance(data, pd.DataFrame):
        data = pd.DataFrame(data)
    
    normalized_data = pd.DataFrame(index=data.index, columns=data.columns, dtype=float)  # To store the normalized results
    
    num_timesteps = len(data)
    for start_idx in range(0, num_timesteps, chunk_size):
        window_start = max(0, start_idx - window_size)
        current_chunk_start = start_idx
        current_chunk_end = min(start_idx + chunk_size, num_timesteps)
        
        rolling_window = data.iloc[window_start:current_chunk_start]
        
        current_chunk = data.iloc[current_chunk_start:current_chunk_end]
        
        if not rolling_window.empty:
            mean, std = welford(rolling_window.values, 5000)
            normalized_data.iloc[current_chunk_start:current_chunk_end] = (current_chunk - mean) / std
        else:
            mean = 0
            std = 1 
            normalized_data.iloc[current_chunk_start:current_chunk_end] = np.nan
        
    
    return normalized_data


if __name__ == '__main__':
    task = 'mprf'
    raw_df = pd.read_csv("./data/btcusdt.csv")
    raw_df = raw_df.iloc[:,3:]
    print(raw_df.head())
    raw_df = flip_cols(raw_df)
    raw_df.dropna(inplace=True)

    mid_price = (raw_df.iloc[:, 0] + raw_df.iloc[:, 2]) / 2
    print(np.max(np.abs(mid_price)))
    print(np.min(np.abs(mid_price)))
    print(np.max(mid_price))
    print(np.min(mid_price))

    print(raw_df.head())

    k_list = [1,2,3,5,10]

    # with Pool() as pool:
    #     results = pool.map(calculate_labels, [(mid_price, k, 0.00002) for k in k_list])

    rw = 20 
    print(f'RW: {rw}')
    print(f'Length of dataset: {len(raw_df)}')
    x_norm = rolling_norm_include_current_tstep(raw_df,rw, 50000)

    if task == 'mprf':
        df = pd.concat([x_norm, mid_price[rw-1:]], axis=1)
        df.dropna(inplace=True)
        df.to_pickle(f'./data/btcusdt_rw{rw}_mprf.pkl')
        print('done')
    elif task == 'mptp':
        alpha = 1e-4
        labels_df = pd.DataFrame()
        for k in k_list:
            labels = calculate_labels(mid_price, k, alpha * 0.01)
            print('Done', k)
            labels_df['labels_k_' + str(k)] = labels
            print(labels_df['labels_k_' + str(k)].value_counts())
        labels_df = labels_df.iloc[rw-1:]

        df = pd.concat([x_norm, labels_df], axis=1)
        print(df.head())
        print(df.tail())
        print(df.shape)
        df.dropna(inplace=True)
        print('after drop')
        print(df.head())
        print(df.tail())
        print(df.shape)


        split = 0.8
        train_val_df = df.iloc[:int(split * len(df))]
        n_samples_train = int(np.floor(len(train_val_df) * 0.8))
        train_df = train_val_df.iloc[:n_samples_train]
        val_df = train_val_df.iloc[n_samples_train:]
        test_df = df.iloc[int(split * len(df)):]

        print(train_df.shape)
        print(val_df.shape)
        print(test_df.shape)
        print(df.dtypes)


        train_df.to_pickle("./data/btcusdt_train_rw{}_a{}.pkl".format(rw,alpha))
        val_df.to_pickle("./data/btcusdt_val_rw{}_a{}.pkl".format(rw,alpha))
        test_df.to_pickle("./data/btcusdt_test_rw{}_a{}.pkl".format(rw,alpha))
        print("DONE!")

    
