import numpy as np
import polars as pl
from pathlib import Path
import itertools
import best3 as b3
import shutil
repeats = 3
datasets = Path('datasets')
sushi_data_dir = Path('data') / 'sushi3-2016'
def load_order(fn):
    alternatives = []
    ratings = []
    customers = []
    with open(fn) as fd:
        x, y = map(int, fd.readline().split())
        assert y == 1
        for customer in itertools.count(0):
            if len(line := fd.readline()) == 0:
                break
            x, y, *z = map(int, line.split())
            assert x == 0
            assert len(z) == y
            for i, a in enumerate(z):
                customers.append(customer)
                ratings.append(-i)  # is this the right order?
                alternatives.append(a)

    return pl.DataFrame(
        {"customer": customers, "rating": ratings, "alternative": alternatives}
    )




def sushi_a(i):
    mdf = pl.DataFrame(
        {
            "alternative": np.arange(10),
            "name": [
                "ebi",
                "anago",
                "maguro",
                "ika",
                "uni",
                "ikura",
                "tamago",
                "toro",
                "tekke maki",
                "kappa maki",
            ],
        }
    )

    df = load_order(sushi_data_dir / "sushi3a.5000.10.order")

    b3.save_dataset(datasets / "sushi" / "sushi_a" / str(i), mdf, df, np.eye(len(mdf)))



def sushi_b(i):
    mdf = pl.read_csv(sushi_data_dir / "sushi3.idata", separator="\t", has_header=False)
    mdf
    mdf.columns = [
        "alternative",
        "name",
        "style",
        "major group",
        "minor group",
        "heaviness",
        "frequency",
        "price",
        "findability",
    ]

    df = load_order(sushi_data_dir / "sushi3b.5000.10.order")

    b3.save_dataset(datasets / "sushi"/ "sushi_b_onehot" / str(i), mdf, df, np.eye(len(mdf)))
    embed = []
    for col in [
        "style",
        "major group",
        "minor group",
        "heaviness",
        "frequency",
        "price",
        "findability",
    ]:
        c = mdf[col].to_numpy()
        if np.issubdtype(c.dtype, np.int_):
            x = np.zeros((len(c), c.max() + 1))
            x[np.arange(len(c)), c] = 1
            c = x
        if c.ndim == 1:
            c = c[:, None]
        embed.append(c)
    embed = np.concat(embed, axis=-1)

    b3.save_dataset(datasets / "sushi" / "sushi_b" / str(i), mdf, df, embed)

shutil.rmtree(datasets / "sushi", ignore_errors=True)
for i in range(repeats):
    sushi_a(i)
    sushi_b(i)
