import os
from urllib import request
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning import LightningDataModule


class Compas(LightningDataModule):
    target = "two_year_recid"
    protected = "race"

    def __init__(
        self,
        batch_size: int = 128,
        validation_size: float = 0.2,
        test_size: float = 0.2,
        seed: int = 0,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.save_hyperparameters({"dataset": self.__class__.__name__})

        df = self.load_data()
        df = self.preprocess(df)
        self.train_val_test_split(df)

    def load_data(self) -> pd.DataFrame:
        # create data directory
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        data_dir = os.path.join(project_root, "data", "compas")
        os.makedirs(data_dir, exist_ok=True)

        # download and load train data into dataframe
        data_file = os.path.join(data_dir, "compas-scores-two-years.csv")
        if not os.path.exists(data_file):
            url = "https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv"
            request.urlretrieve(url, data_file)
        df = pd.read_csv(data_file)
        return df

    def preprocess(self, df: pd.DataFrame):
        # filter out irrelevant entries
        df = df[df["days_b_screening_arrest"] >= -30]
        df = df[df["days_b_screening_arrest"] <= 30]
        df = df[df["is_recid"] != -1]
        df = df[df["c_charge_degree"] != "0"]
        df = df[df["score_text"] != "N/A"]

        # only keep African-American and Caucasian
        df = df[df["race"].isin(["African-American", "Caucasian"])]
        df["race"] = df["race"] == "African-American"

        # get total custody/jail time
        df["in_custody"] = pd.to_datetime(df["in_custody"])
        df["out_custody"] = pd.to_datetime(df["out_custody"])
        df["diff_custody"] = (df["out_custody"] - df["in_custody"]).dt.total_seconds()
        df["c_jail_in"] = pd.to_datetime(df["c_jail_in"])
        df["c_jail_out"] = pd.to_datetime(df["c_jail_out"])
        df["diff_jail"] = (df["c_jail_out"] - df["c_jail_in"]).dt.total_seconds()

        # # drop rows with negative jail time or custody time
        df = df[df["diff_jail"] >= 0]
        df = df[df["diff_custody"] >= 0]

        # drop other possible targets
        other_targets = ["is_recid", "is_violent_recid", "violent_recid"]
        df.drop(other_targets, axis=1, inplace=True)

        # drop unique identifiers
        unique_ids = ["id", "name", "first", "last", "c_case_number", "r_case_number"]
        df.drop(unique_ids, axis=1, inplace=True)

        # drop dates and timestamps
        date_columns = [
            "v_screening_date",
            "compas_screening_date",
            "dob",
            "screening_date",
            "in_custody",
            "out_custody",
            "c_jail_in",
            "c_jail_out",
            "c_offense_date",
            "c_arrest_date",
            "start",
            "end",
            "event",
        ]
        for col in date_columns:
            if col in df.columns:
                df.drop(col, axis=1, inplace=True)

        # drop columns with missing values
        incomplete_columns = [col for col in df.columns if df[col].isnull().sum() > 0]
        df.drop(incomplete_columns, axis=1, inplace=True)

        # drop constant columns (single value)
        constant_columns = ["type_of_assessment", "v_type_of_assessment"]
        for col in constant_columns:
            if col in df.columns:
                df.drop(col, axis=1, inplace=True)

        # log transform columns with large magnitude differences
        for col in ["c_days_from_compas", "diff_custody", "diff_jail", "start", "end"]:
            if col in df.columns:
                df[col] = np.log1p(df[col])

        # one-hot encode categorical columns
        categorical_columns = [col for col in df.columns if df[col].dtype == object]
        if len(categorical_columns) > 0:
            df = pd.get_dummies(df, columns=categorical_columns, prefix_sep="=")

        # normalize continuous columns
        continuous_columns = [col for col in df.columns if df[col].dtype != bool]
        if len(continuous_columns) > 0:
            mean = df[continuous_columns].mean()
            std = df[continuous_columns].std()
            df[continuous_columns] = (df[continuous_columns] - mean) / std
        return df

    def train_val_test_split(self, train_df, test_df=None):
        # split into features, protected attribute, and target
        train_target = torch.tensor(train_df.pop(self.target).to_numpy(np.int64))
        train_protected = torch.tensor(train_df.pop(self.protected).to_numpy(np.int64))
        train_features = torch.tensor(train_df.to_numpy(np.float32))

        if test_df is not None:
            # split into features, protected attribute, and target
            self.test_target = torch.tensor(test_df.pop(self.target).to_numpy(np.int64))
            self.test_protected = torch.tensor(test_df.pop(self.protected).to_numpy(np.int64))
            self.test_features = torch.tensor(test_df.to_numpy(np.float32))
        else:
            # split into test and train sets
            (
                train_features,
                self.test_features,
                train_protected,
                self.test_protected,
                train_target,
                self.test_target,
            ) = train_test_split(
                train_features,
                train_protected,
                train_target,
                test_size=self.hparams["test_size"],
                random_state=self.hparams["seed"],
            )

        # split into validation and train sets
        (
            self.train_features,
            self.val_features,
            self.train_protected,
            self.val_protected,
            self.train_target,
            self.val_target,
        ) = train_test_split(
            train_features,
            train_protected,
            train_target,
            test_size=self.hparams["validation_size"],
            random_state=self.hparams["seed"],
        )

    def train_dataloader(self):
        dataset = TensorDataset(self.train_features, self.train_protected, self.train_target)
        return DataLoader(
            dataset, self.hparams["batch_size"], num_workers=1, shuffle=True, drop_last=True
        )

    def val_dataloader(self):
        dataset = TensorDataset(self.val_features, self.val_protected, self.val_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)

    def test_dataloader(self):
        dataset = TensorDataset(self.test_features, self.test_protected, self.test_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)