from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence

import torch
from datasets import load_dataset
from torch.utils.data import Dataset

from .spec import DatasetSpec
from .task_formatters import format_example

IGNORE_INDEX = -100


@dataclass
class Sample:
    input_ids: List[int]
    attention_mask: List[int]
    labels: List[int]
    prompt_text: str
    task: str


class MaskPruningDataset(Dataset):
    def __init__(self, samples: Sequence[Sample]):
        self.samples = list(samples)

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        sample = self.samples[idx]
        return {
            "input_ids": sample.input_ids,
            "attention_mask": sample.attention_mask,
            "labels": sample.labels,
            "prompt_text": sample.prompt_text,
            "task": sample.task,
        }


def _encode_chat_example(
    tokenizer,
    system_prompt: Optional[str],
    user_prompt: str,
    assistant_response: str,
    cutoff_len: int,
) -> Optional[Sample]:
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_prompt})

    if tokenizer.chat_template is None:
        template = (
            "{{- bos_token }}\n"
            "{%- for message in messages %}\n"
            "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>' }}\n"
            "{{ message['content'] }}\n"
            "<|eot_id|>\n"
            "{%- endfor %}\n"
            "{{ '<|start_header_id|>' + 'assistant' + '<|end_header_id|>' }}\n"
        )
        tokenizer.chat_template = template

    prompt_enc = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        truncation=True,
        max_length=cutoff_len,
    )
    if isinstance(prompt_enc, dict):
        prompt_ids = prompt_enc["input_ids"]
        attention_prompt = prompt_enc.get("attention_mask")
    else:
        prompt_ids = prompt_enc
        attention_prompt = None

    if prompt_ids.dim() == 2:
        prompt_ids = prompt_ids.squeeze(0)
    if prompt_ids.dim() != 1:
        raise ValueError(f"Unexpected prompt tensor shape: {prompt_ids.shape}")

    response_text = assistant_response.strip()
    if not response_text.endswith(tokenizer.eos_token or ""):
        response_text = response_text + (tokenizer.eos_token or "")

    response_ids = tokenizer.encode(response_text, add_special_tokens=False)
    if len(response_ids) == 0:
        return None

    total_length = prompt_ids.size(0) + len(response_ids)
    if total_length > cutoff_len:
        return None

    input_ids = torch.cat(
        [prompt_ids, torch.tensor(response_ids, dtype=torch.long)],
        dim=0,
    )
    attention_mask = torch.ones_like(input_ids)
    labels = torch.full_like(input_ids, IGNORE_INDEX)
    labels[-len(response_ids) :] = torch.tensor(response_ids, dtype=torch.long)

    return Sample(
        input_ids=input_ids.tolist(),
        attention_mask=attention_mask.tolist(),
        labels=labels.tolist(),
        prompt_text=user_prompt,
        task="",
    )


def _load_raw_dataset(spec: DatasetSpec, seed: int) -> List[dict]:
    if spec.source != "hf_hub":
        raise NotImplementedError(f"Unsupported dataset source: {spec.source}")

    ds = load_dataset(
        spec.hf_path,
        name=spec.subset,
        split=spec.split,
        trust_remote_code=True,
    )
    if spec.max_samples is not None:
        ds = ds.shuffle(seed=seed)
        ds = ds.select(range(min(spec.max_samples, len(ds))))
    return list(ds)


def build_training_dataset(
    specs: List[DatasetSpec],
    tokenizer,
    system_prompt: str,
    cutoff_len: int,
    prompt_max_length: int,
    seed: int = 42,
) -> MaskPruningDataset:
    random.seed(seed)
    samples: List[Sample] = []

    for spec in specs:
        raw_examples = _load_raw_dataset(spec, seed)
        for example in raw_examples:
            formatted = format_example(spec.task, example)
            if formatted is None:
                continue
            prompt, response = formatted
            if len(prompt) == 0 or len(response) == 0:
                continue
            if len(prompt) > prompt_max_length:
                prompt = prompt[:prompt_max_length]
            encoded = _encode_chat_example(tokenizer, system_prompt, prompt, response, cutoff_len)
            if encoded is None:
                continue
            encoded.task = spec.task
            samples.append(encoded)

    random.shuffle(samples)
    return MaskPruningDataset(samples)

