import os
from pathlib import Path
import shutil

import h5py
import numpy as np
import vllm
import best3 as b3
import tmdbsimple as ts
import yaml
from requests import HTTPError
import time
import polars as pl
from functools import cache
from tqdm.auto import tqdm

ts.API_KEY = os.getenv("TMDB")
last_call = 0


def wait(f):
    def inner(*args, **kwargs):
        time.sleep(0.01)
        return f(*args, **kwargs)

    return inner


search = ts.Search()
search_movie = cache(wait(search.movie))
search_tv = cache(wait(search.tv))


def search(movie, year, tv=False):
    if tv:
        res = search_tv(query=movie, include_adult=True, year=year)["results"]
    else:
        res = search_movie(query=movie, include_adult=True, year=year)["results"]
    if len(res) == 0:
        if not tv:
            return search(movie, year, True)
        else:
            return (0, False)
    return res[0]["id"], tv


@cache
def fetch_tmdb(key, tv):
    global last_call
    time.sleep(max(0, 0.01 - (time.time() - last_call)))

    if tv:
        tv = ts.TV(key)
        res = tv.info()
    else:
        try:
            movie = ts.Movies(key)
            res = movie.info()
        except HTTPError:
            return fetch_tmdb(key, True)

    last_call = time.time()
    return yaml.safe_dump(res)


if __name__ == "__main__":
    model = vllm.LLM("Qwen/Qwen3-Embedding-0.6B", task="embed")
    ds = Path("datasets")
    data = Path("data")
    im_llm = ds / "ml_llm"

    replace = [
        (108727, 258216, False),
        (131724, 61929, True),
        (142115, 13579, True),
        (159817, 1044, True),
        (163809, 61617, True),
        (169906, 66276, True),
        (170355, 1018, False),
        (170705, 4613, True),
        (171011, 68595, True),
        (179135, 74313, True),
        (224264, 106525, True),
        (4207, 427910, False),
        (66934, 14301, False),
        (7669, 1457, True),
        (7842, 19566, True),
        (90647, 417859, False),
        (122926, 315635, False),
    ]

    replace_df = pl.DataFrame(
        {"movieId": [i[0] for i in replace], "tmdbId": [i[1] for i in replace], "tv": [i[2] for i in replace]}
    )
    links = pl.read_csv(data / "ml-32m" / "links.csv").with_columns(tv=False).update(replace_df, on="movieId")

    shutil.rmtree(im_llm, ignore_errors=True)

    shutil.copytree(ds / "ml", im_llm)
    for em in im_llm.rglob("embed.hdf5"):
        df = pl.DataFrame.deserialize(em.parent / "alternatives.pl").join(
            links, left_on="alternative", right_on="movieId"
        )
        tmdb_articles = list(tqdm(map(lambda x: fetch_tmdb(*x), zip(df["tmdbId"], df["tv"])), total=len(df)))
        emebed = np.stack([np.array(i.outputs.embedding) for i in model.embed(tmdb_articles)])
        with h5py.File(em, "a") as f:
            if "embed" in f:
                del f["embed"]

            f.create_dataset("embed", emebed.shape, dtype=np.float32)
            f["embed"][...] = emebed

    nf_llm = ds / "nf_llm"

    shutil.rmtree(nf_llm, ignore_errors=True)
    shutil.copytree(ds / "nf", nf_llm)
    for em in nf_llm.rglob("embed.hdf5"):
        dir = em.parent
        print(dir)
        df = pl.DataFrame.deserialize(dir / "alternatives.pl")
        print(len(df))
        ids = []
        movie: str
        for _, _, year, movie, _ in tqdm(df.iter_rows(), total=len(df)):
            ids.append(search(movie, year))
        df = df.with_columns(
            tmdb=pl.Series([i[0] for i in ids]),
            tmdb_tv=pl.Series([i[1] for i in ids]),
        ).filter(pl.col.tmdb > 0)
        print(len(df))
        choices = pl.DataFrame.deserialize(dir / "choices.pl").filter(
            pl.col("alternative").is_in(df["alternative"].implode())
        )

        tmdb_articles = []
        for i in tqdm(df.iter_rows(named=True), total=len(df)):
            tmdb_articles.append(fetch_tmdb(i["tmdb"], i["tmdb_tv"]))
        emebed = np.stack(
            [np.array(i.outputs.embedding) for i in model.embed(tmdb_articles)]
        )
        shutil.rmtree(dir)
        dir.mkdir()

        b3.save_dataset(dir, df, choices, emebed)
