import os
import pandas as pd
import numpy as np
import multiprocessing
from tqdm import tqdm
from constants import OUTPUT_DIR


def verify_dataset(data_path: str) -> tuple[int, int, int]:
    """Verify the dataset.

    Args:
        data_path (str): The path to the dataset.

    Returns:
        tuple[int, int, int]: The number of successful, valid, and total number of datasets.
    """
    try:
        df = pd.read_csv(data_path)
        # Delete label column
        df = df.iloc[:, :-1]

        assert len(df.shape) == 2, "Data must be 2D"

        array = df.values.astype(np.float64)

        successful_num, valid_num = 0, 0
        total_num = array.shape[-1]
        for j in range(array.shape[-1]):
            sorted_unique = sorted(array[:, j])
            diff = np.diff(sorted_unique)

            # All the data in the column are the same
            if len(diff) == 0:
                continue

            if np.any(np.isnan(diff)):
                continue

            # If there exists same values in the column, set the ratio to infinity
            # This theoretically happens with probability 0 for continuous data
            if diff.min() == 0:
                ratio = float("inf")
            else:
                ratio = diff.max() / diff.min()

            successful_num += ratio > (len(array) + 3) ** 0.5
            valid_num += 1
        return successful_num, valid_num, total_num
    except Exception:
        return 0, 0, 0


def main():
    data_list = os.listdir(OUTPUT_DIR)
    data_path_list = [os.path.join(OUTPUT_DIR, name) for name in data_list]

    successful_num, valid_num, total_num = 0, 0, 0

    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        with tqdm(total=len(data_list)) as p:
            for s, v, t in pool.imap_unordered(verify_dataset, data_path_list):
                successful_num += s
                valid_num += v
                total_num += t
                p.update(1)

    print(f"Successful: {successful_num}, Valid: {valid_num}, Total: {total_num}")


if __name__ == "__main__":
    main()
