import pandas as pd
import os
import glob

import typing as T
from tqdm import tqdm
from data.CD_processing.utils import transform_normalise_biomarker_features, unify_units, split_data, get_biom_features, get_static_features

def read_data(directory: T.Union[str, T.List], splintered_source:bool=True):
    print("Loading Data")

    print("     Running single process")
    if splintered_source:
        csv_files = glob.glob(os.path.join(directory, "*.csv"))
        dataframes = [pd.read_csv(file) for i, file in enumerate(csv_files)]
        return pd.concat(dataframes, ignore_index=True)

    return pd.read_csv(directory)

def pivot_data_and_save(
        data: pd.DataFrame, 
        save_dir: str, 
        only_prediagnostic=True
    ):
    print()
    print("Starting processing of data")
    metadata_path = f"{save_dir}/metadata"

    data['samplingdate'] = pd.to_datetime(data['samplingdate'], errors='coerce')
    data['prediag_token_date'] = pd.to_datetime(data['prediag_token_date'], errors='coerce')
    data['date_of_birth'] = pd.to_datetime(data['date_of_birth'], errors='coerce')

    data = data.dropna(subset=['date_of_birth'])
    data['age_at_sampling_in_days'] = (data['samplingdate'] - data['date_of_birth']).dt.days

    if only_prediagnostic:
        patients_data = data[data["dataset"] == "patient"]
        patients_data = patients_data[patients_data["samplingdate"] <= patients_data["prediag_token_date"]]
        controls_data = data[data["dataset"] == "control"]
        controls_data = controls_data[controls_data["age_at_sampling_in_days"] <= 14000] # downsample by age
        
        data = pd.concat([patients_data, controls_data], axis=0)

    data = unify_units(data)
    data = data[data['sex'].isin(['k', 'm'])]
    data['sex'] = data['sex'].map({'k': 0, 'm': 1})

    print("Saving Metadata ...")

    pd.Series(data["analysiscode"].unique()).to_csv(f"{metadata_path}/unique_analysiscode.csv", index=False, header=True)
    pd.Series(data["unit"].unique()).to_csv(f"{metadata_path}/unique_units.csv", index=False, header=True)
    pd.Series(data["laboratorium_idcode"].unique()).to_csv(f"{metadata_path}/unique_lab_ids.csv", index=False, header=True)
    
    train_df, val_df, test_df = split_data(data)

    train_biom_features = get_biom_features(train_df)
    val_biom_features = get_biom_features(val_df)
    test_biom_features = get_biom_features(test_df)

    train_biom_features, val_biom_features, test_biom_features = transform_normalise_biomarker_features(train_biom_features, val_biom_features, test_biom_features)

    combined_train = pd.concat([train_biom_features, *get_static_features(train_df)], axis=1)
    combined_val = pd.concat([val_biom_features, *get_static_features(val_df)], axis=1)
    combined_test = pd.concat([test_biom_features, *get_static_features(test_df)], axis=1)

    for (combined, set_name) in [(combined_train, "train"), (combined_val, "val"), (combined_test, "test")]:
        print(f"Saving {set_name} data ...")
        try:
            combined = combined.reset_index()
        except Exception as e:
            if 'lbnr' in combined.columns:
                combined.drop('lbnr', axis=1, inplace=True)
            if 'samplingdate' in combined.columns:
                combined.drop('samplingdate', axis=1, inplace=True)
            combined = combined.reset_index()
    
        ids = combined["lbnr"].unique()

        save_path = f"{save_dir}/{set_name}"

        for id in tqdm(ids):
            batch = combined[combined["lbnr"] == id]
            sample_type = batch.iloc[0]["dataset"]
            full_dir = f"{save_path}__CD__{sample_type}"
            os.makedirs(full_dir, exist_ok=True)
            if batch.shape[0] > 1 and batch.shape[0] < 50:
                batch.to_csv(f"{full_dir}/{id}.csv")

def cut_trajectory_history_and_save(trajectories_folder, save_folder, history_cutoff):
    for f in tqdm(os.listdir(trajectories_folder)):
        df = pd.read_csv(os.path.join(trajectories_folder, f))
        df["samplingdate"] = pd.to_datetime(df["samplingdate"])
        df['prediag_token_date'] = pd.to_datetime(df['prediag_token_date'])

        p_id = df['lbnr'].iloc[0]
        if 'patient' in trajectories_folder:
            df = df[df["samplingdate"] <= df["prediag_token_date"] - pd.DateOffset(years=history_cutoff)]
    
        elif 'control' in trajectories_folder:
            reference_date = df["samplingdate"].max()
            df = df[df["samplingdate"] <= reference_date - pd.DateOffset(years=history_cutoff)]
    
        if df.shape[0] > 2:
            df.to_csv(os.path.join(save_folder, f"{p_id}.csv"), index=False)
