# TODO: Adapted from cli
from typing import Callable, List, Optional

import numpy as np


def compute_num_context(init_video_length, context_size, context_overlap):
    step = context_size - context_overlap
    num_windows = (init_video_length - context_size) // step + 1
    return num_windows


def compute_context_indices(num_context, context_size, context_overlap):
    indices = []
    for i in range(num_context):
        start_index = i * (context_size - context_overlap)
        end_index = start_index + context_size - 1
        indices.append((start_index, end_index))
    return indices

def ordered_halving(val):
    bin_str = f"{val:064b}"
    bin_flip = bin_str[::-1]
    as_int = int(bin_flip, 2)

    return as_int / (1 << 64)


def uniform(
    step: int = ...,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
    **kwargs,
):
    if num_frames <= context_size:
        yield list(range(num_frames))
        return

    context_stride = min(
        context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
    )

    for context_step in 1 << np.arange(context_stride):
        pad = int(round(num_frames * ordered_halving(step)))
        for j in range(
            int(ordered_halving(step) * context_step) + pad,
            num_frames + pad + (0 if closed_loop else -context_overlap),
            (context_size * context_step - context_overlap),
        ):
            next_itr = []
            for e in range(j, j + context_size * context_step, context_step):
                if e >= num_frames:
                    e = num_frames - 2 - e % num_frames
                next_itr.append(e)

            yield next_itr


def uniform_prefix(
    step: int = 0,
    num_frames: int = 0,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
    start_num_frames: int = 16,  # New parameter for the special start window
):
    if num_frames <= context_size:
        yield list(range(num_frames))
        return

    context_stride = min(
        context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
    )

    # Special handling for the first context window
    if start_num_frames > context_size:
        # Generate the first context window
        pad = int(round(num_frames * ordered_halving(step)))
        j = int(ordered_halving(step) * context_stride) + pad
        next_itr = []
        for e in range(j, j + start_num_frames, 1):
            if e >= num_frames:
                e = num_frames - 2 - e % num_frames
            next_itr.append(e)
        yield next_itr
    else:
        # If start_num_frames == context_size, generate the first context window
        yield list(range(context_size))

    # Handle the remaining context windows
    # Start from the end of the first context window
    start_j = start_num_frames - context_overlap

    for context_step in 1 << np.arange(context_stride):
        pad = int(round(num_frames * ordered_halving(step)))
        for j in range(
            start_j + pad,
            num_frames + pad + (0 if closed_loop else -context_overlap),
            (context_size * context_step - context_overlap),
        ):
            next_itr = []
            for e in range(j, j + context_size * context_step, context_step):
                if e >= num_frames:
                    e = num_frames - 2 - e % num_frames
                next_itr.append(e)

            yield next_itr


def get_context_scheduler(name: str) -> Callable:
    if name == "uniform":
        return uniform
    elif name == "uniform_prefix":
        return uniform_prefix
    else:
        raise ValueError(f"Unknown context_overlap policy {name}")


def get_total_steps(
    scheduler,
    timesteps: List[int],
    num_steps: Optional[int] = None,
    num_frames: int = ...,
    context_size: Optional[int] = None,
    context_stride: int = 3,
    context_overlap: int = 4,
    closed_loop: bool = True,
    start_num_frames: int = 16,  # New parameter for the special start window
):
    return sum(
        len(
            list(
                scheduler(
                    i,
                    num_steps,
                    num_frames,
                    context_size,
                    context_stride,
                    context_overlap,
                )
            )
        )
        for i in range(len(timesteps))
    )


if __name__ == "__main__":
    # Test data
    num_frames = 120
    context_size = 16
    context_stride = 1
    context_overlap = 0
    step = 0
    n_motion_frames = 0
    context_batch_size = 1

    # Call the function and print the results
    context_queue = list(
        uniform_prefix(
            step=step,
            num_frames=num_frames,
            context_size=context_size,
            context_stride=context_stride,
            context_overlap=context_overlap,
            start_num_frames=context_size + n_motion_frames,
            closed_loop=False
        )
    )
    import math
    num_context_batches = math.ceil(len(context_queue) / context_batch_size)
    global_context = []
    for k in range(num_context_batches):
        global_context.append(
            context_queue[
                k * context_batch_size : (k + 1) * context_batch_size
            ]
        )
    for frames in global_context:
        print(frames)
