import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer


class DataHandler:
    def __init__(self, args):
        self.args = args
        self.x, self.y = self.load_and_process_dataset(self.args.dataset)
        self.split_data()

    def load_and_process_dataset(self, filepath):
        data = pd.read_csv("data/" + filepath + ".csv", header=None)
        # target_column is -1, assume it's the last column
        target_column = data.shape[1] - 1

        x = data.drop(columns=[target_column])
        y = np.where(np.array(data[target_column]) == 0, -1, 1)
        x = self.binarize_dataset(x)

        return x, y

    def binarize_dataset(self, x):
        x_binarized = x.copy()

        for col in x.columns:
            if x[col].dtype == "object":
                lb = LabelBinarizer()
                binarized_data = lb.fit_transform(x[col])
                # Append new binarized columns
                for i in range(binarized_data.shape[1]):
                    x_binarized[f"{col}_{i}"] = binarized_data[:, i]
                x_binarized = x_binarized.drop(
                    columns=[col]
                )  # Drop original categorical column

        return x_binarized

    def split_data(self):
        """
        Splits the dataset into train, validation, and test sets.
        Also initializes `pred` as empty lists for each split.
        """
        X_train_val, X_test, y_train_val, y_test = train_test_split(
            self.x, self.y, test_size=0.2, random_state=self.args.seed
        )
        X_train, X_val, y_train, y_val = train_test_split(
            X_train_val,
            y_train_val,
            test_size=0.25,  # 25% of 80% is 20% of the total
            random_state=self.args.seed,
        )
        self.data = {
            "train": {"x": X_train, "y": y_train, "pred": []},
            "val": {"x": X_val, "y": y_val, "pred": []},
            "test": {"x": X_test, "y": y_test, "pred": []},
        }

    def get_split(self, split):
        """Retrieve a specific dataset split."""
        if split not in self.data:
            raise ValueError(
                f"Invalid split '{split}'. Must be one of 'train', 'val', or 'test'."
            )
        return self.data[split]

    def add_pred(self, split, pred):
        """Adds a pred array to the specified split."""
        if split not in self.data:
            raise ValueError(
                f"Invalid split '{split}'. Must be one of 'train', 'val', or 'test'."
            )
        self.data[split]["pred"].append(pred.flatten())

    def get_all_splits(self):
        """Retrieve all dataset splits."""
        return self.data
