import os
from urllib import request
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning import LightningDataModule


class Adult(LightningDataModule):
    target = "income"
    protected = "sex"

    def __init__(
        self,
        batch_size: int = 128,
        validation_size: float = 0.2,
        test_size: float = 0.2,  # ignored, test set is fixed
        seed: int = 0,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.save_hyperparameters({"dataset": self.__class__.__name__})

        train_df, test_df = self.load_data()
        train_df, test_df = self.preprocess(train_df, test_df)
        self.train_val_test_split(train_df, test_df)

    def load_data(self) -> tuple[pd.DataFrame, pd.DataFrame]:
        # create data directory
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        data_dir = os.path.join(project_root, "data", "adult")
        os.makedirs(data_dir, exist_ok=True)

        # download train data file
        train_file = os.path.join(data_dir, "adult.data")
        if not os.path.exists(train_file):
            url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
            request.urlretrieve(url, train_file)

        # download test data file
        test_file = os.path.join(data_dir, "adult.test")
        if not os.path.exists(test_file):
            url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
            request.urlretrieve(url, test_file)

        # load data into dataframes
        column_names = [
            "age",
            "workclass",
            "fnlwgt",
            "education",
            "education_num",
            "marital_status",
            "occupation",
            "relationship",
            "race",
            "sex",
            "capital_gain",
            "capital_loss",
            "hours_per_week",
            "native_country",
            "income",
        ]
        train_df = pd.read_csv(train_file, sep=",", header=None, names=column_names)
        test_df = pd.read_csv(test_file, sep=",", header=0, names=column_names)
        return train_df, test_df

    def preprocess(self, train_df: pd.DataFrame, test_df: pd.DataFrame):
        # remove leading/trailing whitespaces/dots
        train_df = train_df.map(lambda x: x.strip(" .") if isinstance(x, str) else x)
        test_df = test_df.map(lambda x: x.strip(" .") if isinstance(x, str) else x)

        # remove rows with missing values
        train_df.replace(to_replace="?", value=np.nan, inplace=True)
        train_df.dropna(axis=0, inplace=True)
        test_df.replace(to_replace="?", value=np.nan, inplace=True)
        test_df.dropna(axis=0, inplace=True)

        # convert binary columns from str to boolean
        train_df[["income"]] = train_df[["income"]] == ">50K"
        test_df[["income"]] = test_df[["income"]] == ">50K"
        train_df[["sex"]] = train_df[["sex"]] == "Female"
        test_df[["sex"]] = test_df[["sex"]] == "Female"
        if "native_country" in train_df.columns:
            train_df[["native_country"]] = train_df[["native_country"]] == "United-States"
            test_df[["native_country"]] = test_df[["native_country"]] == "United-States"

        # log transform columns with large magnitude differences
        for col in ["fnlwgt", "capital_gain", "capital_loss"]:
            if col in train_df.columns:
                train_df[col] = np.log1p(train_df[col])
                test_df[col] = np.log1p(test_df[col])

        # one-hot encode categorical columns
        categorical_columns = [col for col in train_df.columns if train_df[col].dtype == object]
        if len(categorical_columns) > 0:
            train_df = pd.get_dummies(train_df, columns=categorical_columns, prefix_sep="=")
            test_df = pd.get_dummies(test_df, columns=categorical_columns, prefix_sep="=")

        # normalize continuous columns
        continuous_columns = [col for col in train_df.columns if train_df[col].dtype != bool]
        if len(continuous_columns) > 0:
            mean = train_df[continuous_columns].mean()
            std = train_df[continuous_columns].std()
            train_df[continuous_columns] = (train_df[continuous_columns] - mean) / std
            test_df[continuous_columns] = (test_df[continuous_columns] - mean) / std
        return train_df, test_df

    def train_val_test_split(self, train_df, test_df=None):
        # split into features, protected attribute, and target
        train_target = torch.tensor(train_df.pop(self.target).to_numpy(np.int64))
        train_protected = torch.tensor(train_df.pop(self.protected).to_numpy(np.int64))
        train_features = torch.tensor(train_df.to_numpy(np.float32))
        
        if test_df is not None:
            # split into features, protected attribute, and target
            self.test_target = torch.tensor(test_df.pop(self.target).to_numpy(np.int64))
            self.test_protected = torch.tensor(test_df.pop(self.protected).to_numpy(np.int64))
            self.test_features = torch.tensor(test_df.to_numpy(np.float32))
        else:
            # split into test and train sets
            (
                train_features,
                self.test_features,
                train_protected,
                self.test_protected,
                train_target,
                self.test_target,
            ) = train_test_split(
                train_features,
                train_protected,
                train_target,
                test_size=self.hparams["test_size"],
                random_state=self.hparams["seed"],
            )

        # split into validation and train sets
        (
            self.train_features,
            self.val_features,
            self.train_protected,
            self.val_protected,
            self.train_target,
            self.val_target,
        ) = train_test_split(
            train_features,
            train_protected,
            train_target,
            test_size=self.hparams["validation_size"],
            random_state=self.hparams["seed"],
        )

    def train_dataloader(self):
        dataset = TensorDataset(self.train_features, self.train_protected, self.train_target)
        return DataLoader(
            dataset, self.hparams["batch_size"], num_workers=1, shuffle=True, drop_last=True
        )

    def val_dataloader(self):
        dataset = TensorDataset(self.val_features, self.val_protected, self.val_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)

    def test_dataloader(self):
        dataset = TensorDataset(self.test_features, self.test_protected, self.test_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)
