from typing import Generator, List, NamedTuple

import numpy as np
import torch


def find_end_first_consecutive_true(arr: np.ndarray) -> int:
    
    if not arr[0]:
        return 0

    prog = np.cumsum(arr)
    if prog[-1] == len(arr):
        return len(arr)

    true_locs = np.where(prog[:-1:] == prog[1::])[0]

    return true_locs[0] + 1


def find_start_last_consecutive_true(arr: np.ndarray) -> int:
    
    reverse = find_end_first_consecutive_true(arr[::-1])
    return len(arr) - reverse if reverse > 0 else -1


def group_consecutive_values(arr: np.ndarray, stepsize: int = 1) -> List[np.ndarray]:
    
    return np.split(arr, np.where(np.diff(arr) != stepsize)[0] + 1)


class RepetitionTuple(NamedTuple):
    

    start: int
    end: int
    period: int
    times: int


def find_periodic_sequences(
    arr: np.ndarray, max_period: int, min_period: int = 1, mask_value: int = -1
) -> Generator[RepetitionTuple, None, None]:
    
    
    if (arr == mask_value).sum() > 0:
        raise ValueError("`mask_value` is in the array")

    
    
    max_period = min(max_period, len(arr) // 3)

    for period in range(min_period, max_period + 1):
        
        padded_arr = np.pad(arr, (0, period - (len(arr) % period)), constant_values=mask_value)
        shaped_arr = padded_arr.reshape(-1, period)

        
        is_equal_to_prev_row = shaped_arr == np.roll(shaped_arr, shift=1, axis=0)
        rows_with_period, *_ = np.where(is_equal_to_prev_row.all(axis=1))

        
        if len(rows_with_period) == 0:
            continue

        
        where_true_consecutive = group_consecutive_values(rows_with_period)

        for sequence in where_true_consecutive:
            start_row = sequence[0]
            end_row = sequence[-1]

            
            
            
            
            
            start_offset = find_start_last_consecutive_true(is_equal_to_prev_row[start_row - 1])
            start_offset = period - start_offset if start_offset > 0 else 0

            
            
            
            end_offset = find_end_first_consecutive_true(is_equal_to_prev_row[end_row + 1])

            
            
            start_pos = (start_row - 1) * period - start_offset

            
            end_pos = ((end_row + 1) * period) + end_offset

            out = RepetitionTuple(
                start=start_pos, end=end_pos, period=period, times=(end_pos - start_pos) // period
            )
            if out.times > 2:
                
                
                yield out


def get_document_lengths(input_ids: torch.Tensor, eos_token_id: int) -> torch.Tensor:
    doc_boundaries = torch.cat(
        [
            torch.tensor([-1], dtype=torch.int32),
            (input_ids == eos_token_id).nonzero(as_tuple=True)[0].to(dtype=torch.int32),
            torch.tensor([] if input_ids[-1] == eos_token_id else [input_ids.shape[0] - 1], dtype=torch.int32),
        ]
    )
    return doc_boundaries[1:] - doc_boundaries[:-1]
