import pandas as pd
from tqdm import tqdm


def run_pipeline(steps=[]):
    print("Starting pipeline ...")
    print()

    raw_data_save_path=""
    
    sequential_data_save_path = ""
    data = None

    if "fetch_data" in steps:
        print("Starting step: fetch_data")
        print()
        data = fetch_data(raw_data_save_path)

    if "create_trajectories" in steps:
        print("Starting step: create_trajectories")
        print()
        create_trajectories(data, sequential_data_save_path, raw_data_save_path=raw_data_save_path)

    if "cut_trajectories" in steps:
        print("Starting step: cut_trajectories")
        print()
        cut_trajectories(sequential_data_save_path)

    print()
    print("Pipeline Completed ...")

def fetch_data(raw_data_save_path, return_data=True):
    from data.CD_processing.fetch_data import query_data
    from data.CD_processing.queries.cd_vs_control_17_bioms import cd_vs_control_17_bioms

    connection_string=""
    
    data = query_data(connection_string, raw_data_save_path, db_query=cd_vs_control_17_bioms)
    if return_data:
        return data

    else: return True

def create_trajectories(raw_data, sequential_data_save_path, raw_data_save_path=None):
    from data.CD_processing.process_data import pivot_data_and_save

    raw_data = pd.read_csv(f"{raw_data_save_path}/raw_data.csv")

    pivot_data_and_save(
        data=raw_data, 
        save_dir=sequential_data_save_path, 
        ffill=False,
        split_by_type=True,
        only_prediagnostic=True
    )

    return True 

def cut_trajectories(sequential_data_save_path):
    from data.CD_processing.process_data import cut_trajectory_history_and_save
    sequential_data_save_path=""
    save_root = ""

    for split in tqdm(["train", "val", "test"]):
        print(f"Cutting: {split}")
        for label in ["control", "patient"]:
            trajectories_folder = f"{sequential_data_save_path}/{split}__CD__{label}"
            save_folder = f"{save_root}/{split}__CD__{label}"
    
            cut_trajectory_history_and_save(
                trajectories_folder=trajectories_folder, 
                save_folder=save_folder, 
                max_history_in_years=1
            )



if __name__ == '__main__':
    import warnings
    warnings.simplefilter(action='ignore', category=FutureWarning)
    warnings.simplefilter(action='ignore', category=UserWarning)

    steps = ["cut_trajectories"]
    run_pipeline(steps)
