import itertools
import math
from dataclasses import dataclass

import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase

from datasets import concatenate_datasets
from mow.common.trainer import init_model_or_tokenizer
from mow.dataset import AlfworldDatasetBuilder
from mow.modules.mow import MoW


@dataclass
class EvalResult:
    success: int
    total: int


def get_device():
    """
    Get the device to use for evaluation.
    """
    import torch

    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


def _calc_result(
    result: dict[int, EvalResult],
) -> tuple[float, float]:
    sr_lst = [
        res.success / res.total if res.total > 0 else 0
        for res in result.values()
    ]
    sr = sum(sr_lst) / len(sr_lst) if sr_lst else 0

    ps_lst = (1 / i if i > 0 else 100 for i in sr_lst)
    ps_cum = (min(i, 100) for i in itertools.accumulate(ps_lst))
    ps = math.sumprod(ps_cum, ((i / 30) for i in range(30, 0, -1))) / 30

    return sr, ps


def eval_alfworld_for_mow(
    model: MoW,
    tokenizer: PreTrainedTokenizerBase,
    *,
    dataset_path: list[str],
):
    """
    Evaluate a model on a JSON dataset.
    """
    device = get_device()
    model.to(device)  # type: ignore

    dataset = concatenate_datasets(
        [
            AlfworldDatasetBuilder.load(path)
            .expand()
            .as_chat(tokenizer=tokenizer, action_only=True)
            .prepare_graph_representation(
                sentence_transformer=model.sentence_transformer
            )
            .unwrap(type="pt", output_all_columns=True)
            for path in dataset_path
        ]
    )

    init_model_or_tokenizer(model=model, tokenizer=tokenizer)

    result: dict[int, EvalResult] = {i: EvalResult(0, 0) for i in range(1, 31)}
    assistant_token = "<|start_header_id|>assistant<|end_header_id|>"
    with tqdm(
        dataset,
        desc="Evaluating",
        total=len(dataset),
    ) as bar:
        for example in bar:
            length = len(example["history"])  # type: ignore
            if length > 30:
                continue

            input_text: str = example["text"]  # type: ignore
            *input_text_split, answer = input_text.split(assistant_token)
            input_text = assistant_token.join(input_text_split + [""])
            inputs = tokenizer(input_text, return_tensors="pt")

            if length == 1:
                model.restore_router()
                output_text = "look"
            else:
                hidden_states = example["nodes"]  # type: ignore
                adjacency_matrix = example["adjacency_matrix"]  # type: ignore
                relation_matrix = example["relation_matrix"]  # type: ignore
                context = example["context"]  # type: ignore

                inputs.to(device)
                hidden_states = hidden_states.to(device)
                adjacency_matrix = adjacency_matrix.to(device)
                relation_matrix = relation_matrix.to(device)
                context = context.to(device)

                routing_scores: list[dict[str, dict[str, torch.Tensor]]] = []
                outputs = model.generate(
                    hidden_states=hidden_states,
                    adjacency_matrix=adjacency_matrix,
                    relation_matrix=relation_matrix,
                    context=context,
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=inputs["input_ids"].shape[1] + 10,  # type: ignore
                    routing_score_collector=routing_scores,
                )
                output_text = tokenizer.decode(
                    outputs[0], skip_special_tokens=False
                )
                output_text = output_text.split(assistant_token)[-1]
                output_text = output_text.split("<|eot_id|>")[0]
                output_text = output_text.strip()

                model.refine_router(
                    hidden_states=hidden_states,
                    adjacency_matrix=adjacency_matrix,
                    relation_matrix=relation_matrix,
                    context=context,
                )

            answer = answer.split("<|eot_id|>")[0]
            answer = answer.strip()

            if answer == output_text:
                result[length].success += 1
            result[length].total += 1

            sr, ps = _calc_result(result)
            bar.set_postfix({"SR": f"{sr*100:.2f}%", "PS": f"{ps:.2f}"})
            yield {
                "message": "action",
                "step": length,
                "action": output_text,
                "answer": answer,
            }

    sr, ps = _calc_result(result)
    yield {
        "message": "result",
        "success_rate": sr,
        "pending_steps": ps,
    }
