import os
from urllib import request
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning import LightningDataModule


class German(LightningDataModule):
    target = "credit"
    protected = "age"

    def __init__(
        self,
        batch_size: int = 128,
        validation_size: float = 0.2,
        test_size: float = 0.2,
        seed: int = 0,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.save_hyperparameters({"dataset": self.__class__.__name__})

        df = self.load_data()
        df = self.preprocess(df)
        self.train_val_test_split(df)

    def load_data(self) -> pd.DataFrame:
        # create data directory
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        data_dir = os.path.join(project_root, "data", "german")
        os.makedirs(data_dir, exist_ok=True)

        # download and load train data into dataframe
        data_file = os.path.join(data_dir, "german.data")
        if not os.path.exists(data_file):
            url = "https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data"
            request.urlretrieve(url, data_file)

        column_names = [
            "status",
            "months",
            "credit_history",
            "purpose",
            "credit_amount",
            "savings",
            "employment",
            "investment_as_income_percentage",
            "personal_status",
            "other_debtors",
            "residence_since",
            "property",
            "age",
            "installment_plans",
            "housing",
            "number_of_credits",
            "skill_level",
            "people_liable_for",
            "telephone",
            "foreign_worker",
            "credit",
        ]
        df = pd.read_csv(data_file, sep=" ", header=None, names=column_names)
        return df

    def preprocess(self, df: pd.DataFrame):
        # convert to numeric
        df["status"] = df["status"].map(
            {
                "A11": 1,  # <0 DM
                "A12": 2,  # <200 DM
                "A13": 3,  # >200 DM
                "A14": 4,  # Unknown/NoAccount
            }
        )
        df["credit_history"] = df["credit_history"].map(
            {
                "A30": 0,  # No Credit
                "A31": 1,  # All Paid
                "A32": 2,  # Existing Paid
                "A33": 3,  # Delay
                "A34": 4,  # Critical Account
            }
        )
        df["savings"] = df["savings"].map(
            {
                "A61": 1,  # <100 DM
                "A62": 2,  # <500 DM
                "A63": 3,  # <1000 DM
                "A64": 4,  # >1000 DM
                "A65": 5,  # Unknown/NoAccount
            }
        )
        df["employment"] = df["employment"].map(
            {
                "A71": 1,  # Unemployed
                "A72": 2,  # <1 Year
                "A73": 3,  # <4 Years
                "A74": 4,  # <7 Years
                "A75": 5,  # >7 Years
            }
        )
        df["property"] = df["property"].map(
            {
                "A121": 1,  # Real Estate
                "A122": 2,  # Life Insurance
                "A123": 3,  # Car
                "A124": 4,  # Unknown/None
            }
        )
        df["installment_plans"] = df["installment_plans"].map(
            {
                "A141": 1,  # Bank
                "A142": 2,  # Stores
                "A143": 3,  # None
            }
        )

        # convert binary columns to boolean
        df["telephone"] = df["telephone"].map({"A191": False, "A192": True})
        df["foreign_worker"] = df["foreign_worker"].map({"A201": True, "A202": False})
        df["housing"] = df["housing"].map(
            {
                "A151": True,  # Rent
                "A152": False,  # Own
                "A153": False,  # For Free
            }
        )
        df["other_debtors"] = df["other_debtors"].map(
            {
                "A101": False,  # None
                "A102": True,  # Co-Applicant
                "A103": True,  # Guarantor
            }
        )

        # convert categorical columns to strings
        df["purpose"] = df["purpose"].map(
            {
                "A40": "Car (New)",
                "A41": "Car (Used)",
                "A42": "Furniture/Equipment",
                "A43": "Radio/Television",
                "A44": "Domestic Appliances",
                "A45": "Repairs",
                "A46": "Education",
                "A47": "Vacation",
                "A48": "Retraining",
                "A49": "Business",
                "A410": "Others",
            }
        )
        df["skill_level"] = df["skill_level"].map(
            {
                "A171": "Unskilled non resident",
                "A172": "Unskilled resident",
                "A173": "Skilled",
                "A174": "Highly Skilled",
            }
        )
        df["personal_status"] = df["personal_status"].map(
            {
                "A91": "Male Divorced/Seperated",
                "A92": "Female Divorced/Seperated/Married",
                "A93": "Male Single",
                "A94": "Male Married/Widowed",
                "A95": "Female Single",
            }
        )

        # # convert age sex and credit to binary
        df["age"] = df["age"] > 25
        df["credit"] = df["credit"] == 1
        df["sex"] = df["personal_status"].map(lambda x: x.startswith("Female"))
        df = df.drop(columns=["personal_status"])

        # one-hot encode categorical columns
        categorical_columns = [col for col in df.columns if df[col].dtype == object]
        if len(categorical_columns) > 0:
            df = pd.get_dummies(df, columns=categorical_columns, prefix_sep="=")

        # # normalize continuous columns
        continuous_columns = [col for col in df.columns if df[col].dtype != bool]
        if len(continuous_columns) > 0:
            mean = df[continuous_columns].mean()
            std = df[continuous_columns].std()
            df[continuous_columns] = (df[continuous_columns] - mean) / std
        return df

    def train_val_test_split(self, train_df, test_df=None):
        # split into features, protected attribute, and target
        train_target = torch.tensor(train_df.pop(self.target).to_numpy(np.int64))
        train_protected = torch.tensor(train_df.pop(self.protected).to_numpy(np.int64))
        train_features = torch.tensor(train_df.to_numpy(np.float32))

        if test_df is not None:
            # split into features, protected attribute, and target
            self.test_target = torch.tensor(test_df.pop(self.target).to_numpy(np.int64))
            self.test_protected = torch.tensor(test_df.pop(self.protected).to_numpy(np.int64))
            self.test_features = torch.tensor(test_df.to_numpy(np.float32))
        else:
            # split into test and train sets
            (
                train_features,
                self.test_features,
                train_protected,
                self.test_protected,
                train_target,
                self.test_target,
            ) = train_test_split(
                train_features,
                train_protected,
                train_target,
                test_size=self.hparams["test_size"],
                random_state=self.hparams["seed"],
            )

        # split into validation and train sets
        (
            self.train_features,
            self.val_features,
            self.train_protected,
            self.val_protected,
            self.train_target,
            self.val_target,
        ) = train_test_split(
            train_features,
            train_protected,
            train_target,
            test_size=self.hparams["validation_size"],
            random_state=self.hparams["seed"],
        )

    def train_dataloader(self):
        dataset = TensorDataset(self.train_features, self.train_protected, self.train_target)
        return DataLoader(
            dataset, self.hparams["batch_size"], num_workers=1, shuffle=True, drop_last=True
        )

    def val_dataloader(self):
        dataset = TensorDataset(self.val_features, self.val_protected, self.val_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)

    def test_dataloader(self):
        dataset = TensorDataset(self.test_features, self.test_protected, self.test_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)


class GermanSex(German):
    target = "credit"
    protected = "sex"


class GermanAge(German):
    target = "credit"
    protected = "age"
