from pathlib import Path
import pandas as pd
import json

CLF_DIR_1 = Path(__file__).parents[2] / "data/tabular/processed/classification"
CLF_DIR_2 = (
    Path(__file__).parents[2]
    / "data/tabular/tables_with_linked_entities/classification"
)
PROCESSED_DIR = Path(__file__).parents[2] / "data/tabular/processed"


def balance_classification_data(df, target_col):
    class_counts = df[target_col].value_counts()
    min_class_count = class_counts.min()
    # Subsample min_class_count from each class
    balanced_df = (
        df.groupby(target_col)
        .apply(lambda x: x.sample(min_class_count))
        .reset_index(drop=True)
    )
    balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)  # shuffle
    return balanced_df


def main():
    # Get datasets configurations
    with open(PROCESSED_DIR / "data_configs.json", "r") as f:
        configs = json.load(f)

    print("Balancing classification datasets...")
    for file in CLF_DIR_1.iterdir():
        print(f"Processing {file.stem}...")
        df = pd.read_parquet(file)
        target_col = configs[file.stem]["target"]
        balanced_df = balance_classification_data(df, target_col)
        if len(balanced_df) > 1050:
            print(
                f"Keeping {file.stem}. Balanced dataset has {len(balanced_df)} samples."
            )
            balanced_df.to_parquet(file)
        else:
            print(
                f"Deleting {file.stem}. Balanced dataset has only {len(balanced_df)} samples."
            )
            file.unlink()

    print("Balancing linked classification datasets...")
    for file in CLF_DIR_2.iterdir():
        print(f"Processing {file.stem}...")
        df = pd.read_parquet(file)
        target_col = configs[file.stem]["target"]
        balanced_df = balance_classification_data(df, target_col)
        if len(balanced_df) > 1050:
            print(
                f"Keeping {file.stem}. Balanced dataset has {len(balanced_df)} samples."
            )
            balanced_df.to_parquet(file)
        else:
            print(
                f"Deleting {file.stem}. Balanced dataset has only {len(balanced_df)} samples."
            )
            file.unlink()
    return


if __name__ == "__main__":
    main()
