import numpy as np
import os
import pandas as pd

def get_ordered_columns(level=5, direct=("Bid","Ask"), pv=("Price", "Size"),index=True,price=True,volume=True):
    """
    Get the ordered columns for the level of the order book.
    """

    columns = []
    if index:
        columns.append("index") 

    if price:
        for i in range(level,0,-1):
            columns.append(f"{direct[0]}{pv[0]}{i}") 
        for i in range(1,level+1):
            columns.append(f"{direct[1]}{pv[0]}{i}")

    if volume:
        for i in range(level,0,-1):
            columns.append(f"{direct[0]}{pv[1]}{i}")
        for i in range(1,level+1):
            columns.append(f"{direct[1]}{pv[1]}{i}")

    return columns

class PreProcessData:
    def __init__(self,
                 csv_path,
                 level = 5, 
                 normalizing_method=[],
                 semi_mode=False,
                 add_features=False,
                 save_path="dataset/split_LOBSTER/",
                 **kwargs):
        ''' Initialize the PreProcessData class.
        Parameters:
        - csv_path: str, path to the CSV file containing the raw data.
        - level: int, the level of the order book to process.
        - normalizing_method: list, methods to normalize the data.
        - semi_mode: bool, if True, splits data into train, valid and test sets all with injected anomalies.
        - add_features: bool, if True, adds additional features to the data.
        - save_path: str, path to save the split data.
        '''
        self.level = level
        self.normalizing_method = normalizing_method
        
        for key, value in kwargs.items():
            setattr(self, key, value)
            
        if os.path.exists(csv_path) == False:
            raise FileExistsError("file is not exist")
        else:
            self.raw_data = pd.read_csv(csv_path, sep=';',dtype={'OriginalSequenceNumber': str})
            if not add_features:
                columns = get_ordered_columns(level=self.level) + ['Date', 'TimeInMilliSecs', 'OriginalSequenceNumber', 'StockSymbol', 'ClusterNo', 'ManipulatedLevel', 'FraudType']
                self.raw_data = self.raw_data[columns]
            self.raw_data['FraudType'] = self.raw_data['FraudType'].fillna(0).astype(int)
            self.raw_data['DeltaTime'] = self.raw_data.groupby(['StockSymbol', 'Date'])['TimeInMilliSecs'].diff().fillna(0)
        
        if add_features:
            self.add_features_norm()
        else: 
            self.normalized_data = self.raw_data.copy()

        if len(self.normalizing_method) > 0:
            self.normalize_lob_data()
        
        if semi_mode:
            self.train_data, self.valid_data, self.test_data = self.split_data_semisupervised()
        else:
            self.train_data, self.valid_data, self.test_data = self.split_data()
        
        self.save_split_data(save_path)

    def normalize_lob_data(self):
        price_columns = get_ordered_columns(level=self.level,index=False, price=True, volume=False)
        volume_columns = get_ordered_columns(level=self.level,index=False, price=False, volume=True)
        for stock in self.raw_data['StockSymbol'].unique():
            stock_mask = self.raw_data['StockSymbol'] == stock
            price_data = self.raw_data.loc[stock_mask, price_columns].values
            volume_data = self.raw_data.loc[stock_mask, volume_columns].values
            delta_time = self.raw_data.loc[stock_mask, 'DeltaTime'].values

            if "feature_zscore" in self.normalizing_method:
                price_mean = np.nanmean(price_data, axis=0)
                price_std = np.nanstd(price_data, axis=0)
                volume_mean = np.nanmean(volume_data, axis=0)
                volume_std = np.nanstd(volume_data, axis=0)
                deltaT_mean = np.nanmean(delta_time, axis=0)
                deltaT_std = np.nanstd(delta_time, axis=0)

                self.normalized_data.loc[stock_mask, price_columns] = (
                    self.raw_data.loc[stock_mask, price_columns] - price_mean
                ) / price_std
                self.normalized_data.loc[stock_mask, volume_columns] = (
                    self.raw_data.loc[stock_mask, volume_columns] - volume_mean
                ) / volume_std
                self.normalized_data.loc[stock_mask, 'DeltaTime'] = (
                    self.raw_data.loc[stock_mask, 'DeltaTime'] - deltaT_mean
                ) / deltaT_std

                print(f"Normalized data for stock: {stock} using feature z-score.")
    
    def add_features_norm(self):
        self.normalized_data = self.raw_data.copy()
        self.normalized_data['CancelledBidIndicator'] = self.normalized_data['CancelledBidIndicator'].astype(float)
        self.normalized_data['CancelledAskIndicator'] = self.normalized_data['CancelledAskIndicator'].astype(float)
        for stock in self.normalized_data['StockSymbol'].unique():
            stock_data = self.normalized_data[self.normalized_data['StockSymbol'] == stock]
            for date in stock_data['Date'].unique():
                day_data = stock_data[stock_data['Date'] == date].copy()
                indexes = stock_data[stock_data['Date'] == date].index.values
                fraud_free_indexes = stock_data[(stock_data['Date'] == date) & (stock_data['FraudType'] == 0)].index.values
                # Computing delta time between two LOB updates
                time_diff = day_data['TimeInMilliSecs'].diff().values * \
                            (day_data['TimeInMilliSecs'].diff().values >= 1).astype(float)
                time_diff[np.isnan(time_diff)] = 1
                time_diff[time_diff == 0] = 1

                day_data['TradeBidIndicator'] = (day_data['TradeIndicator'] == 1).astype(int)
                day_data['TradeAskIndicator'] = (day_data['TradeIndicator'] == -1).astype(int)

                # Computing absolute price move and derivative of move in time
                day_data['ReturnBid1'] = np.log(day_data['BidPrice1'] / day_data['BidPrice1'].shift(-1))
                day_data['ReturnAsk1'] = np.log(day_data['AskPrice1'] / day_data['AskPrice1'].shift(-1))
                day_data['ReturnBid2'] = np.log(day_data['BidPrice2'] / day_data['BidPrice2'].shift(-1))
                day_data['ReturnAsk2'] = np.log(day_data['AskPrice2'] / day_data['AskPrice2'].shift(-1))
                day_data['ReturnBid3'] = np.log(day_data['BidPrice3'] / day_data['BidPrice3'].shift(-1))
                day_data['ReturnAsk3'] = np.log(day_data['AskPrice3'] / day_data['AskPrice3'].shift(-1))
                day_data['ReturnBid4'] = np.log(day_data['BidPrice4'] / day_data['BidPrice4'].shift(-1))
                day_data['ReturnAsk4'] = np.log(day_data['AskPrice4'] / day_data['AskPrice4'].shift(-1))
                day_data['ReturnBid5'] = np.log(day_data['BidPrice5'] / day_data['BidPrice5'].shift(-1))
                day_data['ReturnAsk5'] = np.log(day_data['AskPrice5'] / day_data['AskPrice5'].shift(-1))

                day_data['DerivativeReturnBid1'] = (day_data['ReturnBid1'] / time_diff)
                day_data['DerivativeReturnAsk1'] = (day_data['ReturnAsk1'] / time_diff)
                day_data['DerivativeReturnBid2'] = (day_data['ReturnBid2'] / time_diff)
                day_data['DerivativeReturnAsk2'] = (day_data['ReturnAsk2'] / time_diff)
                day_data['DerivativeReturnBid3'] = (day_data['ReturnBid3'] / time_diff)
                day_data['DerivativeReturnAsk3'] = (day_data['ReturnAsk3'] / time_diff)
                day_data['DerivativeReturnBid4'] = (day_data['ReturnBid4'] / time_diff)
                day_data['DerivativeReturnAsk4'] = (day_data['ReturnAsk4'] / time_diff)
                day_data['DerivativeReturnBid5'] = (day_data['ReturnBid5'] / time_diff)
                day_data['DerivativeReturnAsk5'] = (day_data['ReturnAsk5'] / time_diff)

                # Inserting features back into main dataframe
                self.normalized_data.loc[indexes, 'ReturnBid1'] = day_data['ReturnBid1']
                self.normalized_data.loc[indexes, 'ReturnAsk1'] = day_data['ReturnAsk1']
                self.normalized_data.loc[indexes, 'DerivativeReturnBid1'] = day_data['DerivativeReturnBid1']
                self.normalized_data.loc[indexes, 'DerivativeReturnAsk1'] = day_data['DerivativeReturnAsk1']
                self.normalized_data.loc[indexes, 'ReturnBid2'] = day_data['ReturnBid2']
                self.normalized_data.loc[indexes, 'ReturnAsk2'] = day_data['ReturnAsk2']
                self.normalized_data.loc[indexes, 'DerivativeReturnBid2'] = day_data['DerivativeReturnBid2']
                self.normalized_data.loc[indexes, 'DerivativeReturnAsk2'] = day_data['DerivativeReturnAsk2']
                self.normalized_data.loc[indexes, 'ReturnBid3'] = day_data['ReturnBid3']
                self.normalized_data.loc[indexes, 'ReturnAsk3'] = day_data['ReturnAsk3']
                self.normalized_data.loc[indexes, 'DerivativeReturnBid3'] = day_data['DerivativeReturnBid3']
                self.normalized_data.loc[indexes, 'DerivativeReturnAsk3'] = day_data['DerivativeReturnAsk3']
                self.normalized_data.loc[indexes, 'ReturnBid4'] = day_data['ReturnBid4']
                self.normalized_data.loc[indexes, 'ReturnAsk4'] = day_data['ReturnAsk4']
                self.normalized_data.loc[indexes, 'DerivativeReturnBid4'] = day_data['DerivativeReturnBid4']
                self.normalized_data.loc[indexes, 'DerivativeReturnAsk4'] = day_data['DerivativeReturnAsk4']
                self.normalized_data.loc[indexes, 'ReturnBid5'] = day_data['ReturnBid5']
                self.normalized_data.loc[indexes, 'ReturnAsk5'] = day_data['ReturnAsk5']
                self.normalized_data.loc[indexes, 'DerivativeReturnBid5'] = day_data['DerivativeReturnBid5']
                self.normalized_data.loc[indexes, 'DerivativeReturnAsk5'] = day_data['DerivativeReturnAsk5']                
                self.normalized_data.loc[indexes, 'TradeBidSize'] = day_data['TradeSize'].abs().rolling(window=10).mean() * (day_data['TradeBidIndicator'] == 1).astype(int)
                self.normalized_data.loc[indexes, 'TradeAskSize'] = day_data['TradeSize'].abs().rolling(window=10).mean() * (day_data['TradeAskIndicator'] == 1).astype(int)
                self.normalized_data.loc[indexes, 'CancelledBidSize'] = day_data['CancelledBidSize'].rolling(window=10).mean()
                self.normalized_data.loc[indexes, 'CancelledAskSize'] = day_data['CancelledAskSize'].rolling(window=10).mean()
                self.normalized_data.loc[indexes, 'TradeBidIndicator'] = day_data['TradeBidIndicator'] / time_diff
                self.normalized_data.loc[indexes, 'TradeAskIndicator'] = day_data['TradeAskIndicator'] / time_diff
                self.normalized_data.loc[indexes, 'CancelledBidIndicator'] = day_data['CancelledBidIndicator'] / time_diff
                self.normalized_data.loc[indexes, 'CancelledAskIndicator'] = day_data['CancelledAskIndicator'] / time_diff
                
                # Normalize data on daily instrument basis
                self.normalized_data.loc[indexes, 'TradeBidSize'] = (self.normalized_data.loc[indexes, 'TradeBidSize'] - self.normalized_data.loc[fraud_free_indexes, 'TradeBidSize'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'TradeBidSize'].std()
                self.normalized_data.loc[indexes, 'TradeAskSize'] = (self.normalized_data.loc[indexes, 'TradeAskSize'] - self.normalized_data.loc[fraud_free_indexes, 'TradeAskSize'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'TradeAskSize'].std()
                self.normalized_data.loc[indexes, 'CancelledBidSize'] = (self.normalized_data.loc[indexes, 'CancelledBidSize'] - self.normalized_data.loc[fraud_free_indexes, 'CancelledBidSize'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'CancelledBidSize'].std()
                self.normalized_data.loc[indexes, 'CancelledAskSize'] = (self.normalized_data.loc[indexes, 'CancelledAskSize'] - self.normalized_data.loc[fraud_free_indexes, 'CancelledAskSize'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'CancelledAskSize'].std()
                self.normalized_data.loc[indexes, 'TradeBidIndicator'] = (self.normalized_data.loc[indexes, 'TradeBidIndicator'] - self.normalized_data.loc[fraud_free_indexes, 'TradeBidIndicator'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'TradeBidIndicator'].std()
                self.normalized_data.loc[indexes, 'TradeAskIndicator'] = (self.normalized_data.loc[indexes, 'TradeAskIndicator'] - self.normalized_data.loc[fraud_free_indexes, 'TradeAskIndicator'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'TradeAskIndicator'].std()
                self.normalized_data.loc[indexes, 'CancelledBidIndicator'] = (self.normalized_data.loc[indexes, 'CancelledBidIndicator'] - self.normalized_data.loc[fraud_free_indexes, 'CancelledBidIndicator'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'CancelledBidIndicator'].std()
                self.normalized_data.loc[indexes, 'CancelledAskIndicator'] = (self.normalized_data.loc[indexes, 'CancelledAskIndicator'] - self.normalized_data.loc[fraud_free_indexes, 'CancelledAskIndicator'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'CancelledAskIndicator'].std()
                self.normalized_data.loc[indexes, 'ReturnBid1'] = (self.normalized_data.loc[indexes, 'ReturnBid1'] - self.normalized_data.loc[fraud_free_indexes, 'ReturnBid1'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'ReturnBid1'].std()
                self.normalized_data.loc[indexes, 'ReturnAsk1'] = (self.normalized_data.loc[indexes, 'ReturnAsk1'] - self.normalized_data.loc[fraud_free_indexes, 'ReturnAsk1'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'ReturnAsk1'].std()
                self.normalized_data.loc[indexes, 'DerivativeReturnBid1'] = (self.normalized_data.loc[indexes, 'DerivativeReturnBid1'] - self.normalized_data.loc[fraud_free_indexes, 'DerivativeReturnBid1'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'DerivativeReturnBid1'].std()
                self.normalized_data.loc[indexes, 'DerivativeReturnAsk1'] = (self.normalized_data.loc[indexes, 'DerivativeReturnAsk1'] - self.normalized_data.loc[fraud_free_indexes, 'DerivativeReturnAsk1'].mean()) / self.normalized_data.loc[fraud_free_indexes, 'DerivativeReturnAsk1'].std()
                
        print("Added normalized features data for all stocks.")
        
    def split_data(self, buffer_seconds=35, random_state=42, valid_ratio=0.2, max_block_len=5000):
        train_list, test_list = [], []
        print('Splitting data into train, valid and test sets.')
        for stock in self.normalized_data['StockSymbol'].unique():
            stock_data = self.normalized_data[self.normalized_data['StockSymbol'] == stock]
            for date in stock_data['Date'].unique():
                day_data = stock_data[stock_data['Date'] == date].copy()
                
                clusters = day_data['ClusterNo'].dropna().unique()
                if len(clusters) == 0:
                    day_data['block_id'] = 0
                    train_list.append(day_data)
                    print(f"No clusters found for stock {stock} on date {date}. Skipping.")
                    continue
                
                # record the cluster ranges
                cluster_ranges = []
                for c in clusters:
                    c_data = day_data[day_data['ClusterNo'] == c]
                    start = c_data['TimeInMilliSecs'].min()
                    end = c_data['TimeInMilliSecs'].max()
                    cluster_ranges.append([start, end])
                # Merge overlapping clusters
                cluster_ranges.sort()
                merged = []
                for rng in cluster_ranges:
                    if not merged:
                        merged.append(rng)
                    else:
                        if rng[0] - merged[-1][1] <= buffer_seconds * 2000:  # 2000=1000ms*2
                            merged[-1][1] = max(merged[-1][1], rng[1])
                        else:
                            merged.append(rng)
                # add buffer to the merged ranges
                day_start = day_data['TimeInMilliSecs'].min()
                day_end = day_data['TimeInMilliSecs'].max()
                test_ranges = []
                for i, (start, end) in enumerate(merged):
                    test_start = max(day_start, start - buffer_seconds * 1000)
                    test_end = min(day_end, end + buffer_seconds * 1000)
                    test_ranges.append((test_start, test_end, i))

                mask = np.zeros(len(day_data), dtype=bool)
                for test_start, test_end, idx in test_ranges:
                    in_test = (day_data['TimeInMilliSecs'] >= test_start) & (day_data['TimeInMilliSecs'] <= test_end)
                    mask = mask | in_test
                
                # mark normal_block
                normal_mask = ~mask
                normal_block = (normal_mask.astype(int).diff(1).fillna(0) != 0).cumsum()
                normal_data = day_data[normal_mask].copy()
                normal_data['block_id'] = normal_block[normal_mask].values
                
                train_list.append(normal_data)
                test_list.append(day_data[mask])
        train_df = pd.concat(train_list, ignore_index=True)
        test_df = pd.concat(test_list, ignore_index=True)
        # split train and validation dataset
        train_df = train_df.sort_values(['StockSymbol', 'Date', 'TimeInMilliSecs', 'index']).reset_index(drop=True)
        normal_blocks = train_df.groupby(['StockSymbol', 'Date', 'block_id'])
        
        new_blocks = []
        new_block_id = 0
        for (stock, date, block_id), block_df in normal_blocks:
            n = len(block_df)
            if n > max_block_len:
                num_splits = int(np.ceil(n / max_block_len))
                for i in range(num_splits):
                    sub_block = block_df.iloc[i*max_block_len : (i+1)*max_block_len].copy()
                    sub_block['block_id'] = new_block_id
                    new_blocks.append(sub_block)
                    new_block_id += 1
            else:
                block_df = block_df.copy()
                block_df['block_id'] = new_block_id
                new_blocks.append(block_df)
                new_block_id += 1
        train_df = pd.concat(new_blocks, ignore_index=True)
        
        normal_blocks = train_df.groupby(['StockSymbol', 'Date', 'block_id'])
        block_indices = list(normal_blocks.groups.keys())
        rng = np.random.RandomState(random_state)
        rng.shuffle(block_indices)
        n_valid = int(len(block_indices) * valid_ratio)
        valid_blocks = set(block_indices[:n_valid])
        train_blocks = set(block_indices[n_valid:])
        train_mask = train_df.apply(lambda row: (row['StockSymbol'], row['Date'], row['block_id']) in train_blocks, axis=1)
        valid_mask = train_df.apply(lambda row: (row['StockSymbol'], row['Date'], row['block_id']) in valid_blocks, axis=1)
        train_final = train_df[train_mask].drop(['block_id'], axis=1).reset_index(drop=True)
        valid_final = train_df[valid_mask].drop(['block_id'], axis=1).reset_index(drop=True)
        return train_final, valid_final, test_df
    
    def split_data_semisupervised(self, buffer_seconds=35, random_state=42, valid_ratio=0.2, max_block_len=5000, inject_abnormal_ratio=0.3):
        train_list, test_list = [], []
        rng = np.random.RandomState(random_state)
        print('Splitting data into train, valid and test sets under semi-supervised mode.')

        for stock in self.normalized_data['StockSymbol'].unique():
            stock_data = self.normalized_data[self.normalized_data['StockSymbol'] == stock]
            for date in stock_data['Date'].unique():
                day_data = stock_data[stock_data['Date'] == date].copy()
                
                clusters = day_data['ClusterNo'].dropna().unique()
                if len(clusters) == 0:
                    day_data['block_id'] = 0
                    train_list.append(day_data)
                    print(f"No clusters found for stock {stock} on date {date}. Skipping.")
                    continue
                
                # merge overlapping clusters
                cluster_ranges = []
                for c in clusters:
                    c_data = day_data[day_data['ClusterNo'] == c]
                    start = c_data['TimeInMilliSecs'].min()
                    end = c_data['TimeInMilliSecs'].max()
                    cluster_ranges.append([start, end])
                cluster_ranges.sort()
                merged = []
                for rng_ in cluster_ranges:
                    if not merged:
                        merged.append(rng_)
                    else:
                        if rng_[0] - merged[-1][1] <= buffer_seconds * 2000:  # 2000=1000ms*2
                            merged[-1][1] = max(merged[-1][1], rng_[1])
                        else:
                            merged.append(rng_)
                
                # add buffer to the merged ranges
                day_start = day_data['TimeInMilliSecs'].min()
                day_end = day_data['TimeInMilliSecs'].max()
                test_ranges = []
                for start, end in merged:
                    test_start = max(day_start, start - buffer_seconds * 1000)
                    test_end = min(day_end, end + buffer_seconds * 1000)
                    test_ranges.append((test_start, test_end))

                # Shuffle and inject abnormal segments into train data
                rng.shuffle(test_ranges)
                n_inject = int(len(test_ranges) * inject_abnormal_ratio)
                inject_ranges = test_ranges[:n_inject]
                final_test_ranges = test_ranges[n_inject:]
                print(f"Stock {stock} on date {date}: Injected {n_inject} abnormal segments, remaining {len(final_test_ranges)} test segments.")

                inject_mask = np.zeros(len(day_data), dtype=bool)
                for s, e in inject_ranges:
                    inject_mask |= (day_data['TimeInMilliSecs'] >= s) & (day_data['TimeInMilliSecs'] <= e)

                test_mask = np.zeros(len(day_data), dtype=bool)
                for s, e in final_test_ranges:
                    test_mask |= (day_data['TimeInMilliSecs'] >= s) & (day_data['TimeInMilliSecs'] <= e)

                used_mask = inject_mask | test_mask
                normal_mask = ~used_mask

                # add block id to normal data
                normal_block = (normal_mask.astype(int).diff(1).fillna(0) != 0).cumsum()
                normal_data = day_data[normal_mask].copy()
                normal_data['block_id'] = normal_block[normal_mask].values
                train_list.append(normal_data)

                # add block id to inject and test data
                max_block_id = normal_data['block_id'].max() if not normal_data.empty else -1
                next_block_id = max_block_id + 1

                for s, e in inject_ranges:
                    chunk = day_data[(day_data['TimeInMilliSecs'] >= s) & (day_data['TimeInMilliSecs'] <= e)].copy()
                    chunk['block_id'] = next_block_id
                    next_block_id += 1
                    train_list.append(chunk)

                for s, e in final_test_ranges:
                    chunk = day_data[(day_data['TimeInMilliSecs'] >= s) & (day_data['TimeInMilliSecs'] <= e)].copy()
                    chunk['block_id'] = next_block_id
                    next_block_id += 1
                    test_list.append(chunk)

        train_df = pd.concat(train_list, ignore_index=True)
        test_df = pd.concat(test_list, ignore_index=True)

        train_df = train_df.sort_values(['StockSymbol', 'Date', 'TimeInMilliSecs', 'index']).reset_index(drop=True)
        blocks = train_df.groupby(['StockSymbol', 'Date', 'block_id'])

        new_blocks = []
        new_block_id = 0
        for (stock, date, block_id), block_df in blocks:
            n = len(block_df)
            if n > max_block_len:
                num_splits = int(np.ceil(n / max_block_len))
                for i in range(num_splits):
                    sub_block = block_df.iloc[i*max_block_len : (i+1)*max_block_len].copy()
                    sub_block['block_id'] = new_block_id
                    new_blocks.append(sub_block)
                    new_block_id += 1
            else:
                block_df = block_df.copy()
                block_df['block_id'] = new_block_id
                new_blocks.append(block_df)
                new_block_id += 1
        train_df = pd.concat(new_blocks, ignore_index=True)

        blocks = train_df.groupby(['StockSymbol', 'Date', 'block_id'])
        block_indices = list(blocks.groups.keys())
        rng.shuffle(block_indices)
        n_valid = int(len(block_indices) * valid_ratio)
        valid_blocks = set(block_indices[:n_valid])
        train_blocks = set(block_indices[n_valid:])
        train_mask = train_df.apply(lambda row: (row['StockSymbol'], row['Date'], row['block_id']) in train_blocks, axis=1)
        valid_mask = train_df.apply(lambda row: (row['StockSymbol'], row['Date'], row['block_id']) in valid_blocks, axis=1)
        train_final = train_df[train_mask].drop(['block_id'], axis=1).reset_index(drop=True)
        valid_final = train_df[valid_mask].drop(['block_id'], axis=1).reset_index(drop=True)
        return train_final, valid_final, test_df

    def save_split_data(self, save_path):
        """
        Save the split data to the specified path.
        """
        print('Saving train, valid and test sets.')
        self.train_data.to_parquet(save_path + 'train_data.parquet', index=False, engine='fastparquet',
                            row_group_offsets=100000)
        self.valid_data.to_parquet(save_path + 'valid_data.parquet', index=False, engine='fastparquet',
                            row_group_offsets=100000)
        self.test_data.to_parquet(save_path + 'test_data.parquet', index=False, engine='fastparquet',
                            row_group_offsets=100000)

        print('Preprocessing done.')
        
if __name__ == "__main__":
    csv_path = "dataset/formatted_LOBSTER/All_Stocks_InjectedL1.csv"
    data_preprocess = PreProcessData(csv_path=csv_path, level=5, normalizing_method=["feature_zscore"],semi_mode=False, add_features=True,save_path="dataset/split_LOBSTER/AnomalyL1_features_")