import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_classification
from .dataset import BaseVFLDataset

class SyntheticVFLDataset(BaseVFLDataset):

    def __init__(self, n_samples=40000, n_features=500, n_informative=100,
                 test_size=0.2, random_state=42):
        super().__init__(test_size, random_state)

        X, y = make_classification(
            n_samples=n_samples,
            n_features=n_features,
            n_informative=n_informative,
            random_state=random_state
        )


        scaler = StandardScaler()
        X = scaler.fit_transform(X)

        data = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(X.shape[1])])
        data['target'] = y

    
        self.preprocessed_data = self._preprocess_data(data)
        data_a, data_b = self._split_features(self.preprocessed_data)
        labels = y

    
        (self.train_data_a, self.train_data_b, self.train_labels,
         self.test_data_a, self.test_data_b, self.test_labels) = self._split_data(
            data_a, data_b, labels)

    def _preprocess_data(self, data):

        return data

    def _split_features(self, data):


        features_a = [f'feature_{i}' for i in range(50)]
        data_a = data[features_a].values


        features_b = [f'feature_{i}' for i in range(50, 500)]
        data_b = data[features_b].values

        return data_a, data_b



if __name__ == "__main__":

    dataset = SyntheticVFLDataset()


    train_x_a, train_y = dataset.get_train_data_for_a()
    train_x_b = dataset.get_train_data_for_b()
    test_x_a, test_y = dataset.get_test_data_for_a()
    test_x_b = dataset.get_test_data_for_b()


    print("\nData Shape:")
    print("Training Data Party A:", train_x_a.shape)
    print("Training Data  Party B:", train_x_b.shape)
    print("Training Lable:", train_y.shape)
    print("Testing Data Party A:", test_x_a.shape)
    print("Testing Data Party B:", test_x_b.shape)
    print("Testing Lable:", test_y.shape)