import argparse
import os
import logging
import re
import string

import numpy as np
import pandas as pd
from scipy.sparse import diags
from sklearn.preprocessing import normalize
from pecos.utils import smat_util
from pecos.core import clib


logger = logging.getLogger(__name__)


def write_df_to_parquet_files(df_inp, output_dir, num_partitions=128):
    os.makedirs(output_dir, exist_ok=True)
    num_rows = len(df_inp)
    assert num_rows >= num_partitions
    for part_id, indices in enumerate(np.array_split(range(num_rows), num_partitions)):
        cur_df = df_inp.loc[indices]
        cur_path = os.path.join(output_dir, f"part-{part_id:05d}.parquet")
        cur_df.to_parquet(cur_path)


def build_dataset(data_name, inputs_dir, output_dir, is_lf_dataset=True):
    inputs_dir = os.path.join(inputs_dir, data_name)
    output_dir = os.path.join(output_dir, data_name)
    os.makedirs(output_dir, exist_ok=True)

    for split in ["trn", "tst"]:
        logger.info(f"Building parquet files for {split} set")
        X_txt = [line.strip() for line in open(f"{inputs_dir}/X.{split}.txt", "r")]
        Y_csr = smat_util.load_matrix(f"{inputs_dir}/Y.{split}.npz")
        assert len(X_txt) == Y_csr.shape[0]
        num_inputs = len(X_txt)
        logger.info(f"| constructing dataframe")
        
        y_w = np.squeeze(np.asarray(Y_csr.sum(axis=0),dtype=np.float32))
        y_w = diags(y_w, 0, format="csr")
        Y_csr = clib.sparse_matmul(Y_csr.astype(np.float32), y_w)
        Y_csr = smat_util.sorted_csr(Y_csr)
        df_out = pd.DataFrame.from_dict({
            "qid": [qid for qid in range(num_inputs)],
            "input": [txt for txt in X_txt],
            "pos_labels": [Y_csr.indices[Y_csr.indptr[i] : Y_csr.indptr[i + 1]] for i in range(num_inputs)],
            "neg_labels": [[] for _ in range(num_inputs)],
        })
        logger.info(f"| writing dataframe in partitions")
        write_df_to_parquet_files(df_out, f"{output_dir}/{split}")
    
    logger.info(f"Building parquet files for lbl set, is_lf_dataset={is_lf_dataset}")
    Y_csr = smat_util.load_matrix(f"{inputs_dir}/Y.trn.npz").tocsr().astype(np.float32)
    Y_csc = Y_csr.tocsc()
    Y_txt = [line.strip() for line in open(f"{inputs_dir}/Y.txt", "r", encoding="utf-8", errors="ignore")]
    assert len(Y_txt) == Y_csr.shape[1]
    num_labels = len(Y_txt)
    logger.info(f"| constructing dataframe")
    df_out = pd.DataFrame.from_dict({
        "lid": [lid for lid in range(num_labels)],
        "input": [txt if is_lf_dataset else "" for txt in Y_txt],
        "pos_trn_ids": [Y_csc.indices[Y_csc.indptr[l] : Y_csc.indptr[l + 1]] for l in range(num_labels)],
    })
    logger.info(f"| writing dataframe in partitions")
    write_df_to_parquet_files(df_out, f"{output_dir}/lbl")


def main():
    # logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )
    # parse argument
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data-name", "-d",
        type=str, required=True, help="The dataset name. E.g., LF-Amazon-131K",
    )
    parser.add_argument(
        "--inputs-dir", "-i",
        default="./datasets", type=str, required=False, help="<PATH_TO_INPUTS_DIR>/<DATA_NAME>/...",
    )
    parser.add_argument(
        "--output-dir", "-o",
        default="./proc_data", type=str, required=False, help="<PATH_TO_OUTPUT_DIR>/<DATA_NAME>/...",
    )
    args = parser.parse_args()
    logger.info(args)

    is_lf_dataset = True if "lf" in args.data_name.lower() else False
    build_dataset(args.data_name, args.inputs_dir, args.output_dir, is_lf_dataset)


if __name__ == "__main__":
   main()

