import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

DATA_FP = "../beta_similarity/Apply_Rate_2019.csv"

SEED = 42
MAX_NUM_CLASSES = 10  # None for all classes included


def clean_data(df, truncate_by_class_id=None):
    df.dropna(inplace=True)
    df.drop(columns=["search_date_pacific"], inplace=True)

    class_id = df["class_id"].values

    unique_cid, cid_freq = np.unique(class_id, return_counts=True)
    freq_index = np.argsort(cid_freq)[::-1]
    unique_cid, cid_freq = unique_cid[freq_index], cid_freq[freq_index]

    unique_cid = unique_cid[:truncate_by_class_id]
    df = df.iloc[np.isin(class_id, unique_cid)]

    return df, unique_cid, cid_freq


if __name__ == '__main__':
    np.random.seed(SEED)
    df = pd.read_csv(DATA_FP)
    df, unique_gid, gid_freq = clean_data(df, truncate_by_class_id=MAX_NUM_CLASSES)
    df, df_test = train_test_split(df, test_size=0.2)

    print(f"{df.shape=}")


