from tqdm import tqdm
from transformers import GenerationMixin, PreTrainedTokenizerBase

from mow.common.trainer import init_model_or_tokenizer
from mow.dataset import AutoChatDatasetBuilder
from mow.dataset.embodied import TargetModel
from mow.modules.mow import MoW


def get_device():
    """
    Get the device to use for evaluation.
    """
    import torch

    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


def eval_json(
    model: GenerationMixin,
    tokenizer: PreTrainedTokenizerBase,
    *,
    dataset_path: str,
    target_model: TargetModel,
):
    """
    Evaluate a model on a JSON dataset.
    """
    device = get_device()
    model.to(device)  # type: ignore

    dataset = AutoChatDatasetBuilder.load(dataset_path=dataset_path).as_chat(
        tokenizer=tokenizer
    )

    filter_fn = lambda x: x["next_observation"] is None
    dataset = dataset.filter(filter_fn).unwrap()

    init_model_or_tokenizer(model=model, tokenizer=tokenizer)

    total = 0
    correct = 0
    history = list[dict]()
    assistant_token = "<|start_header_id|>assistant<|end_header_id|>"
    with tqdm(
        dataset,
        desc="Evaluating",
        total=len(dataset),
    ) as bar:
        for example in bar:
            input_text: str = example["text"]  # type: ignore
            *input_text_split, answer = input_text.split(assistant_token)
            input_text = assistant_token.join(input_text_split + [""])
            inputs = tokenizer(input_text, return_tensors="pt")
            inputs.to(device)
            if isinstance(model, MoW):
                h, adj, rel, ctx = model.obs_to_graph(
                    observation=example["observation"],  # type: ignore
                    instruction=example["instruction"],  # type: ignore
                )
                h.to(device)
                adj.to(device)
                rel = rel.to(device) if rel is not None else None
                ctx.to(device)
                inputs["hidden_states"] = h
                inputs["adjacency_matrix"] = adj
                inputs["relation_matrix"] = rel
                inputs["context"] = ctx

            outputs = model.generate(**inputs, max_length=inputs["input_ids"].shape[1] + 10)  # type: ignore
            output_text = tokenizer.decode(
                outputs[0], skip_special_tokens=False
            )
            output_text = output_text.split(assistant_token)[-1]
            output_text = output_text.split("<|eot_id|>")[0]
            output_text = output_text.strip()

            answer = answer.split("<|eot_id|>")[0]
            answer = answer.strip()

            print("Output Text:", output_text)
            print("Answer Text:", answer)

            if answer == output_text:
                correct += 1
            total += 1

            bar.set_postfix({"accuracy": f"{correct / total:.2%}"})
            history.append(
                {
                    "input": input_text,
                    "output": output_text,
                    "answer": answer,
                    "correct": answer == output_text,
                }
            )

    print(f"Accuracy: {correct / total:.2%}")
    history.append(
        {
            "accuracy": correct / total,
        }
    )
    return history
