from argparse import ArgumentParser
from pathlib import Path

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import LongType, FloatType, TimestampType

from common import cat_freq, collect_lists, train_test_split


CAT_FEATURES = [
    "B_30",
    "B_38",
    "D_114",
    "D_116",
    "D_117",
    "D_120",
    "D_126",
    "D_63",
    "D_64",
    "D_66",
    "D_68",
]
INDEX_COLUMNS = [
    "customer_ID",
    "target",
]
ORDERING_COLUMNS = [
    "S_2",
]
TEST_FRACTION = 0.2


def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--data-path",
        help="Path to directory containing CSV files",
        required=True,
        type=Path,
    )
    parser.add_argument(
        "--save-path",
        help="Where to save preprocessed parquets",
        required=True,
        type=Path,
    )
    parser.add_argument(
        "--kaggle-split",
        help="Whether to preprocess Kaggle train set, test set or their union",
        choices=["train", "test", "union"],
        required=True,
    )
    parser.add_argument(
        "--cat-codes-path",
        help="Path where to save codes for categorical features",
        type=Path,
    )
    parser.add_argument(
        "--split-seed",
        help="Random seed used to split the data on train and test",
        default=0,
        type=int,
    )
    parser.add_argument(
        "--train-partitions",
        help="Number of parquet partitions for train dataset",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--test-partitions",
        help="Number of parquet partitions for test dataset",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--overwrite",
        help='Toggle "overwrite" mode on all spark writes',
        action="store_true",
    )
    args = parser.parse_args()
    mode = "overwrite" if args.overwrite else "error"

    spark = SparkSession.builder.master("local[32]").getOrCreate()  # pyright: ignore
    df, df_kag_train, df_kag_test = None, None, None

    if args.kaggle_split in ("train", "union"):
        df_kag_train = spark.read.csv(
            (args.data_path / "train_data.csv").as_posix(), header=True
        )

        all_columns = set(df_kag_train.columns)
        cat_features = set(CAT_FEATURES)
        num_features = all_columns - cat_features - {"S_2", "customer_ID"}
        cast = (
            [F.col("customer_ID"), F.col("S_2").cast(TimestampType())]
            + [F.col(c) for c in cat_features]
            + [F.col(c).cast(FloatType()) for c in num_features]
        )
        df_kag_train = df_kag_train.select(*cast)

        df_label = spark.read.csv(
            (args.data_path / "train_labels.csv").as_posix(), header=True
        ).withColumn("target", F.col("target").cast(LongType()))

        df_kag_train = df_kag_train.join(df_label, on="customer_ID")

    if args.kaggle_split in ("test", "union"):
        df_kag_test = spark.read.csv(
            (args.data_path / "test_data.csv").as_posix(), header=True
        )

        all_columns = set(df_kag_test.columns)
        cat_features = set(CAT_FEATURES)
        num_features = all_columns - cat_features - {"S_2", "customer_ID"}
        cast = (
            [F.col("customer_ID"), F.col("S_2").cast(TimestampType())]
            + [F.col(c) for c in cat_features]
            + [F.col(c).cast(FloatType()) for c in num_features]
        )
        df_kag_test = df_kag_test.select(*cast)

    if df_kag_train is not None and df_kag_test is not None:
        df_kag_test = df_kag_test.withColumn("target", F.lit(None).cast(LongType()))
        df = df_kag_train.union(df_kag_test)
    elif df_kag_train is not None:
        df = df_kag_train
    elif df_kag_test is not None:
        df = df_kag_test
    else:
        raise ValueError("Something went wrong, train and test are None")

    vcs = cat_freq(df, CAT_FEATURES)
    for vc in vcs:
        df = vc.encode(df)
        if args.cat_codes_path is not None:
            vc.write(args.cat_codes_path / vc.feature_name, mode=mode)

    df = collect_lists(
        df,
        group_by=INDEX_COLUMNS,
        order_by=ORDERING_COLUMNS,
    )

    stratify_col, stratify_col_vals = None, None
    if df_kag_train is not None:  # target has non-null values
        stratify_col = "target"
        stratify_col_vals = [0, 1]

    # stratified splitting on train and test
    train_df, test_df = train_test_split(
        df=df,
        test_frac=TEST_FRACTION,
        index_col="customer_ID",
        stratify_col=stratify_col,
        stratify_col_vals=stratify_col_vals,
        random_seed=args.split_seed,
    )

    train_df.repartition(args.train_partitions).write.parquet(
        (args.save_path / "train").as_posix(), mode=mode
    )
    test_df.repartition(args.test_partitions).write.parquet(
        (args.save_path / "test").as_posix(), mode=mode
    )


if __name__ == "__main__":
    main()
