from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from shutil import rmtree
import shutil
import tarfile
import tempfile
import h5py
import numpy as np
import polars as pl
import best3 as b3
from tqdm import tqdm

nf_data_dir = Path("data") / "nf_prize_dataset"
patch_year = {
    "Ancient Civilizations: Athens and Greece": 2001,
    "Ancient Civilizations: Land of the Pharaohs": 2001,
    "Ancient Civilizations: Rome and Pompeii": 2001,
    "Roti Kapada Aur Makaan": 1974,
    "Hote Hote Pyaar Ho Gaya": 1999,
    "Jimmy Hollywood": 1994,
    "Eros Dance Dhamaka": 1999,
}
to_rem = {12539, 12580, 2284}

def load_rating(path: Path):
    with path.open(encoding="latin-1") as fd:
        lines = [i.strip() for i in fd.readlines()]
    movie_, k = lines[0].split(":")
    assert len(k) == 0
    movie = int(movie_)
    customers = list[str]()
    ratings = list[str]()
    dates = list[str]()

    for line in lines[1:]:
        customer, rating, date = line.split(",")
        customers.append(customer)
        ratings.append(rating)
        dates.append(date)

    df = pl.DataFrame({"customer": customers, "rating": ratings, "date": dates})

    return df.with_columns(
        pl.col("date").str.to_date("%Y-%m-%d"),
        pl.col("customer").str.to_integer(),
        pl.col("rating").str.to_integer().cast(pl.Int8),
        pl.lit(movie).alias("alternative"),
    )


def load_movies():
    movies, years, names = [], [], []
    with open(nf_data_dir / "movie_titles.txt", encoding="latin-1") as fd:
        assert fd is not None
        for line in fd.readlines():
            movie, year, *name_ = line.strip().split(",")
            if movie in to_rem:
                continue
            name = ",".join(name_)
            if year == "NULL":
                year = str(patch_year[name])
            movies.append(movie)
            years.append(year)
            names.append(name)

    return pl.DataFrame(
        {
            "alternative": movies,
            "year": years,
            "movie": names,
            "name": [f"{n} ({y})" for (n, y) in zip(names, years, strict=True)],
        }
    ).with_columns(
        pl.col("year").str.to_integer().cast(pl.Int16),
        pl.col("alternative").str.to_integer(),
    )



if __name__ == "__main__":
    nf_dir = Path('datasets') / 'nf'
    shutil.rmtree(nf_dir, ignore_errors=True)
    nf_dir.mkdir(exist_ok=True, parents=True)
    movies = load_movies()
    print(movies)

    tmp_dir = Path(tempfile.mkdtemp())
    try:
        with tarfile.TarFile.open(nf_data_dir / "training_set.tar") as tar:
            tar.extractall(tmp_dir)
        with ThreadPoolExecutor(min(b3.cpu_count(), 8)) as executor:
            files = list((tmp_dir / "training_set").rglob("mv_*.txt"))
            choices = pl.concat(
                list(tqdm(executor.map(load_rating, files), total=len(files)))
            ).sort("customer")
            print(choices)

    finally:
        rmtree(tmp_dir)

    ratings = (
        choices.group_by(["alternative"])
        .len()
        .join(movies, "alternative")
        .sort("len", descending=True)
    )

    ratings = (
        choices.group_by(["alternative"])
        .len()
        .join(movies, "alternative")
        .sort("len", descending=True)
    )

    for rep in range(3):
        for i in [150, 100, 10]:
            dir = nf_dir / f"nf-{i}k" / str(rep)
            shutil.rmtree(dir, ignore_errors=True)
            dir.mkdir(exist_ok=True, parents=True)
            print(dir)
            b3.save_dataset(dir, ratings.filter(pl.col("len") >= i * 1000), choices)
