"""Dataset building for token classification training."""

from typing import List, Tuple
from datasets import Dataset, DatasetDict
from tqdm import tqdm
import random

from .reconstruct import reconstruct_completions

def build_dataset(
    save_dict: dict,
    tokenizer,
    restrict_num_prompts: int,
    args,
    balance_difficulty: bool = False,
    select_difficulty: str = None,
) -> Tuple[DatasetDict, List[List[List[int]]]]:
    """
    Construct per-(prompt, completion) examples with token-level labels.

    Labels are placed only at paragraph boundaries (last token of each paragraph).
    
    Always uses inj_roll_trunc behavior:
    - Injects rollouts between paragraphs
    - Rollout texts are always truncated to the first unbalanced } symbol
    
    This allows training with dense supervision while evaluating on boundary tokens only.
    Splits are grouped by prompt (configurable ratio) to avoid leakage across completions of the same prompt.
    
    Args:
        save_dict: Dictionary containing prompts, completions, rollout data
        tokenizer: Tokenizer for encoding
        restrict_num_prompts: If > 0, limit to first N prompts
        
    Returns:
        Tuple of (DatasetDict with train/validation splits, paragraph_times)
    """

    def slicing_by_list(data, indices):
        return [data[i] for i in indices]

    prompt_ids: List[List[int]] = save_dict["prompt_ids"]
    completion_ids: List[List[List[int]]] = save_dict["completion_ids"]
    rollout_rewards: List[List[List[bool]]] = save_dict["rollout_rewards"]
    rollout_bin_indices: List[List[List[int]]] = save_dict["rollout_bin_indices"]
    rollout_texts: List[List[List[str]]] = save_dict["rollout_texts"]
    
    save_args = save_dict["args"]
    paragraph_delimiter_token_id: int = save_args.paragraph_delimiter_token_id
    rollout_append_text: str = save_args.rollout_append_text

    num_prompts = len(save_dict["prompt_ids"])
    num_completions_per_prompt = len(completion_ids[0])
    N = min(num_prompts, restrict_num_prompts)
    last_rollout_rewards = [[rewards[-1] for rewards in prompt_rollout_rewards] for prompt_rollout_rewards in rollout_rewards]

    assert not (balance_difficulty and (select_difficulty is not None))

    random.seed(args.dataset_seed)
    if balance_difficulty:
        selected_indices = []
        difficulties = [sum(rewards_per_prompt)/num_completions_per_prompt for rewards_per_prompt in last_rollout_rewards]
        num_difficulty_bins = args.num_difficulty_bins
        N_per_difficulty = N // num_difficulty_bins
        for i in range(num_difficulty_bins):
            lower = i/num_difficulty_bins
            upper = (i+1)/num_difficulty_bins if i < num_difficulty_bins - 1 else 1.0
            idxs = [idx for idx in range(num_prompts) if lower <= difficulties[idx] < upper]
            print(f"Difficulty {i} has {len(idxs)} prompts")
            selected_indices += random.sample(idxs, min(N_per_difficulty, len(idxs)))
        num_remaining_prompts = N - len(selected_indices)
        if num_remaining_prompts > 0:
            print(f"Selected {len(selected_indices)} prompts, {num_remaining_prompts} remaining")
            selected_indices += random.sample(list(set(range(num_prompts)) - set(selected_indices)), num_remaining_prompts)
        random.shuffle(selected_indices)
    elif select_difficulty is not None:
        lower_bound = float(select_difficulty.split(",")[0])
        upper_bound = float(select_difficulty.split(",")[1])
        difficulties = [sum(rewards_per_prompt)/num_completions_per_prompt for rewards_per_prompt in last_rollout_rewards]
        selected_indices = [idx for idx in range(num_prompts) if lower_bound <= difficulties[idx] < upper_bound]
        if len(selected_indices) < N:
            raise ValueError(f"There are only {len(selected_indices)} prompts with difficulty {select_difficulty}, but you specified N = {N}")
        selected_indices = random.sample(selected_indices, N)
    else:
        selected_indices = random.sample(list(range(N)), N)
    
    prompt_ids, completion_ids, rollout_rewards, rollout_bin_indices, rollout_texts = slicing_by_list(prompt_ids, selected_indices), slicing_by_list(completion_ids, selected_indices), slicing_by_list(rollout_rewards, selected_indices), slicing_by_list(rollout_bin_indices, selected_indices), slicing_by_list(rollout_texts, selected_indices)

    # Token-based reconstruction and labeling
    input_ids, labels, bin_indices, section_labels = reconstruct_completions(
        tokenizer=tokenizer,
        prompt_ids=prompt_ids,
        completion_ids=completion_ids,
        rollout_texts=rollout_texts,
        delimiter_token_id=paragraph_delimiter_token_id,
        rollout_append_text=rollout_append_text,
        rollout_rewards=rollout_rewards,
        rollout_bin_indices=rollout_bin_indices,
    )

    def flatten_from_prompt_indices(
        prompt_indices: List[int], 
        input_ids_data: List[List[List[int]]], 
        labels_data: List[List[List[int]]],
        bin_indices_data: List[List[List[int]]],
        section_labels_data: List[List[List[int]]],
    ) -> Tuple[List[List[int]], List[List[int]], List[List[int]], List[List[int]], List[int], List[int]]:
        xs: List[List[int]] = []
        ys: List[List[int]] = []
        section_labels_flat: List[List[int]] = []
        bin_indices_flat: List[List[int]] = []
        prompt_idxs: List[int] = []
        completion_idxs: List[int] = []
        for pi in prompt_indices:
            num_completions = len(input_ids_data[pi])
            xs.extend(input_ids_data[pi])
            ys.extend(labels_data[pi])
            section_labels_flat.extend(section_labels_data[pi])
            bin_indices_flat.extend(bin_indices_data[pi])
            prompt_idxs.extend([pi] * num_completions)
            completion_idxs.extend(list(range(num_completions)))
        return xs, ys, section_labels_flat, bin_indices_flat, prompt_idxs, completion_idxs

    x, y, section_labels, bin_indices, prompt_idxs, completion_idxs = flatten_from_prompt_indices(
        list(range(len(input_ids))), input_ids, labels, bin_indices, section_labels)

    ds = Dataset.from_dict({
        "input_ids": x,
        "labels": y,
        "section_labels": section_labels,
        "bin_idx": bin_indices,
        "prompt_idx": prompt_idxs,
        "completion_idx": completion_idxs,
    })

    # Return dataset
    return ds
