import json
import os
from collections import defaultdict
from typing import Any, Literal, override

from transformers import (
    AutoModel,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from mow.dataset import AutoChatDatasetBuilder, ChatDatasetBuilder
from mow.utils.virtualhome.knowledge_graph import KnowledgeGraph

_Phase = Literal["behavior-cloning", "belief-state-inference"]
BASE_PROMPT = (
    "You are a home robot agent. You can use 6 skills, (walk [object or room], "
    "grab [object], switch [object], open [object], putin [target object], put "
    "[target object]). {instruction} Room: livingroom, bathroom, kitchen, "
    "bedroom."
)
GENERATE_ACTION_PROMPT = BASE_PROMPT.format(
    instruction="You should return only a skill after 'Action:'."
)
GENERATE_OBSERVATION_PROMPT = BASE_PROMPT.format(
    instruction="You should return only an observation after 'Observation:'."
)

VALID_HOUSE = list(
    map(int, "0 1 5 6 7 8 9 12 13 15 18 20 22 24 26 28 29 31 32 34".split())
)
VALID_TASK = [i for i in range(78)]
ALLOWED_PHASES: set[str] = {"behavior-cloning", "belief-state-inference"}

kg_tokenizer: PreTrainedTokenizer | None = None
kg_model: PreTrainedModel | None = None


def prompt_template(knowledge_graph, task):
    prompts = f"Observation: {knowledge_graph}, "
    prompts += f"Instruction: {task}"
    return prompts


def embedding_fns(sentences):
    global kg_tokenizer, kg_model

    if kg_tokenizer is None:
        kg_tokenizer = AutoTokenizer.from_pretrained(
            "sentence-transformers/paraphrase-MiniLM-L6-v2"
        )
        assert kg_tokenizer is not None, "kg_tokenizer is not initialized"
    if kg_model is None:
        kg_model = (
            AutoModel.from_pretrained(
                "sentence-transformers/paraphrase-MiniLM-L6-v2"
            )
            .eval()
            .cuda()
        )

    input_ids = kg_tokenizer(
        sentences, return_tensors="pt", padding=True, truncation=True
    )
    kg_triples = kg_model(
        input_ids=input_ids["input_ids"].cuda(),  # type: ignore
        attention_mask=input_ids["attention_mask"].cuda(),  # type: ignore
    ).pooler_output
    return kg_triples.cpu()


class VirtualHomeDatasetBuilder(ChatDatasetBuilder):
    @override
    @classmethod
    def _convert_to_chat(
        cls, example, /, *, action_only: bool = False, **_
    ) -> list[dict[str, str]]:
        chat = [
            {"role": "system", "content": BASE_PROMPT},
            {
                "role": "user",
                "content": (
                    f"Instruction: {example['instruction']}\n\n"
                    f"Observation: {example['observation']}\n\n"
                    f"Action: "
                ),
            },
            {"role": "assistant", "content": example["action"]},
        ]
        if not action_only:
            if example["next_observation"] is None:
                raise ValueError(
                    "next_observation is None, but action_only is False"
                )
            chat.append(
                {"role": "user", "content": "Next observation: "},
            )
            chat.append(
                {"role": "assistant", "content": example["next_observation"]},
            )
        return chat


AutoChatDatasetBuilder.register("virtualhome", VirtualHomeDatasetBuilder)


class VirtualHomeDatasetGenerator:
    def __init__(
        self,
        dataset_path: str,
        house_ids: list[int] | None = None,
        task_ids: list[int] | None = None,
        phase: _Phase = "behavior-cloning",
    ) -> None:
        if phase not in ALLOWED_PHASES:
            raise ValueError(
                f"Invalid phase provided. Allowed phases are: {ALLOWED_PHASES}"
            )
        if house_ids is None:
            house_ids = VALID_HOUSE
        if task_ids is None:
            task_ids = VALID_TASK

        self.house_ids = house_ids
        self.task_ids = task_ids
        self.dataset_path = os.path.join(dataset_path, "trajectories")
        self.data = self._load_data()

    def reset_task_ids(self):
        self.task_ids = VALID_TASK

    def reset_house_ids(self):
        self.house_ids = VALID_HOUSE

    def _load_data(self) -> defaultdict[Any, list]:
        data = defaultdict(list)
        # Traverse task directories
        for room_dir in os.listdir(self.dataset_path):
            for episode_file in os.listdir(
                os.path.join(self.dataset_path, room_dir)
            ):
                if episode_file.endswith(".jsonl"):
                    episode_path = os.path.join(
                        self.dataset_path, room_dir, episode_file
                    )
                    with open(episode_path, "r") as f:
                        trajectory = []
                        for line in f:
                            entry = json.loads(line)
                            trajectory.append(entry)
                        data[
                            (trajectory[0]["env_id"], trajectory[0]["task_id"])
                        ].append(trajectory)
        return data

    def __generate_history(self, traj: list[dict[str, Any]]):
        kg = KnowledgeGraph.from_dict(traj[0]["position_graph"])
        for step in range(len(traj)):
            kg.extend(
                traj[step]["visible_graph"],
                timestep=step,
                use_refinement=True,
            )
            kg.extend(
                traj[step]["agent_graph"],
                timestep=step,
                use_refinement=True,
            )

            next_kg = kg.clone()
            next_kg.extend(
                traj[step]["next_agent_graph"],
                timestep=step + 1,
                use_refinement=True,
            )
            next_kg.extend(
                traj[step]["next_visible_graph"],
                timestep=step + 1,
                use_refinement=True,
            )

            inst: str = traj[step]["instruction"]
            act: str = traj[step]["action"]

            obs = kg.retrieve(inst, embedding_fns=embedding_fns, num_edges=17)

            next_obs = next_kg.retrieve(
                inst, embedding_fns=embedding_fns, num_edges=17
            )

            obs_triples = [
                tuple(triple.lower().strip("()").split(", ")) for triple in obs
            ]
            next_obs_triples = [
                tuple(triple.lower().strip("()").split(", "))
                for triple in next_obs
            ]

            for from_t, rel, to_t in obs_triples:
                if rel == "hold":
                    hold_object = to_t
                    break
            else:
                hold_object = "character"

            updated_triples: list[str] = []
            for from_t, rel, to_t in obs_triples:
                for (
                    from_tt,
                    rel_tt,
                    to_tt,
                ) in next_obs_triples:
                    if (
                        from_t == from_tt
                        and rel == rel_tt
                        and to_t != to_tt
                        and f"({from_tt}, {rel_tt}, {to_tt})" not in obs
                    ):
                        if (
                            (from_tt == "character" and to_tt in act)
                            or from_tt in act
                            or (from_tt == hold_object and to_tt in act)
                        ):
                            updated_triples.append(
                                f"({from_tt}, {rel_tt}, {to_tt})"
                            )
                            break

            for from_tt, rel_tt, to_tt in next_obs_triples:
                if (
                    f"({from_tt}, {rel_tt}, {to_tt})" not in updated_triples
                    and f"({from_tt}, {rel_tt}, {to_tt})" not in obs
                ):
                    if (from_tt == "character" and to_tt in act) or (
                        from_tt == hold_object and to_tt in act
                    ):
                        updated_triples.append(
                            f"({from_tt}, {rel_tt}, {to_tt})"
                        )

            if not updated_triples:
                if act.split()[0] == "put":
                    updated_triples.append(
                        f"({hold_object}, on, {act.split()[1]}"
                    )
                elif act.split()[0] == "putin":
                    updated_triples.append(
                        f"({hold_object}, inside, {act.split()[1]}"
                    )
                else:
                    updated_triples.append("No updates")

            next_obs = ", ".join(updated_triples).lower()
            obs = ", ".join(obs).lower()
            yield obs, act, next_obs

    def __generate_data_for_trajectory(self, traj: list[dict[str, Any]]):
        inst = traj[0]["instruction"]
        return {
            "instruction": inst,
            "history": [
                {
                    "observation": obs,
                    "action": act,
                    "next_observation": next_obs,
                }
                for obs, act, next_obs in self.__generate_history(traj)
            ],
        }

    def generate_dataset(self, num_augmented: int = 50):
        for env_pair in [
            (env_id, task_id)
            for task_id in self.task_ids
            for env_id in self.house_ids
        ]:
            traj_list = self.data[env_pair]
            for traj_idx in range(len(traj_list)):
                traj = traj_list[traj_idx]
                for _ in range(num_augmented):
                    yield self.__generate_data_for_trajectory(traj)
