from argparse import ArgumentParser
import ast
from functools import partial
import glob
import os

import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm

from data_model import CLAIM_DICT, get_operator_dict
from hccpy.hcc import HCCEngine

pandarallel.initialize(nb_workers=os.cpu_count(), progress_bar=True, use_memory_fs=False)

def recode_hccs(hcc_engine, valid_hccs, row):
    hcc_list_raw = hcc_engine.profile(row["icds"])["hcc_lst"]
    hccs = set(filter(lambda x: x.startswith("HCC"), hcc_list_raw))
    categories = pd.Categorical(hccs, categories=valid_hccs)
    k_hot_dummies = pd.get_dummies(categories).sum(axis=0).astype(bool).astype(int) # chunk level = bool 
    return k_hot_dummies

if __name__ == '__main__':
    psr = ArgumentParser()
    psr.add_argument("--stage-dir", type=str, required=True)
    psr.add_argument("--include-claims", type=str, nargs="+", default=["dme", "hha", "medpar", "op", "ptb"])
    psr.add_argument("--save-name", type=str, default="data.csv")
    psr.add_argument('--year', type=str, default="2019")
    psr.add_argument('--cms-hcc-version', default="24")
    args = psr.parse_args()

    final_path = os.path.join(args.stage_dir, args.save_name)
    if os.path.isfile(final_path):
        raise ValueError(f"File already exists at {final_path}")
    
    files = [os.path.join(args.stage_dir, "_staging_" + CLAIM_DICT[claim_type].__name__ + ".csv") for claim_type in args.include_claims]

    print("Merging files:")
    print(*files, sep="\n", end="\n\n")
    dfs = [pd.read_csv(f, low_memory=False, index_col=0) for f in files]

    op_dict = get_operator_dict(dfs[0])
    final_df = pd.concat(dfs)
    final_df.loc[:, "icds"] = final_df["icds"].progress_apply(ast.literal_eval)
    final_df = final_df.groupby("BENE_ID").agg(op_dict).reset_index() # don't repeat yourself -- refactor later

    hcc_cols = [c for c in final_df.columns if c.startswith("HCC")]
    # reapply hierarchies when merging

    hcc_engine = HCCEngine(version=args.cms_hcc_version, dx2cc_year=args.year)
    valid_hccs = ["HCC" + str(k) for k in hcc_engine.label.keys()]
    final_df[hcc_cols] = final_df.parallel_apply(partial(recode_hccs, hcc_engine, valid_hccs), axis=1)

    df_path = os.path.join(args.stage_dir, args.save_name)
    print("Saving merged DF...")
    final_df.to_csv(df_path)
    print("Data successfully saved to", df_path) 


   
