from argparse import ArgumentParser
from collections import defaultdict
import functools
import operator as op
import os
import pickle
import sys
import time

from hccpy.hcc import HCCEngine
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm
from ruamel.yaml import YAML

from typing import Any, List, Optional

tqdm.pandas()
pandarallel.initialize(nb_workers=os.cpu_count(), progress_bar=True, use_memory_fs=False)

yaml = YAML(typ='safe', pure=True)

def get_operator_dict(df, boolify=False):
    op_dict = {
        k: ("sum" if k != "icds" else lambda x: functools.reduce(set.union, x)) \
        for k in df.columns if k != "BENE_ID"
    }
    return op_dict

class SAS7BDATClaimManager(object):
    def __init__(
        self,
        path: str,
        base_columns: List[str],
        denial_column: str,
        cost_column: str,
        dx_columns: List[str],
        denial_vals: Optional[List[Any]] = [],
        non_denial_vals: Optional[List[Any]] = [],
        deny_if_nonnull: Optional[bool] = False,
        chunksize: Optional[int] = 10000,
        primary_key: Optional[str] = 'BENE_ID',
        format: Optional[str] = "sas7bdat",
        cms_hcc_version: Optional[str] = "24",
        dx2cc_year: Optional[str] = "2019",
    ):
        self.path = path
        self.chunksize = chunksize
        self.format = format
        if self.format == "sas7bdat":
            self.iterator = pd.read_sas(path, chunksize=self.chunksize, iterator=True)
            self.dataset = None
        elif self.format == "csv":
            self.iterator = None
            self.dataset = pd.read_csv(path, low_memory=False, index_col=0)
        else:
            raise ValueError(f"Format {self.format} is invalid.")

        self.base_columns = base_columns
        self.denial_column = denial_column
        self.cost_column = cost_column
        self.dx_columns = dx_columns
        self.primary_key = primary_key

        self.relevant_cols = self.base_columns + [self.denial_column] + self.dx_columns
        if isinstance(cost_column, list):
            self.relevant_cols += cost_column
        elif cost_column is not None:
            self.relevant_cols.append(cost_column)

        if self.format == "csv":
            self.denial_vals = [v.decode("utf-8") for v in denial_vals]
            self.non_denial_vals = [v.decode("utf-8") for v in non_denial_vals]
        elif self.format == "sas7bdat":
            self.denial_vals = denial_vals
            self.non_denial_vals = non_denial_vals
        else:
            raise ValueError(f"Format {self.format} is invalid.")

        self.deny_if_nonnull = deny_if_nonnull

        self.lines_read = 0
        self.cms_hcc_version = cms_hcc_version
        self.dx2cc_year = dx2cc_year
        self.hcc_engine = HCCEngine(version=self.cms_hcc_version, dx2cc_year=self.dx2cc_year)
        self.valid_dx_set = self.hcc_engine.dx2cc.keys()
        self.valid_hccs = ["HCC" + str(k) for k in self.hcc_engine.label.keys()]

    def _rewind(self):
        if self.iterator is None:
            raise ValueError("_rewind() is not defined for non-iterator datasets")
        self.iterator._path_or_buf.seek(0)
        self.lines_read = 0

    def _pre_validate(self):
        print("Base columns:", self.base_columns)
        print("Denial column:", self.denial_column)
        print("Cost column:", self.cost_column)
        print("Dx columns:", self.dx_columns)
        print("# denial values:", len(self.denial_vals))
        print("# non-denial (acceptance) values:", len(self.non_denial_vals))
        print("# nonnull = deny:", self.deny_if_nonnull)
        if None in [self.base_columns, self.denial_column, self.dx_columns, self.denial_vals, self.deny_if_nonnull]: 
            return False
        if self.primary_key not in self.base_columns:
            print("Primary key", self.primary_key, "not in base columns:", self.base_columns)
            return False
        if len(self.denial_vals) and len(self.non_denial_vals):
            return False
        return True

    def _get_denial_mask(self, chunk):
        denial_mask = np.zeros_like(chunk[self.denial_column]).astype(bool)
        if self.deny_if_nonnull:
            denial_mask = ~chunk[self.denial_column].isna()
        else:
            if len(self.denial_vals):
                denial_mask = chunk[self.denial_column].isin(self.denial_vals)
            elif len(self.non_denial_vals):
                denial_mask = ~chunk[self.denial_column].isin(self.non_denial_vals)
        mask = denial_mask.astype(bool)
        return mask
        

    def process_chunk(self, filter_suffix: Optional[str] = "2"):

        def collate_bene_info(group):
            denial_mask = self._get_denial_mask(group)
            if self.cost_column is not None:
                cost = group.loc[~denial_mask, self.cost_column].values.sum()
            else:
                cost = 0.
            str_codes = group.loc[~denial_mask, self.dx_columns].values.astype(str).ravel()
            dxs = set(str_codes[str_codes != 'nan'])
            hcc_list_raw = self.hcc_engine.profile(dxs)["hcc_lst"]
            hccs = set(filter(lambda x: x.startswith("HCC"), hcc_list_raw))
            categories = pd.Categorical(hccs, categories=self.valid_hccs)
            k_hot_dummies = pd.get_dummies(categories).sum(axis=0).astype(bool) # chunk level = bool 
            return cost, dxs, *k_hot_dummies

        if self.iterator is not None: 
            df = next(self.iterator)
        else:
            df = self.dataset 
        if len(df) == 0:
            return pd.DataFrame()
        df = df[self.relevant_cols]
        
        if filter_suffix is not None: 
            df = df[df[self.primary_key].str.decode("utf-8").str.endswith(filter_suffix)]
        results = df.groupby(self.primary_key).parallel_apply(collate_bene_info)
        agg_df = pd.DataFrame(results.tolist(), index=results.index, columns=['cost', 'icds'] + self.valid_hccs).reset_index()
         
        self.lines_read += len(df)
        return agg_df

    def process_file(self, filter_suffix: Optional[str] = "2", max_lines: Optional[int] = 999999999): # HACK
 

        if not self._pre_validate(): raise ValueError("Some fields are null that shouldn't be. Check your class initialization.")
        
        if self.iterator is not None:
            pbar = tqdm(desc=f"Processing {self.__class__.__name__}", unit="ln", total=self.iterator.row_count)
            dfs = []
            try:
                while True:
                    hcc_df = self.process_chunk(filter_suffix=filter_suffix) # can we dispatch these async
                    dfs.append(hcc_df) 
                    pbar.update(self.chunksize)
                    if self.lines_read >= max_lines:
                        print(f"Read >={max_lines} records -- exiting.")
                        break
            except StopIteration:
                print("End of file reached.")

            op_dict = get_operator_dict(hcc_df) 
            final_df = pd.concat(dfs).groupby("BENE_ID").agg(op_dict).reset_index()
            return final_df
        else:
            df = self.process_chunk(filter_suffix=filter_suffix)
            return df

 
class DMEClaim(SAS7BDATClaimManager): 
# a little hacky, and can probably be yaml-ed -- this is copied directly from my prototyping ntoebook 
    def __init__(self, config, year=2019, chunksize=10000, fmt="sas7bdat"): 
        all_cols = next(pd.read_sas(config[year]["sas7bdat"]["dme"], chunksize=1)).columns
        super().__init__(
            config[year][fmt]["dme"],
            ["BENE_ID", "CLM_ID",  'CLM_FROM_DT', 'CLM_THRU_DT'], # base columns
            "CARR_CLM_PMT_DNL_CD", # denial col
            "CLM_PMT_AMT", # cost col 
            ['PRNCPAL_DGNS_CD'] + [c for c in all_cols if c.startswith("ICD") and "VRSN" not in c],
            [], # = rejected
            [b"1", b"2", b"3", b"4", b"5", b"6", b"7", b"8", b"9", b"A", b"B"], # = accepted
            False, # flag for whether non-null = deny
            chunksize,
            "BENE_ID",
            fmt,
        )

class HHAClaim(SAS7BDATClaimManager):
    def __init__(self, config, year=2019, chunksize=10000, fmt="sas7bdat"):
        all_cols = next(pd.read_sas(config[year]["sas7bdat"]["hha"], chunksize=1)).columns
        super().__init__(
            config[year][fmt]["hha"],
            ["BENE_ID", "CLM_ID",  'CLM_FROM_DT', 'CLM_THRU_DT'], # base columns
            "CLM_MDCR_NON_PMT_RSN_CD", 
            "CLM_PMT_AMT", # cost col 
            ['PRNCPAL_DGNS_CD'] + [c for c in all_cols if c.startswith("ICD") and "VRSN" not in c],
            [],
            [], # = accepted
            True,
            chunksize,
            "BENE_ID",
            fmt,
        )

class MedPARClaim(SAS7BDATClaimManager):
# Note that MedPAR = some inpatient (institutional) + SNF.
    def __init__(self, config, year=2019, chunksize=10000, fmt="sas7bdat"):
        all_cols = next(pd.read_sas(config[year]["sas7bdat"]["medpar"], chunksize=1)).columns
        super().__init__(
            config[year][fmt]["medpar"],
            ["BENE_ID", "MEDPAR_ID", 'ADMSN_DT', 'DSCHRG_DT'],
            "SS_LS_SNF_IND_CD", # placeholder -- medpar has no denied claims (https://resdac.org/sites/datadocumentation.resdac.org/files/Using%20Medicare%20Hospitalization%20Information%20and%20the%20MedPAR%20%28Slides%29.pdf). Update 3/28 -- but we still want to filter out SNF?
            ["MDCR_PMT_AMT", "PASS_THRU_AMT"],  # cost col 
            [c for c in all_cols if "DGNS_" in c and "POA" not in c and "VRSN" not in c and "CNT" not in c], 
            [b"N"],
            [], # = accepted
            False,
            chunksize,
            "BENE_ID",
            fmt,
        )

class OutpatientClaim(SAS7BDATClaimManager):
    def __init__(self, config, year=2019, chunksize=10000, fmt="sas7bdat"):
        all_cols = next(pd.read_sas(config[year]["sas7bdat"]["op"], chunksize=1)).columns
        super().__init__(
            config[year][fmt]["op"],
            ["BENE_ID", "CLM_ID", "CLM_FROM_DT", "CLM_THRU_DT"],
            "CLM_MDCR_NON_PMT_RSN_CD", 
            "CLM_PMT_AMT",  # cost col 
            ['PRNCPAL_DGNS_CD'] + [c for c in all_cols if c.startswith("ICD") and "VRSN" not in c and "PRCDR" not in c], 
            [],
            [], # = accepted
            True,
            chunksize,
            "BENE_ID",
            fmt,
        )

class PartBClaim(SAS7BDATClaimManager):  # AKA carrier
    def __init__(self, config, year=2019, chunksize=10000, fmt="sas7bdat"):
        all_cols = next(pd.read_sas(config[year]["sas7bdat"]["ptb"], chunksize=1)).columns # HACK
        super().__init__(
            config[year][fmt]["ptb"],
            ["BENE_ID", "CLM_ID", "CLM_FROM_DT", "CLM_THRU_DT"],
            "CARR_CLM_PMT_DNL_CD", 
            "CLM_PMT_AMT",  # cost col 
            ['PRNCPAL_DGNS_CD'] + [c for c in all_cols if c.startswith("ICD") and "VRSN" not in c], 
            [],
            [b"1", b"2", b"3", b"4", b"5", b"6", b"7", b"8", b"9", b"A", b"B"], # = accepted
            False,
            chunksize,
            "BENE_ID",
            fmt,
        )



def merge_dicts(d1, d2, operator, default=0.):
    keys = set(d1).union(d2)
    return dict((k, operator(d1.get(k, default),  d2.get(k, default))) for k in keys)

def get_args():
    psr = ArgumentParser()
    psr.add_argument("--name", required=True)
    psr.add_argument("--overwrite", action='store_true')
    psr.add_argument("--chunksize", type=int, default=10000)
    psr.add_argument("--max-lines", type=int, default=10000)
    psr.add_argument("--filter-suffix", type=str, default=None)
    psr.add_argument("--year", default=2019, choices=[2018, 2019, "2018-MA", "2019-MA"])
    psr.add_argument("--include-claims", type=str, nargs="+", default=["dme", "hha", "op", "medpar", "ptb"]) 
    psr.add_argument("--format", type=str, choices=["sas7bdat", "csv"], default="sas7bdat")
    return psr.parse_args()


CLAIM_DICT = {
    "dme": DMEClaim,
    "hha": HHAClaim,
    "medpar": MedPARClaim,
    "op": OutpatientClaim,
    "ptb": PartBClaim,
}

if __name__ == '__main__':
    args = get_args()
    save_dir = os.path.join("intermediate", args.name)
    if os.path.isdir(save_dir) and not args.overwrite:
        raise ValueError(f"Data directory with name {args.name} already exists and --overwrite flag is not set. Exiting.")
    print("Creating claim readers...")

    with open("./config/pathspec.yml", "r") as f:
        cfg = yaml.load(f)

    claims = [CLAIM_DICT[claim_type](cfg, year=args.year, chunksize=args.chunksize, fmt=args.format) for claim_type in args.include_claims if claim_type in CLAIM_DICT]
    print("Claim types to include:", args.include_claims)
        
    os.makedirs(save_dir, exist_ok=True)
    dfs = []
    for claim in claims:
        claim_name = claim.__class__.__name__
        print("Processing", claim_name)
        hcc_df = claim.process_file(max_lines=args.max_lines, filter_suffix=args.filter_suffix)
        if len(hcc_df) != 0:
            
            avg_cost = hcc_df["cost"].mean()
            print("Average cost/beneficiary,", claim_name + ":", avg_cost)
            nnz_cost = (hcc_df["cost"] > 0).mean()
            print("% of beneficiaries with non-zero cost:", nnz_cost)

            hcc_cols = [c for c in hcc_df.columns if c.startswith("HCC")]
            n_hcc = hcc_df[hcc_cols].sum(axis=1)
            print("# HCCs:", n_hcc.value_counts().sort_index(ascending=True))  
            if hcc_df["cost"].mean() == 0:
                print("Claim type", claim_name, "returned zero cost. Consider removing from future processing runs.")
        else:
            print("Note: DataFrame empty. Double check the original data file.")
        print("Caching intermediate data:")

        sub_df_path = os.path.join(save_dir, f"_staging_{claim_name}.csv")
        hcc_df.to_csv(sub_df_path)
        dfs.append(hcc_df)
        print()
   
    print("Concatenating data...") 
    op_dict = get_operator_dict(hcc_df) 
    final_df = pd.concat(dfs).groupby("BENE_ID").agg(op_dict).reset_index()
    final_df[hcc_cols] = final_df[hcc_cols].astype(bool).astype(int)
    df_path = os.path.join(save_dir, "data.csv") # they might overwrite each other here. Let's do a thing with only the staging DFs... 
    
    print("Saving data...")
    final_df.to_csv(df_path)
    print("Data successfully saved to", df_path)
    

    

