import os
import warnings
import zipfile
from urllib import request
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning import LightningDataModule


class Health(LightningDataModule):
    target = "max_CharlsonIndex"
    protected = "AgeAtFirstClaim>60"

    def __init__(
        self,
        batch_size: int = 128,
        validation_size: float = 0.2,
        test_size: float = 0.2,
        seed: int = 0,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.save_hyperparameters({"dataset": self.__class__.__name__})

        df = self.load_data()
        df = self.preprocess(df)
        self.train_val_test_split(df)

    def load_data(self) -> pd.DataFrame:
        # create data directory
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        data_dir = os.path.join(project_root, "data", "health")
        os.makedirs(data_dir, exist_ok=True)

        # download and load train data into dataframe
        data_file = os.path.join(data_dir, "health_full.csv")
        if os.path.exists(data_file):
            df_health = pd.read_csv(data_file, sep=",")
        else:
            zip_file = os.path.join(data_dir, "HHP_release3.zip")
            if not os.path.exists(zip_file):
                print("Downloading data... (this may take a while)")
                url = "https://foreverdata.org/1015/content/HHP_release3.zip"
                request.urlretrieve(url, zip_file)
            zf = zipfile.ZipFile(zip_file)

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                df_claims = self.preprocess_claims(pd.read_csv(zf.open("Claims.csv"), sep=","))
                df_drugs = self.preprocess_drugs(pd.read_csv(zf.open("DrugCount.csv"), sep=","))
                df_labs = self.preprocess_labs(pd.read_csv(zf.open("LabCount.csv"), sep=","))
                df_members = self.preprocess_members(pd.read_csv(zf.open("Members.csv"), sep=","))

                df_labs_drugs = pd.merge(df_labs, df_drugs, on=["MemberID", "Year"], how="outer")
                df_labs_drugs_claims = pd.merge(
                    df_labs_drugs, df_claims, on=["MemberID", "Year"], how="outer"
                )
                df_health = pd.merge(df_labs_drugs_claims, df_members, on=["MemberID"], how="outer")
                df_health.to_csv(data_file, index=False)
        return df_health

    @staticmethod
    def preprocess_claims(df_claims):
        df_claims.loc[df_claims["PayDelay"] == "162+", "PayDelay"] = 162
        df_claims["PayDelay"] = df_claims["PayDelay"].astype(int)

        df_claims.loc[df_claims["DSFS"] == "0- 1 month", "DSFS"] = 1
        df_claims.loc[df_claims["DSFS"] == "1- 2 months", "DSFS"] = 2
        df_claims.loc[df_claims["DSFS"] == "2- 3 months", "DSFS"] = 3
        df_claims.loc[df_claims["DSFS"] == "3- 4 months", "DSFS"] = 4
        df_claims.loc[df_claims["DSFS"] == "4- 5 months", "DSFS"] = 5
        df_claims.loc[df_claims["DSFS"] == "5- 6 months", "DSFS"] = 6
        df_claims.loc[df_claims["DSFS"] == "6- 7 months", "DSFS"] = 7
        df_claims.loc[df_claims["DSFS"] == "7- 8 months", "DSFS"] = 8
        df_claims.loc[df_claims["DSFS"] == "8- 9 months", "DSFS"] = 9
        df_claims.loc[df_claims["DSFS"] == "9-10 months", "DSFS"] = 10
        df_claims.loc[df_claims["DSFS"] == "10-11 months", "DSFS"] = 11
        df_claims.loc[df_claims["DSFS"] == "11-12 months", "DSFS"] = 12

        df_claims.loc[df_claims["CharlsonIndex"] == "0", "CharlsonIndex"] = 0
        df_claims.loc[df_claims["CharlsonIndex"] == "1-2", "CharlsonIndex"] = 1
        df_claims.loc[df_claims["CharlsonIndex"] == "3-4", "CharlsonIndex"] = 2
        df_claims.loc[df_claims["CharlsonIndex"] == "5+", "CharlsonIndex"] = 3

        df_claims.loc[df_claims["LengthOfStay"] == "1 day", "LengthOfStay"] = 1
        df_claims.loc[df_claims["LengthOfStay"] == "2 days", "LengthOfStay"] = 2
        df_claims.loc[df_claims["LengthOfStay"] == "3 days", "LengthOfStay"] = 3
        df_claims.loc[df_claims["LengthOfStay"] == "4 days", "LengthOfStay"] = 4
        df_claims.loc[df_claims["LengthOfStay"] == "5 days", "LengthOfStay"] = 5
        df_claims.loc[df_claims["LengthOfStay"] == "6 days", "LengthOfStay"] = 6
        df_claims.loc[df_claims["LengthOfStay"] == "1- 2 weeks", "LengthOfStay"] = 11
        df_claims.loc[df_claims["LengthOfStay"] == "2- 4 weeks", "LengthOfStay"] = 21
        df_claims.loc[df_claims["LengthOfStay"] == "4- 8 weeks", "LengthOfStay"] = 42
        df_claims.loc[df_claims["LengthOfStay"] == "26+ weeks", "LengthOfStay"] = 180
        df_claims["LengthOfStay"].fillna(0, inplace=True)
        df_claims["LengthOfStay"] = df_claims["LengthOfStay"].astype(int)

        claims_cat_names = ["PrimaryConditionGroup", "Specialty", "ProcedureGroup", "PlaceSvc"]
        for cat_name in claims_cat_names:
            df_claims[cat_name].fillna(f"{cat_name}_?", inplace=True)
        df_claims = pd.get_dummies(df_claims, columns=claims_cat_names, prefix_sep="=")

        oh = [col for col in df_claims.columns if "=" in col]
        agg = {
            "ProviderID": ["count", "nunique"],
            "Vendor": "nunique",
            "PCP": "nunique",
            "CharlsonIndex": "max",
            # 'PlaceSvc': 'nunique',
            # 'Specialty': 'nunique',
            # 'PrimaryConditionGroup': 'nunique',
            # 'ProcedureGroup': 'nunique',
            "PayDelay": ["sum", "max", "min"],
        }
        for col in oh:
            agg[col] = "sum"

        df_group = df_claims.groupby(["Year", "MemberID"])
        df_claims = df_group.agg(agg).reset_index()
        df_claims.columns = [
            "Year",
            "MemberID",
            "no_Claims",
            "no_Providers",
            "no_Vendors",
            "no_PCPs",
            "max_CharlsonIndex",
            "PayDelay_total",
            "PayDelay_max",
            "PayDelay_min",
        ] + oh
        return df_claims

    @staticmethod
    def preprocess_drugs(df_drugs):
        df_drugs.drop(columns=["DSFS"], inplace=True)
        # df_drugs['DSFS'] = df_drugs['DSFS'].apply(lambda x: int(x.split('-')[0])+1)
        df_drugs["DrugCount"] = df_drugs["DrugCount"].apply(lambda x: int(x.replace("+", "")))
        df_drugs = (
            df_drugs.groupby(["Year", "MemberID"])
            .agg({"DrugCount": ["sum", "count"]})
            .reset_index()
        )
        df_drugs.columns = ["Year", "MemberID", "DrugCount_total", "DrugCount_months"]
        return df_drugs

    @staticmethod
    def preprocess_labs(df_labs):
        df_labs.drop(columns=["DSFS"], inplace=True)
        # df_labs['DSFS'] = df_labs['DSFS'].apply(lambda x: int(x.split('-')[0])+1)
        df_labs["LabCount"] = df_labs["LabCount"].apply(lambda x: int(x.replace("+", "")))
        df_labs = (
            df_labs.groupby(["Year", "MemberID"]).agg({"LabCount": ["sum", "count"]}).reset_index()
        )
        df_labs.columns = ["Year", "MemberID", "LabCount_total", "LabCount_months"]
        return df_labs

    @staticmethod
    def preprocess_members(df_members):
        # drop missing row for AgeAtFirstClaim 
        df_members["AgeAtFirstClaim"].fillna("?", inplace=True)
        df_members["Sex"].fillna("?", inplace=True)

        df_members["AgeAtFirstClaim>60"] = df_members["AgeAtFirstClaim"].apply(
            lambda x: x in ["60-69", "70-79", "80+"]
        )
        df_members = pd.get_dummies(df_members, columns=["AgeAtFirstClaim", "Sex"], prefix_sep="=")
        return df_members

    def preprocess(self, df: pd.DataFrame):
        # drop uninformative columns
        df.drop(["Year", "MemberID"], axis=1, inplace=True)
        df.fillna(0, inplace=True)

        # drop rows with unknown protected ("AgeAtFirstClaim=?" == True)
        df.drop(df.loc[df['AgeAtFirstClaim=?']].index, inplace=True)
        df.drop(columns=['AgeAtFirstClaim=?'], inplace=True)

        # binarize target
        df["max_CharlsonIndex"] = df["max_CharlsonIndex"] > 0
        
        # normalize continuous columns
        continuous_columns = [col for col in df.columns if df[col].dtype != bool]
        if len(continuous_columns) > 0:
            mean = df[continuous_columns].mean()
            std = df[continuous_columns].std()
            df[continuous_columns] = (df[continuous_columns] - mean) / std
        return df

    def train_val_test_split(self, train_df, test_df=None):
        # split into features, protected attribute, and target
        train_target = torch.tensor(train_df.pop(self.target).to_numpy(np.int64))
        train_protected = torch.tensor(train_df.pop(self.protected).to_numpy(np.int64))
        train_features = torch.tensor(train_df.to_numpy(np.float32))

        if test_df is not None:
            # split into features, protected attribute, and target
            self.test_target = torch.tensor(test_df.pop(self.target).to_numpy(np.int64))
            self.test_protected = torch.tensor(test_df.pop(self.protected).to_numpy(np.int64))
            self.test_features = torch.tensor(test_df.to_numpy(np.float32))
        else:
            # split into test and train sets
            (
                train_features,
                self.test_features,
                train_protected,
                self.test_protected,
                train_target,
                self.test_target,
            ) = train_test_split(
                train_features,
                train_protected,
                train_target,
                test_size=self.hparams["test_size"],
                random_state=self.hparams["seed"],
            )

        # split into validation and train sets
        (
            self.train_features,
            self.val_features,
            self.train_protected,
            self.val_protected,
            self.train_target,
            self.val_target,
        ) = train_test_split(
            train_features,
            train_protected,
            train_target,
            test_size=self.hparams["validation_size"],
            random_state=self.hparams["seed"],
        )

    def train_dataloader(self):
        dataset = TensorDataset(self.train_features, self.train_protected, self.train_target)
        return DataLoader(
            dataset, self.hparams["batch_size"], num_workers=1, shuffle=True, drop_last=True
        )

    def val_dataloader(self):
        dataset = TensorDataset(self.val_features, self.val_protected, self.val_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)

    def test_dataloader(self):
        dataset = TensorDataset(self.test_features, self.test_protected, self.test_target)
        return DataLoader(dataset, self.hparams["batch_size"], num_workers=1, shuffle=False)