import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import lightning as L

features_default = ['BidPrice5', 'BidPrice4', 'BidPrice3', 'BidPrice2', 'BidPrice1',
                    'AskPrice1', 'AskPrice2', 'AskPrice3', 'AskPrice4', 'AskPrice5',
                    'BidSize5', 'BidSize4', 'BidSize3', 'BidSize2', 'BidSize1',
                    'AskSize1', 'AskSize2', 'AskSize3', 'AskSize4', 'AskSize5', 'DeltaTime',
                    'ReturnBid1', 'ReturnAsk1', 'DerivativeReturnBid1', 'DerivativeReturnAsk1', 
                    'TradeBidSize', 'TradeAskSize', 'CancelledBidSize', 'CancelledAskSize',
                    'TradeBidIndicator', 'TradeAskIndicator', 'CancelledBidIndicator', 'CancelledAskIndicator']
# features_default = ['BidPrice1','AskPrice1', 'BidSize1','AskSize1', 'DeltaTime',
#                     'ReturnBid1', 'ReturnAsk1', 'DerivativeReturnBid1', 'DerivativeReturnAsk1', 
#                     'TradeBidSize', 'TradeAskSize', 'CancelledBidSize', 'CancelledAskSize',
#                     'TradeBidIndicator', 'TradeAskIndicator', 'CancelledBidIndicator', 'CancelledAskIndicator']
# features_default = ['BidPrice5', 'BidPrice4', 'BidPrice3', 'BidPrice2', 'BidPrice1',
#                     'AskPrice1', 'AskPrice2', 'AskPrice3', 'AskPrice4', 'AskPrice5',
#                     'BidSize5', 'BidSize4', 'BidSize3', 'BidSize2', 'BidSize1',
#                     'AskSize1', 'AskSize2', 'AskSize3', 'AskSize4', 'AskSize5','DeltaTime']

class SequentialDataset(Dataset):
    def __init__(self, dataset_path, seq_len, features=features_default,
                 response=None, stride=1, anomaly_ratio=None):
        '''Initialize the SequentialDataset.
        Parameters:
        - dataset_path: Path to the dataset file (Parquet format).
        - seq_len: Length of one sequence.
        - features: List of feature names to be used.
        - response: List of response variable names (optional).
        - stride: Step size for sliding window.
        - anomaly_ratio: Ratio of anomalies to normal samples, if set, will oversample anomalies
        '''
        self.data = pd.read_parquet(dataset_path, engine='fastparquet')
        self.seq_len = seq_len
        self.features = features
        self.stride = stride
        self.anomaly_ratio = anomaly_ratio  # whether to oversample anomalies

        if response is not None:
            self.response = response
            self.columns = self.features + self.response
        else:
            self.response = response
            self.columns = self.features + ['FraudType']  # add FraudType as information for semi mode but not used in training

        self.data = self.data[~self.data[self.columns].isnull().any(axis="columns")]

        self.index_map = {}
        self.anomaly_indices = []  # record anomaly windows
        self.normal_indices = []   # record normal windows

        self.days = self.data.Date.unique()
        self.instruments_data_matrices = [[] for _ in range(len(self.days))]

        index = 0
        for day_index in range(len(self.days)):
            data_day = self.data[self.data.Date == self.days[day_index]]
            instruments_day_counts = data_day.groupby('StockSymbol').StockSymbol.count()
            instruments_day = instruments_day_counts[instruments_day_counts > self.seq_len].index.tolist()
            for instrument_day_index in range(len(instruments_day)):
                data_day_instrument = data_day.loc[
                    data_day.StockSymbol == instruments_day[instrument_day_index], self.columns].to_numpy()
                self.instruments_data_matrices[day_index].append(data_day_instrument)
                for book_instrument_day_index in range(0, data_day_instrument.shape[0] - self.seq_len + 1, self.stride):
                    window = data_day_instrument[book_instrument_day_index: book_instrument_day_index + self.seq_len]
                    if 'FraudType' in self.columns:
                        fraud_label = (window[:, -1] != 0).any()  # check if any fraud label exists in the window
                    else:
                        fraud_label = False
                    self.index_map[index] = (day_index, instrument_day_index, book_instrument_day_index)
                    if fraud_label:
                        self.anomaly_indices.append(index)
                    else:
                        self.normal_indices.append(index)
                    index += 1   
        del self.data
          
        if self.anomaly_ratio is not None:
            self.prepare_random_indices()  # prepare sample indices based on anomaly ratio
    
    def on_epoch(self):
        if self.anomaly_ratio is not None:
            self.prepare_random_indices()
           
    def prepare_random_indices(self):
        num_samples = len(self.index_map)
        num_anomaly = int(self.anomaly_ratio * num_samples)
        num_normal = num_samples - num_anomaly

        anomaly_sampled = np.random.choice(self.anomaly_indices, num_anomaly, replace=True)
        normal_sampled = np.random.choice(self.normal_indices, num_normal, replace=False)
        self.sample_indices = np.concatenate([anomaly_sampled, normal_sampled])
        np.random.shuffle(self.sample_indices) 
        print(f"Prepared {len(self.sample_indices)} samples with anomaly ratio {self.anomaly_ratio}.")
        
    def __getitem__(self, idx):
        # if anomaly_ratio is set, use sampled indices
        if self.anomaly_ratio is not None:
            real_idx = self.sample_indices[idx]
        else:
            real_idx = idx

        day, instrument, book = self.index_map[real_idx]
        data = self.instruments_data_matrices[day][instrument][book: (book + self.seq_len), :]

        x = data[:, :-1]
        y = np.asarray(1 if (data[:, -1] == 1).any() else 0)
        return torch.from_numpy(x).float(), torch.from_numpy(y).float()

    def __len__(self):
        # if anomaly_ratio is set, return the length of sampled indices
        return len(self.normal_indices) if self.anomaly_ratio else len(self.index_map)

    
class SeqDataModule(L.LightningDataModule):

    def __init__(self, datasets, batch_size,  is_shuffle_train=True, num_workers=20, **kwargs):
        super().__init__()
        self.train_set = datasets[0]
        self.val_set = datasets[1]
        self.test_set = datasets[2]
        self.batch_size = batch_size
        self.is_shuffle_train = is_shuffle_train
        self.num_workers = num_workers
        self.pin_memory = True

    def setup(self, stage=None):
        pass
    
    def on_train_epoch_start(self):
        if hasattr(self.train_set, "on_epoch"):
            self.train_set.on_epoch()

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=self.is_shuffle_train, pin_memory=self.pin_memory, drop_last=False,num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory, drop_last=False,num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory, drop_last=False,num_workers=self.num_workers)
    
if __name__ == "__main__":
    train_dataset_path = "./dataset/split_LOBSTER/semi_features_train_data.parquet"
    valid_dataset_path = "./dataset/split_LOBSTER/semi_features_valid_data.parquet"   
    test_dataset_path = "./dataset/split_LOBSTER/semi_features_test_data.parquet" 
    seq_len = 25  
    
    train_dataset = SequentialDataset(train_dataset_path, seq_len, anomaly_ratio=0.3)
    valid_dataset = SequentialDataset(valid_dataset_path, seq_len, anomaly_ratio=0.3)
    test_dataset = SequentialDataset(test_dataset_path, seq_len, anomaly_ratio=0.3)
    
    dataset = [train_dataset, valid_dataset, test_dataset]

    import math

    batch_size = 256
    for name, dataset in [("Train", train_dataset), ("Valid", valid_dataset), ("Test", test_dataset)]:
        num_batches = math.ceil(len(dataset) / batch_size)
        print(f"{name} set: {len(dataset)} samples, batch size {batch_size}, total batches: {num_batches}")
    
    print(f"\nTrain anomaly windows: {len(train_dataset.anomaly_indices)}")
    print(f"Train normal windows: {len(train_dataset.normal_indices)}")
    print(f"Train anomaly ratio: {len(train_dataset.anomaly_indices)/(len(train_dataset.anomaly_indices)+len(train_dataset.normal_indices))*100:.2f}%")
    
    print(f"\nValid anomaly windows: {len(valid_dataset.anomaly_indices)}")
    print(f"Valid normal windows: {len(valid_dataset.normal_indices)}")
    print(f"Valid anomaly ratio: {len(valid_dataset.anomaly_indices)/(len(valid_dataset.anomaly_indices)+len(valid_dataset.normal_indices))*100:.2f}%")
    
    print(f"\nTest anomaly windows: {len(test_dataset.anomaly_indices)}")
    print(f"Test normal windows: {len(test_dataset.normal_indices)}")
    print(f"Test anomaly ratio: {len(test_dataset.anomaly_indices)/(len(test_dataset.anomaly_indices)+len(test_dataset.normal_indices))*100:.2f}%")
    

    for dataset_name, dataset in [("Train", train_dataset), ("Valid", valid_dataset), ("Test", test_dataset)]:
        print(f"\n{dataset_name} Dataset:")
        
        import random
        sample_indices = random.sample(range(len(dataset)), min(10, len(dataset)))
        
        for i, idx in enumerate(sample_indices):
            x, y = dataset[idx]
            print(f"  Sample {i+1} (index {idx}):")
            print(f"    x.shape: {x.shape}")
            print(f"    y.shape: {y.shape}")
            print(f"    y value: {y.item()}")
            print(f"    x dtype: {x.dtype}")
            print(f"    y dtype: {y.dtype}")
            
            day, instrument, book = dataset.index_map[dataset.sample_indices[idx]]
            raw_data = dataset.instruments_data_matrices[day][instrument][book: book + seq_len]
            fraud_labels = raw_data[:, -1]  # FraudType column
            expected_y = 1 if (fraud_labels != 0).any() else 0
            print(f"    Expected y: {expected_y}")
            print(f"    Fraud labels in window: {fraud_labels}")
            print(f"    Y validation: {'✓' if y.item() == expected_y else '✗'}")
            
            expected_features = len(dataset.features)
            print(f"    Expected features: {expected_features}")
            print(f"    Actual features: {x.shape[1]}")
            print(f"    Features validation: {'✓' if x.shape[1] == expected_features else '✗'}")
    
    if train_dataset.anomaly_ratio is not None:
        anomaly_count = 0
        total_count = len(train_dataset)
        for idx in range(total_count):
            _, y = train_dataset[idx]
            if y.item() == 1:
                anomaly_count += 1
        print(f"\nThe Actual Ratio after Sampling: {anomaly_count}/{total_count} = {anomaly_count/total_count:.4f}")