import os
from urllib import request
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning import LightningDataModule


class Crime(LightningDataModule):
    target = "ViolentCrimesPerPop"
    protected = "racePctWhiteVsracePctOther"

    def __init__(
        self,
        batch_size: int = 128,
        validation_size: float = 0.2,
        test_size: float = 0.2,
        seed: int = 0,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.save_hyperparameters({"dataset": self.__class__.__name__})

        df = self.load_data()
        df = self.preprocess(df)
        self.train_val_test_split(df)

    def load_data(self) -> pd.DataFrame:
        # create data directory
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        data_dir = os.path.join(project_root, "data", "crime")
        os.makedirs(data_dir, exist_ok=True)

        # download and load train data into dataframe
        data_file = os.path.join(data_dir, "communities.data")
        if not os.path.exists(data_file):
            url = "http://archive.ics.uci.edu/ml/machine-learning-databases/00211/CommViolPredUnnormalizedData.txt"
            request.urlretrieve(url, data_file)

        # load data into dataframe
        column_names = [
            "communityname",
            "state",
            "countyCode",
            "communityCode",
            "fold",
            "population",
            "householdsize",
            "racePctBlack",
            "racePctWhite",
            "racePctAsian",
            "racePctHisp",
            "agePct12t21",
            "agePct12t29",
            "agePct16t24",
            "agePct65up",
            "NumUrban",
            "PctUrban",
            "medIncome",
            "PctWWage",
            "PctWFarmSelf",
            "PctWInvInc",
            "PctWSocSec",
            "PctWPubAsst",
            "PctWRetire",
            "medFamInc",
            "PerCapInc",
            "whitePerCap",
            "blackPerCap",
            "indianPerCap",
            "AsianPerCap",
            "OtherPerCap",
            "HispPerCap",
            "NumUnderPov",
            "PctPopUnderPov",
            "PctLess9thGrade",
            "PctNotHSGrad",
            "PctBSorMore",
            "PctUnemployed",
            "PctEmploy",
            "PctEmplManu",
            "PctEmplProfServ",
            "PctOccupManu",
            "PctOccupMgmtProf",
            "MalePctDivorce",
            "MalePctNevMarr",
            "FemalePctDiv",
            "TotalPctDiv",
            "PersPerFam",
            "PctFam2Par",
            "PctKids2Par",
            "PctYoungKids2Par",
            "PctTeen2Par",
            "PctWorkMomYoungKids",
            "PctWorkMom",
            "NumKidsBornNeverMar",
            "PctKidsBornNeverMar",
            "NumImmig",
            "PctImmigRecent",
            "PctImmigRec5",
            "PctImmigRec8",
            "PctImmigRec10",
            "PctRecentImmig",
            "PctRecImmig5",
            "PctRecImmig8",
            "PctRecImmig10",
            "PctSpeakEnglOnly",
            "PctNotSpeakEnglWell",
            "PctLargHouseFam",
            "PctLargHouseOccup",
            "PersPerOccupHous",
            "PersPerOwnOccHous",
            "PersPerRentOccHous",
            "PctPersOwnOccup",
            "PctPersDenseHous",
            "PctHousLess3BR",
            "MedNumBR",
            "HousVacant",
            "PctHousOccup",
            "PctHousOwnOcc",
            "PctVacantBoarded",
            "PctVacMore6Mos",
            "MedYrHousBuilt",
            "PctHousNoPhone",
            "PctWOFullPlumb",
            "OwnOccLowQuart",
            "OwnOccMedVal",
            "OwnOccHiQuart",
            "OwnOccQrange",
            "RentLowQ",
            "RentMedian",
            "RentHighQ",
            "RentQrange",
            "MedRent",
            "MedRentPctHousInc",
            "MedOwnCostPctInc",
            "MedOwnCostPctIncNoMtg",
            "NumInShelters",
            "NumStreet",
            "PctForeignBorn",
            "PctBornSameState",
            "PctSameHouse85",
            "PctSameCity85",
            "PctSameState85",
            "LemasSwornFT",
            "LemasSwFTPerPop",
            "LemasSwFTFieldOps",
            "LemasSwFTFieldPerPop",
            "LemasTotalReq",
            "LemasTotReqPerPop",
            "PolicReqPerOffic",
            "PolicPerPop",
            "RacialMatchCommPol",
            "PctPolicWhite",
            "PctPolicBlack",
            "PctPolicHisp",
            "PctPolicAsian",
            "PctPolicMinor",
            "OfficAssgnDrugUnits",
            "NumKindsDrugsSeiz",
            "PolicAveOTWorked",
            "LandArea",
            "PopDens",
            "PctUsePubTrans",
            "PolicCars",
            "PolicOperBudg",
            "LemasPctPolicOnPatr",
            "LemasGangUnitDeploy",
            "LemasPctOfficDrugUn",
            "PolicBudgPerPop",
            "murders",
            "murdPerPop",
            "rapes",
            "rapesPerPop",
            "robberies",
            "robbbPerPop",
            "assaults",
            "assaultPerPop",
            "burglaries",
            "burglPerPop",
            "larcenies",
            "larcPerPop",
            "autoTheft",
            "autoTheftPerPop",
            "arsons",
            "arsonsPerPop",
            "ViolentCrimesPerPop",
            "nonViolPerPop",
        ]
        df = pd.read_csv(data_file, sep=",", header=None, names=column_names)
        return df

    def preprocess(self, df: pd.DataFrame):
        # remove features that are not predictive
        df.drop(["communityname", "countyCode", "communityCode", "fold"], axis=1, inplace=True)

        # remove all other potential goal variables
        other_targets = [
            "murders",
            "murdPerPop",
            "rapes",
            "rapesPerPop",
            "robberies",
            "robbbPerPop",
            "assaults",
            "assaultPerPop",
            "burglaries",
            "burglPerPop",
            "larcenies",
            "larcPerPop",
            "autoTheft",
            "autoTheftPerPop",
            "arsons",
            "arsonsPerPop",
            "nonViolPerPop",
        ]
        df.drop(other_targets, axis=1, inplace=True)

        # remove rows/cols with missing values
        df.replace(to_replace="?", value=np.nan, inplace=True)
        df.dropna(axis=0, subset=["ViolentCrimesPerPop"], inplace=True)
        df.dropna(axis=1, inplace=True)

        # fix data types
        for col in df.columns:
            if (df[col].dtype == object) and (col != "state"):
                df[col] = pd.to_numeric(df[col])

        # log transform columns with large magnitude differences
        log_trainsform_columns = [
            "population",
            "NumUrban",
            "NumUnderPov",
            "NumKidsBornNeverMar",
            "NumImmig",
            "NumInShelters",
            "NumStreet",
            "HousVacant",
            "OwnOccLowQuart",
            "OwnOccMedVal",
            "OwnOccHiQuart",
            "OwnOccQrange",
            "LemasSwornFT",
            "LemasSwFTPerPop",
            "LemasSwFTFieldOps",
            "LemasSwFTFieldPerPop",
            "LemasTotalReq",
            "LemasTotReqPerPop",
            "LemasGangUnitDeploy",
            "PolicReqPerOffic",
            "PolicPerPop",
            "PolicCars",
            "PolicOperBudg",
            "PolicBudgPerPop",
            "OfficAssgnDrugUnits",
            "LandArea",
            "PopDens",
            "PctRecentImmig",
            "PctRecImmig5",
            "PctRecImmig8",
            "PctRecImmig10",
            "PctNotSpeakEnglWell",
            "PctPersDenseHous",
            "PctVacantBoarded",
            "PctPolicBlack",
            "PctPolicHisp",
            "PctPolicAsian",
            "PctPolicMinor",
            "PctUsePubTrans",
        ]
        for col in log_trainsform_columns:
            if col in df.columns:
                df[col] = np.log1p(df[col])

        # add protected attribute
        df["racePctWhiteVsracePctOther"] = np.less(
            df["racePctWhite"] / 5, df["racePctBlack"] + df["racePctAsian"] + df["racePctHisp"]
        )

        # binarize target
        df["ViolentCrimesPerPop"] = df["ViolentCrimesPerPop"] < df["ViolentCrimesPerPop"].median()

        # one-hot encode categorical columns
        categorical_columns = [col for col in df.columns if df[col].dtype == object]
        if len(categorical_columns) > 0:
            df = pd.get_dummies(df, columns=categorical_columns, prefix_sep="=")

        # normalize continuous columns
        continuous_columns = [col for col in df.columns if df[col].dtype != bool]
        if len(continuous_columns) > 0:
            mean = df[continuous_columns].mean()
            std = df[continuous_columns].std()
            df[continuous_columns] = (df[continuous_columns] - mean) / std
        return df

    def train_val_test_split(self, train_df, test_df=None):
        # split into features, protected attribute, and target
        train_target = torch.tensor(train_df.pop(self.target).to_numpy(np.int64))
        train_protected = torch.tensor(train_df.pop(self.protected).to_numpy(np.int64))
        train_features = torch.tensor(train_df.to_numpy(np.float32))

        if test_df is not None:
            # split into features, protected attribute, and target
            self.test_target = torch.tensor(test_df.pop(self.target).to_numpy(np.int64))
            self.test_protected = torch.tensor(test_df.pop(self.protected).to_numpy(np.int64))
            self.test_features = torch.tensor(test_df.to_numpy(np.float32))
        else:
            # split into test and train sets
            (
                train_features,
                self.test_features,
                train_protected,
                self.test_protected,
                train_target,
                self.test_target,
            ) = train_test_split(
                train_features,
                train_protected,
                train_target,
                test_size=self.hparams["test_size"],
                random_state=self.hparams["seed"],
            )

        # split into validation and train sets
        (
            self.train_features,
            self.val_features,
            self.train_protected,
            self.val_protected,
            self.train_target,
            self.val_target,
        ) = train_test_split(
            train_features,
            train_protected,
            train_target,
            test_size=self.hparams["validation_size"],
            random_state=self.hparams["seed"],
        )

    def train_dataloader(self):
        dataset = TensorDataset(self.train_features, self.train_protected, self.train_target)
        return DataLoader(
            dataset, self.hparams["batch_size"], num_workers=1, shuffle=True, drop_last=True
        )

    def val_dataloader(self):
        dataset = TensorDataset(self.val_features, self.val_protected, self.val_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)

    def test_dataloader(self):
        dataset = TensorDataset(self.test_features, self.test_protected, self.test_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)


