import argparse
import json
from typing import Literal

from mow.dataset.embodied import TargetModel
from mow.scripts.eval_alfworld import eval_alfworld_for_mow
from mow.scripts.eval_virtualhome import eval_virtualhome_for_mow
from mow.utils.program import Program


class EvalArgs(argparse.Namespace):
    type: str
    model_path: str
    env: str
    dataset_path: list[str] | None
    ip: str
    port: int
    domain: Literal["seen", "unseen"]
    task: Literal["seen", "unseen"]
    output_filename: str | None


class EvalProgram(Program, args=EvalArgs, name="eval", help="Train a model."):
    @staticmethod
    def add_arguments(parser: argparse.ArgumentParser):
        parser.add_argument(
            "type", choices=["expert", "mow"], help="Type of model to train"
        )
        parser.add_argument("model_path", help="Path to the model file")
        parser.add_argument(
            "-e",
            "--env",
            help="Type of environment to use",
            choices=["json", "virtualhome", "alfworld"],
            default="json",
        )
        parser.add_argument(
            "-d",
            "--dataset",
            dest="dataset_path",
            help="Path to the dataset",
            type=str,
            default=None,
            nargs="*",
        )
        parser.add_argument(
            "--ip", help="IP address of the simulator", default="localhost"
        )
        parser.add_argument(
            "-p",
            "--port",
            help="Port number of the simulator",
            type=int,
            default=8080,
        )
        parser.add_argument(
            "--domain",
            choices=["seen", "unseen"],
            help="Domain to use",
            default="seen",
        )
        parser.add_argument(
            "--task",
            choices=["seen", "unseen"],
            help="Task to use",
            default="seen",
        )
        parser.add_argument(
            "-o",
            "--output",
            dest="output_filename",
            help="Path to the output file",
            type=str,
            default=None,
        )

    @staticmethod
    def main(args: EvalArgs):
        match args.type:
            case "expert":
                from transformers import AutoModelForCausalLM, AutoTokenizer

                model = AutoModelForCausalLM.from_pretrained(args.model_path)
                tokenizer = AutoTokenizer.from_pretrained(args.model_path)
                target_model = TargetModel.WORLD
            case "mow":
                from transformers import AutoTokenizer

                from mow.modules.mow import MoW

                model = MoW.from_pretrained(args.model_path)
                tokenizer = AutoTokenizer.from_pretrained(args.model_path)
                target_model = TargetModel.POLICY
            case _:
                raise ValueError(f"Unknown model type: {args.type}")

        match args.env:
            case "json":
                from mow.scripts.eval_json import eval_json

                if args.dataset_path is None:
                    raise ValueError(
                        "Dataset path is required for JSON evaluation"
                    )
                hist = eval_json(
                    model,
                    tokenizer,
                    dataset_path=args.dataset_path[0],
                    target_model=target_model,
                )
            case "virtualhome":
                from mow.modules.mow import MoW

                if not isinstance(model, MoW):
                    raise NotImplementedError(
                        "VirtualHome evaluation is only supported for MoW models"
                    )
                hist = eval_virtualhome_for_mow(
                    model,
                    tokenizer,
                    virtualhome_ip=args.ip,
                    port=args.port,
                    domain_type=args.domain,
                    task_type=args.task,
                )
            case "alfworld":
                from mow.modules.mow import MoW

                if not isinstance(model, MoW):
                    raise NotImplementedError(
                        "ALFRED evaluation is only supported for MoW models"
                    )
                hist = eval_alfworld_for_mow(
                    model,
                    tokenizer,
                    dataset_path=(
                        args.dataset_path
                        if args.dataset_path is not None
                        else []
                    ),
                )
            case _:
                raise ValueError(f"Unknown environment: {args.env}")

        if hist and args.output_filename is not None:
            with open(args.output_filename, "w") as f:
                for h in hist:
                    f.write(json.dumps(h) + "\n")
                    f.flush()
        print("Evaluation complete.")
