from __future__ import annotations

import json
import random
from pathlib import Path
from typing import Iterable, List, Optional

from .dataset import MaskPruningDataset, Sample


def _extract_text(content) -> str:
    if isinstance(content, str):
        return content
    if content is None:
        return ""
    return str(content)


def _ensure_system_message(messages: List[dict], default_system: str) -> List[dict]:
    if not messages:
        return [{"role": "system", "content": default_system}]
    if messages[0]["role"] != "system":
        return [{"role": "system", "content": default_system}] + messages
    return messages


def _encode_chatml_conversation(
    tokenizer,
    conversation: List[dict],
    cutoff_len: int,
    default_system_prompt: str,
    prompt_text_mode: str = "system_user",
    llm_prompt_mode: str = "system_user",
    append_eos: bool = False,
) -> Optional[Sample]:
    if not conversation:
        return None
    assistant_msg = conversation[-1]
    if assistant_msg.get("role") != "assistant":
        return None

    response_text = _extract_text(assistant_msg.get("content"))
    if not response_text.strip():
        return None

    prompt_messages_raw = conversation[:-1]
    prompt_messages_raw = [
        {"role": msg["role"], "content": _extract_text(msg.get("content"))}
        for msg in prompt_messages_raw
    ]
    prompt_messages_raw = [
        msg for msg in prompt_messages_raw if msg["content"].strip() or msg["role"] == "system"
    ]
    prompt_messages_raw = _ensure_system_message(prompt_messages_raw, default_system_prompt)

    system_prompt_text = prompt_messages_raw[0]["content"] if prompt_messages_raw else default_system_prompt
    user_prompt_text = None
    for msg in reversed(prompt_messages_raw):
        if msg["role"] == "user":
            user_prompt_text = msg["content"]
            break

    if llm_prompt_mode == "system_user":
        prompt_messages = prompt_messages_raw
    elif llm_prompt_mode == "system_only":
        prompt_messages = [{"role": "system", "content": system_prompt_text}]
    elif llm_prompt_mode == "user_only":
        if user_prompt_text is None:
            return None
        prompt_messages = [{"role": "user", "content": user_prompt_text}]
    else:
        raise ValueError(f"Unsupported llm_prompt_mode: {llm_prompt_mode}")

    prompt_enc = tokenizer.apply_chat_template(
        prompt_messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        truncation=True,
        max_length=cutoff_len,
    )
    if isinstance(prompt_enc, dict):
        prompt_ids = prompt_enc["input_ids"]
    else:
        prompt_ids = prompt_enc
    if prompt_ids.dim() == 2:
        prompt_ids = prompt_ids.squeeze(0)

    response_ids = tokenizer.encode(response_text, add_special_tokens=False)
    if append_eos:
        eos_id = tokenizer.eos_token_id
        if isinstance(eos_id, (list, tuple)):
            eos_token = tokenizer.eos_token
            eos_id = tokenizer.convert_tokens_to_ids(eos_token) if eos_token else eos_id[-1]
        if eos_id is not None:
            eos_id = int(eos_id)
            if not response_ids or response_ids[-1] != eos_id:
                response_ids = response_ids + [eos_id]
    if len(response_ids) == 0:
        return None

    total_len = prompt_ids.size(0) + len(response_ids)
    if total_len > cutoff_len:
        return None

    import torch

    input_ids = torch.cat(
        [prompt_ids, torch.tensor(response_ids, dtype=torch.long)],
        dim=0,
    )
    attention_mask = torch.ones_like(input_ids)
    labels = torch.full_like(input_ids, -100)
    labels[-len(response_ids) :] = torch.tensor(response_ids, dtype=torch.long)

    user_prompt_text = user_prompt_text or ""
    if prompt_text_mode == "system_user":
        prompt_text = f"{system_prompt_text}\n\n{user_prompt_text}"
    elif prompt_text_mode == "system_only":
        prompt_text = system_prompt_text
    elif prompt_text_mode == "user_only":
        prompt_text = user_prompt_text
    else:
        raise ValueError(f"Unsupported prompt_text_mode: {prompt_text_mode}")
    return Sample(
        input_ids=input_ids.tolist(),
        attention_mask=attention_mask.tolist(),
        labels=labels.tolist(),
        prompt_text=prompt_text,
        task="",
    )


def build_chatml_training_dataset(
    files: Iterable[dict],
    tokenizer,
    cutoff_len: int,
    prompt_max_length: int,
    default_system_prompt: str,
    prompt_text_mode: str = "system_user",
    llm_prompt_mode: str = "system_user",
    seed: int = 42,
    append_eos: bool = False,
) -> MaskPruningDataset:
    samples: List[Sample] = []
    for spec in files:
        path = Path(spec["path"])
        split = spec.get("split", "train")
        task_name = spec.get("task_name") or path.stem

        with path.open("r", encoding="utf-8") as f:
            data = json.load(f)
        conversations = data.get(split, [])
        for conversation in conversations:
            sample = _encode_chatml_conversation(
                tokenizer,
                conversation,
                cutoff_len=cutoff_len,
                default_system_prompt=default_system_prompt,
                prompt_text_mode=prompt_text_mode,
                llm_prompt_mode=llm_prompt_mode,
                append_eos=append_eos,
            )
            if sample is None:
                continue
            sample.task = task_name
            samples.append(sample)

    random.seed(seed)
    random.shuffle(samples)
    return MaskPruningDataset(samples)
