import os

import pandas as pd
from datasets import DatasetDict, Dataset, load_dataset, load_from_disk


def generate(data, versions_per_id=250, id='formula_name_id', factor_false=10, split=(0.7, 0.15, 0.15), seed=0):
    try:
        dsd = load_dataset(data)
    except Exception:
        dsd = load_from_disk('file://' + data)

    new_dsd = DatasetDict()

    # create train data
    key = 'train'
    ds = dsd[key]
    df = ds.to_pandas()
    split_sum = sum(split)
    if abs(split_sum - 1) > 0.01:
        raise ValueError(f"Split values must sum to 1. Current sum: {split_sum}", split)
    test_versions_per_id = int(versions_per_id * split[1])
    val_versions_per_id = int(versions_per_id * split[2])
    train_versions_per_id = versions_per_id - test_versions_per_id - val_versions_per_id
    reduced_true_df = df[df['label']].sample(frac=1, random_state=seed).groupby(id).head(train_versions_per_id)
    reduced_false_df= df[~df['label']].sample(frac=1, random_state=seed).groupby(id).head(factor_false * train_versions_per_id)
    reduced_pd = pd.concat([reduced_true_df, reduced_false_df]).sample(frac=1, random_state=seed).reset_index(drop=True)
    new_dsd[key] = Dataset.from_pandas(reduced_pd)

    # create test data
    key = 'test'
    ds = dsd[key]
    df = ds.to_pandas()
    reduced_true_df = df[df['label']].sample(frac=1, random_state=seed).groupby(id).head(test_versions_per_id)
    reduced_false_df = df[~df['label']].sample(frac=1, random_state=seed).groupby(id).head(factor_false * test_versions_per_id)
    reduced_pd = pd.concat([reduced_true_df, reduced_false_df]).sample(frac=1, random_state=seed).reset_index(drop=True)
    new_dsd[key] = Dataset.from_pandas(reduced_pd)

    # create validation data
    key = 'validation'
    reduced_true_df = df[df['label'] & ~df.index.isin(reduced_true_df.index)].sample(frac=1, random_state=seed).groupby(id).head(val_versions_per_id)
    reduced_false_df = df[~df['label'] & ~df.index.isin(reduced_false_df.index)].sample(frac=1, random_state=seed).groupby(id).head(factor_false * val_versions_per_id)
    reduced_pd = pd.concat([reduced_true_df, reduced_false_df]).sample(frac=1, random_state=seed).reset_index(drop=True)
    new_dsd[key] = Dataset.from_pandas(reduced_pd)

    output = data + '_' + str(versions_per_id)
    os.makedirs(output, exist_ok=True)
    new_dsd.save_to_disk(output)

    print(f"Dataset {data} saved to {output}")
    # print stats
    for key in new_dsd.keys():
        ds = new_dsd[key]
        df = ds.to_pandas()
        print(f"Dataset {key} has {len(df)} rows")
        print(f"Dataset {key} has {df['label'].sum()} true and {len(df) - df['label'].sum()} false labels")

#generate('anonymized')
generate('nmf_full_merged')
#generate('anonymized', id='formula1_name_id')