from argparse import ArgumentParser
import os
import pprint
import re

import pandas as pd
from ruamel.yaml import YAML

DATASET_PATHSPEC = "./config/data_pathspec.yml"
ALL_US_STATES = "../List-of-US-States/states.csv" # sourced from https://github.com/jasonong/List-of-US-States/blob/master/states.csv
yaml = YAML(typ='safe')
yaml.default_flow_style = False

if __name__ == '__main__':
    psr = ArgumentParser()
    psr.add_argument("--config", required=True, type=str)
    psr.add_argument("--dataset", required=True, type=str)
    psr.add_argument("--overwrite", action="store_true")
    args = psr.parse_args()

    with open(DATASET_PATHSPEC, "r") as f:
        data_cfg = yaml.load(f)
    data_path = data_cfg[args.dataset]["raw"]
    new_data_path = data_path.replace(".csv", "_mapped.csv")
    if os.path.isfile(new_data_path) and not args.overwrite:
        raise ValueError("File exists")
    
    print("Reading DataFrame...")
    df = pd.read_csv(data_path, index_col=0, low_memory=False)
    with open(args.config, "r") as f:
        cfg = yaml.load(f)
    if cfg["treatment_filter"]:
        states = pd.read_csv(ALL_US_STATES)
        df = df.loc[df["STATE_NAME_STR"].isin(states["State"])]
    


    print("Remapping columns...")
    x_cols = []
    for reg_str in cfg["covariate_regexp"]:
        regexp = re.compile(reg_str)
        x_cols += [c for c in df.columns if regexp.search(c)]
    
    x_cols += cfg["covariate_cols"]  
    y_col = cfg["outcome_col"]
    t_col = cfg["treatment_col"]
    other_cols = cfg["other_cols"]

    for dummy_col in cfg["dummy_covariates"]:
        dummy_df = pd.get_dummies(df[dummy_col], prefix=dummy_col)
        x_cols += dummy_df.columns.tolist()
        df = pd.concat([df.drop(dummy_col, axis=1), dummy_df], axis=1) 

    print("Final covariate columns:", x_cols)
    print("Treatment column:", t_col)
    print("Outcome column:", y_col)
    x_col_dict = {x_cols[i]: f"x{i}" for i in range(len(x_cols))}
    y_col_dict = {y_col: "d_obs"}
    t_col_dict = {t_col: "t_raw"}

    alpha_treatments = sorted(df[t_col].unique())
    t_ind_mapping = {alpha_treatments[i]: i for i in range(len(alpha_treatments))}
    df["t"] = df[t_col].map(t_ind_mapping)

    new_df = df.loc[:, other_cols + x_cols + [y_col, t_col, "t"]]
    old_columns = new_df.columns
    new_df = new_df.rename(columns={**x_col_dict, **y_col_dict, **t_col_dict})
    new_columns = new_df.columns
    colmap = dict(zip(old_columns, new_columns))
    print("Final column map:")
    pprint.pprint(colmap)
    new_df.to_csv(new_data_path)

    yaml_path = data_path.replace(".csv", "_colmap.yaml")
    with open(yaml_path, 'wb') as f:
        yaml.dump({
            "column_mapping": colmap,
            "treatment_mapping": t_ind_mapping
        }, f)
    print("Saved data to", new_data_path, "and column mapping to", yaml_path)

