from typing import cast

from peft import PeftMixedModel, PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizerBase,
)

from mow.dataset import AutoChatDatasetBuilder
from mow.scripts.train_expert import TrainExpertConfig
from mow.scripts.train_mow import TrainMoWConfig


def test_expert(config: TrainExpertConfig):
    model = AutoModelForCausalLM.from_pretrained(
        config.train_config.output_dir, config=config.model_config
    )
    print(f"🌟 Loaded MoW model from: {config.train_config.output_dir}")

    tokenizer = AutoTokenizer.from_pretrained(config.main_model_path)
    print(f"🌟 Loaded tokenizer from: {config.main_model_path}")

    test_expert_with_dataset(
        model=model,
        tokenizer=tokenizer,
        config=config,
    )


def test_expert_with_dataset(
    *,
    model: PeftModel | PeftMixedModel,
    tokenizer: PreTrainedTokenizerBase,
    config: TrainExpertConfig,
):
    dataset = AutoChatDatasetBuilder.load(config.eval_dataset)
    dataset = dataset.as_chat(tokenizer=tokenizer).unwrap()

    for example in dataset:
        example = cast(dict[str, str], example)

        prompt = example["text"]
        *splits, answer = prompt.split("assistant<|end_header_id|>")
        prompt = "assistant<|end_header_id|>".join(splits)
        prompt += "assistant<|end_header_id|>\n\n"
        answer = answer.split("<|eot_id|>")[0].strip()

        print(f"Input: {prompt}")
        print(f"Expected Output: {answer}")

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        input_ids = input_ids.to(model.device)

        output = model.generate(
            input_ids=input_ids, max_length=input_ids.shape[1] + 10
        )
        pred = tokenizer.decode(output[0], skip_special_tokens=True)
        pred = pred.split("assistant")[-1]
        if "<|end_header_id|>" in pred:
            pred = pred.split("<|end_header_id|>")[1]
        if "<|eot_id|>" in pred:
            pred = pred.split("<|eot_id|>")[0]
        pred = pred.strip()
        print(f"Output: {pred}")

        try:
            input("\n👉 Press Enter to process the next example... 🚀")
        except KeyboardInterrupt:
            print("\n😢 Test interrupted by user.")
            break


def test_mow_without_dataset(config: TrainMoWConfig):
    raise NotImplementedError
