import random
from typing import Any, Generator, Literal

import torch
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerBase

from mow.common.tokenizer import init_model_or_tokenizer
from mow.dataset.virtualhome import VirtualHomeDatasetBuilder
from mow.environments.virtualhome.virtualhome.simulation.environment.unity_environment import (
    UnityEnvironment,
)
from mow.modules.mow import MoW
from mow.modules.utils import get_router_entropy
from mow.utils.virtualhome.const import (
    SEEN_DOMAIN,
    SEEN_TASKS,
    TASKS,
    TASKS_SET,
    UNSEEN_DOMAIN,
    UNSEEN_TASKS_HALF,
)
from mow.utils.virtualhome.knowledge_graph import KnowledgeGraph

FAILURE_TRIAL = 3
AGENT_ID = 1


def __get_device():
    """
    Get the device to use for evaluation.
    """
    import torch

    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


class MoWAgent:
    def __init__(
        self,
        model: MoW,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device | str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
    ):
        self.device = torch.device(device)
        self.dtype = dtype

        self.model = model
        self.model.to(self.device)  # type: ignore
        self.model.eval()

        self.tokenizer = tokenizer

        init_model_or_tokenizer(model=self.model, tokenizer=self.tokenizer)

    def predict(self, observation: str, instruction: str):
        chat_row = VirtualHomeDatasetBuilder.convert_to_chat(
            {
                "observation": observation,
                "instruction": instruction,
                "action": "",
                "next_observation": None,
            },
            tokenizer=self.tokenizer,
            action_only=True,
        )
        input_text: str = chat_row["text"]
        assistant_token = "<|start_header_id|>assistant<|end_header_id|>"
        *input_text_split, _ = input_text.split(assistant_token)
        input_text = assistant_token.join(input_text_split + [""])
        input_ids = self.tokenizer(input_text, return_tensors="pt")
        input_ids.to(self.device)

        hidden_states, adj_mat, rel_mat, context = self.model.obs_to_graph(
            observation=observation, instruction=instruction
        )
        assert rel_mat is not None, "Relation matrix should not be None"
        hidden_states = hidden_states.to(self.model.device, dtype=self.dtype)
        adj_mat = adj_mat.to(self.model.device)
        rel_mat = rel_mat.to(self.model.device, dtype=self.dtype)
        context = context.to(self.model.device, dtype=self.dtype)
        input_ids = input_ids.to(self.model.device)

        routing_scores: list[dict[str, dict[str, torch.Tensor]]] = []
        output = self.model.generate(
            hidden_states=hidden_states,
            adjacency_matrix=adj_mat,
            relation_matrix=rel_mat,
            context=context,
            input_ids=input_ids["input_ids"],
            max_length=input_ids["input_ids"].shape[1] + 10,  # type: ignore
            routing_score_collector=routing_scores,
        )
        pred = self.model.tokenizer.decode(output[0], skip_special_tokens=False)
        pred = pred.split("assistant")[-1]
        if "<|end_header_id|>" in pred:
            pred = pred.split("<|end_header_id|>")[1]
        if "<|eot_id|>" in pred:
            pred = pred.split("<|eot_id|>")[0]
        output_str = pred.strip()

        return output_str, {
            "nodes": hidden_states,
            "adjacency_matrix": adj_mat,
            "relation_matrix": rel_mat,
            "context": context,
            "routing_scores": routing_scores,
        }


def __generate_embedding_fns(kg_tokenizer, kg_model):
    def embedding_fns(sentences):
        input_ids = kg_tokenizer(
            sentences, return_tensors="pt", padding=True, truncation=True
        )
        kg_triples = kg_model(
            input_ids=input_ids["input_ids"].cuda(),
            attention_mask=input_ids["attention_mask"].cuda(),
        ).pooler_output
        return kg_triples.cpu().detach()

    return embedding_fns


def eval_virtualhome_for_mow(
    model: MoW,
    tokenizer: PreTrainedTokenizerBase,
    *,
    virtualhome_ip: str = "localhost",
    port: int = 8080,
    domain_type: Literal["seen", "unseen"] = "seen",
    task_type: Literal["seen", "unseen"] = "seen",
) -> Generator[dict[str, Any], None, None]:
    import time

    random.seed(int(time.time() * 1000) % 1000000)  # Use current time as seed
    seed = random.choice([111, 333, 555, 888, 999])
    env = UnityEnvironment(url=virtualhome_ip, seed=seed, base_port=port)

    num_success = 0
    num_epi = 0

    if domain_type == "unseen":
        target_domains = UNSEEN_DOMAIN
    else:
        target_domains = SEEN_DOMAIN

    if task_type == "unseen":
        target_tasks = UNSEEN_TASKS_HALF
    else:
        target_tasks = SEEN_TASKS

    kg_tokenizer = AutoTokenizer.from_pretrained(
        "sentence-transformers/paraphrase-MiniLM-L6-v2"
    )
    kg_model = (
        AutoModel.from_pretrained(
            "sentence-transformers/paraphrase-MiniLM-L6-v2"
        )
        .eval()
        .cuda()
    )
    embedding_fns = __generate_embedding_fns(kg_tokenizer, kg_model)

    agent = MoWAgent(
        model, tokenizer, device=__get_device(), dtype=torch.float32
    )

    yield {
        "message": "evaluation start",
        "domain": domain_type,
        "task": task_type,
    }

    # env reset
    target_tasks.reverse()
    total_timesteps = 0
    with tqdm(target_tasks) as tbar:
        for task_id in tbar:
            agent.model.restore_router()
            for env_id in tqdm(target_domains):
                trial = 0
                env.set_task(
                    {
                        "required_condition": TASKS_SET[TASKS[task_id]],
                        "prohibited_condition": [],
                    }
                )
                yield {
                    "message": "episode start",
                    "domain": env_id,
                    "task": TASKS[task_id],
                }
                try:
                    obs = env.reset(environment_id=env_id)
                except NotImplementedError:
                    yield {
                        "message": "episode skipped",
                        "reason": "NotImplementedError",
                    }
                    continue
                except Exception as e:
                    yield {
                        "message": "episode skipped",
                        "reason": str(e),
                    }
                    continue

                num_epi += 1
                pos_kg = env.get_position_graph()
                kg = KnowledgeGraph.from_dict(pos_kg)

                (
                    gc_success,
                    _,
                    done,
                    _,
                    step,
                    info,
                    failed_action,
                    action,
                ) = (0, True, False, [], 0, {}, [], None)

                while not done:
                    kg.extend(
                        obs["agent_graph"],  # type: ignore
                        timestep=step,
                        use_refinement=True,
                    )
                    kg.extend(
                        obs["visible_graph"],  # type: ignore
                        timestep=step,
                        use_refinement=True,
                    )
                    observation = kg.retrieve(
                        TASKS[task_id],
                        embedding_fns=embedding_fns,
                        num_edges=17,
                    )

                    action, status = agent.predict(
                        ", ".join(observation), TASKS[task_id]
                    )
                    action = action.split("\n")[-1].strip()
                    if not action:
                        action = "nothing"
                    routing_scores: dict[str, dict[str, torch.Tensor]] = status[
                        "routing_scores"
                    ][-1]
                    routing_scores_py = {
                        key: {
                            k: [float(e) for e in v.cpu()]
                            for k, v in value.items()
                        }
                        for key, value in routing_scores.items()
                    }
                    yield {
                        "message": "action",
                        "step": step,
                        "action": action,
                        "router_entropy": sum(
                            get_router_entropy(e)
                            for e in routing_scores.values()
                        )
                        / len(routing_scores),
                        "routing_scores": routing_scores_py,
                    }

                    obs, _, done, info = env.step(action)
                    agent.model.refine_router(
                        hidden_states=status["nodes"],
                        adjacency_matrix=status["adjacency_matrix"],
                        relation_matrix=status["relation_matrix"],
                        context=status["context"],
                    )

                    if not info["success"]:
                        failed_action.append(action)
                    else:
                        failed_action = []
                    info["failed_action"] = failed_action

                    if step > 30:
                        trial += 1
                        num_success += gc_success
                        yield {
                            "message": "episode done",
                            "status": "truncated",
                            "episode": num_epi,
                            "domain": env_id,
                            "task": TASKS[task_id],
                            "pending_steps": step - 1,
                        }
                        break

                    if done:
                        num_success += 1
                        yield {
                            "message": "episode done",
                            "status": "success",
                            "episode": num_epi,
                            "domain": env_id,
                            "task": TASKS[task_id],
                            "pending_steps": step - 1,
                        }
                    step += 1

                total_timesteps += step - 1
                tbar.set_description(
                    f"{num_success}/{num_epi} ({num_success/num_epi*100}%), PENDING_STEPS: {total_timesteps / num_epi}"
                )

    yield {
        "message": "evaluation completed",
        "domain": domain_type,
        "task": task_type,
        "seed": seed,
        "total": num_epi,
        "success": num_success,
        "success%": num_success / num_epi * 100,
        "pending_steps": total_timesteps / num_epi,
    }
