#%%
import torch

import numpy as np
import pandas as pd

import random

#%%
"""for reproducibility"""
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    np.random.seed(seed)
    random.seed(seed)   


#%%
def undummify(imputed, prefix_sep="###"):
    cols2collapse = {
        col.split(prefix_sep)[0]: (prefix_sep in col) for col in imputed.columns
    }

    series_list = []
    for col, needs_to_collapse in cols2collapse.items():
        if needs_to_collapse:
            undummified = (
                imputed.filter(like=f"{col}###") # duplication column name
                .idxmax(axis=1)
                .apply(lambda x: x.split(prefix_sep, maxsplit=1)[1])
                .rename(col)
            )
            series_list.append(undummified.astype(float))
        else:
            series_list.append(imputed[col])
    undummified_df = pd.concat(series_list, axis=1)
    return undummified_df 