from pathlib import Path
import h5py
import polars as pl
import best3 as b3
import tmdbsimple as ts
import yaml
import shutil
import numpy as np
import os




if __name__ == "__main__":
    ml_dir = Path("datasets") / "ml"
    shutil.rmtree(ml_dir, ignore_errors=True)
    alternatives = pl.read_csv(
        "data/ml-32m/movies.csv",
    ).rename({'movieId': 'alternative'})

    choices = pl.read_csv("data/ml-32m/ratings.csv").rename({'movieId': 'alternative', 'userId': 'customer'})

    ratings = (
        choices.group_by(["alternative"])
        .len()
        .join(alternatives, "alternative")
        .sort("len", descending=True)
    )
    
    for rep in range(3):
        for i in [1, 10, 50]:
            dir = ml_dir / f"ml-{i}k" / str(rep)
            shutil.rmtree(dir, ignore_errors=True)
            dir.mkdir(exist_ok=True, parents=True)
            b3.save_dataset(dir, ratings.filter(pl.col("len") >= i * 1000), choices)
