from typing import cast

import torch
from sentence_transformers import SentenceTransformer

from mow.common.graph import map_observation_to_graph
from mow.dataset import AutoChatDatasetBuilder
from mow.modules.mow import MoW
from mow.scripts.train_mow import TrainMoWConfig


def test_mow(config: TrainMoWConfig):
    assert config.train_config.output_dir, "Output directory must be provided."

    sentence_transformer = SentenceTransformer(
        config.sentence_transformer_model
    )
    sentence_transformer.to("cuda" if torch.cuda.is_available() else "cpu")
    sentence_transformer.eval()
    print(
        f"🌟 Loaded SentenceTransformer model: {config.sentence_transformer_model}"
    )

    if (path := config.train_config.output_dir / "best").exists():
        mow = cast(MoW, MoW.from_pretrained(path))
        print(f"🌟 Loaded MoW model from: {path}")
    else:
        mow = MoW(config.mow)
    mow.to("cuda" if torch.cuda.is_available() else "cpu")  # type: ignore

    if config.datasets:
        test_mow_with_dataset(
            mow=mow,
            sentence_transformer=sentence_transformer,
            config=config,
        )
    else:
        test_mow_without_dataset(config)


def test_mow_with_dataset(
    *,
    mow: MoW,
    sentence_transformer: SentenceTransformer,
    config: TrainMoWConfig,
):
    assert config.datasets, "Dataset path must be provided."

    dataset = (
        AutoChatDatasetBuilder.load(
            dataset_path=list(config.datasets.values())[0] / "test"
        )
        .as_chat(tokenizer=mow.tokenizer, action_only=True)
        .prepare_graph_representation(sentence_transformer=sentence_transformer)
        .unwrap(type="pt", output_all_columns=True)
    )

    for example in dataset:
        example = cast(dict[str, str], example)

        prompt = example["text"]
        *splits, answer = prompt.split("assistant<|end_header_id|>")
        prompt = "assistant<|end_header_id|>".join(splits)
        prompt += "assistant<|end_header_id|>\n\n"
        answer = answer.split("<|eot_id|>")[0].strip()

        print(f"Input: {prompt}")
        print(f"Expected Output: {answer}")

        input_ids = mow.tokenizer(prompt, return_tensors="pt").input_ids

        hidden_states = example["nodes"].to(mow.device)  # type: ignore
        adj_mat = example["adjacency_matrix"].to(mow.device)  # type: ignore
        rel = example["relation_matrix"].to(mow.device)  # type: ignore
        context = example["context"].to(mow.device)  # type: ignore
        input_ids = input_ids.to(mow.device)

        output = mow.generate(
            hidden_states=hidden_states,
            adjacency_matrix=adj_mat,
            relation_matrix=rel,
            context=context,
            input_ids=input_ids,
            max_length=input_ids.shape[1] + 10,
        )
        pred = mow.tokenizer.decode(output[0], skip_special_tokens=False)
        pred = pred.split("assistant")[-1]
        if "<|end_header_id|>" in pred:
            pred = pred.split("<|end_header_id|>")[1]
        if "<|eot_id|>" in pred:
            pred = pred.split("<|eot_id|>")[0]
        pred = pred.strip()
        print(f"Output: {pred}")

        try:
            input("\n👉 Press Enter to process the next example... 🚀")
        except KeyboardInterrupt:
            print("\n😢 Test interrupted by user.")
            break


def test_mow_without_dataset(config: TrainMoWConfig):
    raise NotImplementedError
