# Diffusion Sampler

This implementation is a fork of the codebase of the original model, which can be found here: https://github.com/seal-rg/recurrent-pretraining.

The new sampler is embedded as a new method into the `modeling/raven_modeling_minimal.py` file, with the rest of the file copied from the original source.

For full completeness, the method is 


```
    @torch.no_grad()
    def generate_diffusion_style(
        self,
        input_ids: torch.Tensor,
        generation_config: Optional[GenerationConfig] = None,
        tokenizer=None,
        streamer=None,
        init_scale: float = 1.0,
        cache_lookup_strategy: str = "latest-m4-compress-s4",
        full_prefill: bool = True,
        ema_embeds: float = 0.1,
        state_noise_mixing: float = 0.5,
        inner_recurrence=4,  # equal to normal generation if state_noise_mixing=1.0 and inner_recurrence=32
        num_steps: int = 32,
        freeze_adaptive: str | bool = False,
        headway: int = 1,
        dampened_state_mixer: bool = True,
        sqrt_mixer: bool = False,
        continuous_compute: bool = False,
        exit_t: float = 0.03,  # used only if freeze=adaptive=latent-diff
        max_wavefront: int = 128,  # if set this will stop the wave expanding until more states freeze
        max_diffusion_steps: int = 4096,  # prevent oot for badly configured hyperparam settings
        return_analysis_tablets: bool = False,
        return_full_state_tablet: bool = False,  # make sure to have enough RAM
        **model_kwargs,
    ) -> Union[torch.Tensor, dict[str, Any]]:
        """Diffusion-style generation."""

        assert input_ids.shape[0] == 1, "Only batch_size=1 supported for now"
        model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
            input_ids, generation_config, cache_lookup_strategy, model_kwargs
        )
        stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)

        current_sequence = input_ids.clone()
        blocked_size = max(self.config.block_size, input_ids.shape[1] + max_new_tokens + headway + full_prefill)
        recurrence_counter_per_position = input_ids.new_zeros([1, blocked_size])
        token_stable_per_position = input_ids.new_zeros([1, blocked_size])
        kv_cache = model_kwargs["past_key_values"]  # reference

        num_core_forward_passes = 0
        num_tokens_forward = 0
        num_cache_clears = 0
        num_standing_waves = 0

        if return_analysis_tablets:
            # max_new_tokens upper as conservative estimate of steps with headway
            shape = [(max_new_tokens + headway) * 2, blocked_size]
            token_tablet = input_ids.new_zeros(shape, device=torch.device("cpu"))
            frozen_tablet = input_ids.new_zeros(shape, device=torch.device("cpu"))
            counter_tablet = input_ids.new_zeros(shape, device=torch.device("cpu"))
            stability_tablet = input_ids.new_zeros(shape, device=torch.device("cpu"))
            if return_full_state_tablet:
                state_tablet = input_ids.new_zeros(
                    [*shape, self.config.n_embd], dtype=torch.bfloat16, device=torch.device("cpu")
                )

        if full_prefill:
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            outputs = self(**model_inputs, init_scale=init_scale, num_steps=num_steps)
            num_core_forward_passes += 1
            num_tokens_forward += input_ids.shape[1]
            next_token = self._sample_next_token(outputs.logits[:, -1, :], input_ids, generation_config)
            if streamer:
                streamer.put(next_token.cpu())
            frozen_tokens = current_sequence = torch.cat([input_ids, next_token], dim=-1)
            recurrence_counter_per_position[:, : frozen_tokens.shape[1]] += num_steps
            states = self.initialize_state(outputs.latent_states[:, -1:, :], scale=init_scale)
        else:
            embed_shape = [1, input_ids.shape[1], self.config.n_embd]
            states = self.initialize_state(input_ids.new_zeros(embed_shape, dtype=torch.bfloat16), scale=init_scale)
            frozen_tokens = input_ids.clone()
            if max_wavefront > 0 and states.shape[1] > max_wavefront:
                raise ValueError(
                    f"The input prompt is too long to fit into the chosen max_wavefront memory limit. Use prefill, "
                    f"or increase max_wavefront to at least {states.shape[1]}"
                )

        if ema_embeds > 0.0:
            old_embeds, _ = self.embed_inputs(current_sequence)  # recompute in case of full prefill
        if freeze_adaptive == "latent-diff":
            previous_states = states.clone()
        old_cache_index = 0
        old_k = k = 0
        step = 0

        while ((frozen_tokens.shape[1] - input_ids.shape[1]) < max_new_tokens) and (
            current_sequence.shape[1] <= self.config.block_size
        ):
            cache_index = kv_cache.get_seq_length()
            # print(states.shape, current_sequence.shape, frozen_tokens.shape, cache_index)
            model_kwargs["cache_position"] = torch.arange(
                cache_index, cache_index + states.shape[1], device=input_ids.device
            )
            model_inputs = self.prepare_inputs_for_generation(current_sequence, **model_kwargs)
            aux_inputs = dict(past_key_values=kv_cache, cache_position=model_kwargs["cache_position"])

            if ema_embeds > 0.0:
                new_embeds, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
                matching_old_embeds = old_embeds[:, cache_index - old_cache_index :]
                embedded_inputs = new_embeds.clone() * (1 - ema_embeds)
                embedded_inputs[:, : matching_old_embeds.shape[1], :] += matching_old_embeds * ema_embeds
                old_embeds = embedded_inputs.clone()
            else:
                embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)

            if state_noise_mixing > 0:
                rand_states = self.initialize_state(states, scale=init_scale)
                if dampened_state_mixer:
                    active_region = recurrence_counter_per_position[:, cache_index : cache_index + states.shape[1]]
                    state_noise = (state_noise_mixing / (1 + active_region))[:, :, None].to(dtype=states.dtype)
                else:  # constant
                    state_noise = torch.as_tensor(state_noise_mixing, device=states.device, dtype=states.dtype)
                if sqrt_mixer:
                    states = states * F.relu(1 - state_noise).sqrt() + state_noise.sqrt() * rand_states
                else:
                    states = states * (1 - state_noise) + state_noise * rand_states
            for substep in range(inner_recurrence):
                states, block_idx, _ = self.iterate_one_step(embedded_inputs, states, block_idx=block_idx, **aux_inputs)
                recurrence_counter_per_position[:, cache_index : cache_index + states.shape[1]] += 1
                num_core_forward_passes += 1
                num_tokens_forward += states.shape[1]

            output = self.predict_from_latents(states, **aux_inputs)
            all_logits: torch.Tensor = output.logits  # type: ignore
            # remove 1) frozen_tokens, but 2) account for the logit vector being cache_index shorter than
            # the full logits vector, and subtract -1 because we are decoding from the last position as well
            final_logits = all_logits[0, max(frozen_tokens.shape[1] - 1 - cache_index, 0) :, :]  # type: ignore
            new_tokens = self._sample_next_token(final_logits, frozen_tokens, generation_config).T  # +1 toks here
            potential_edits = torch.cat([frozen_tokens, new_tokens[:, :-1]], dim=1)
            token_stable_per_position[:, : current_sequence.shape[1]] += current_sequence == potential_edits
            extra_tokens = torch.randint(0, 65510, (1, headway - 1), device=input_ids.device)  # + h-1 toks here

            # frozen tokens + generated/guessed tokens
            current_sequence = torch.cat([frozen_tokens, new_tokens, extra_tokens], dim=1)
            if continuous_compute and states.shape[1] >= headway:
                new_position_state = states[:, -headway:].clone()
            else:
                new_position_state = self.initialize_state(
                    input_ids.new_zeros([1, headway, self.config.n_embd], dtype=torch.bfloat16),
                    scale=init_scale,
                )
            states = torch.cat([states, new_position_state], dim=1)

            if max_wavefront > 0:
                states_extent = states.shape[1]
                if states_extent > max_wavefront:
                    positions_kept = max_wavefront - (states_extent - headway)
                    headway_in_step = positions_kept
                    states = states[:, :max_wavefront]
                    current_sequence = current_sequence[:, : cache_index + max_wavefront]
                    num_standing_waves += 1
                else:
                    headway_in_step = headway
            else:
                headway_in_step = headway

            # Note for this block that state.shape and current_sequence.shape have already advanced by headway many toks
            if freeze_adaptive == "latent-diff":
                matching_prev_states = previous_states[:, cache_index - old_cache_index :]
                match_states = states[:, : matching_prev_states.shape[1], :]
                new_exits = (match_states - matching_prev_states).norm(dim=-1) / match_states.norm(dim=-1) < exit_t
                if new_exits.sum() > 0:
                    k = cache_index + new_exits.nonzero()[-1][1].item() + 1
                    kv_cache.clear_last_k_entries(current_sequence.shape[1] - headway_in_step - k + 1)  # -k or k-1?
                    num_cache_clears += current_sequence.shape[1] - headway_in_step - k + 1
                    states = states[:, k - cache_index - 1 :, :]  # k or k-1
                    frozen_tokens = current_sequence[:, :k]
                    if streamer:
                        streamer.put(frozen_tokens[:, old_k:k].cpu())
                else:
                    kv_cache.clear_last_k_entries(states.shape[1] - headway_in_step)
                    num_cache_clears += states.shape[1] - headway_in_step

                previous_states = states.clone()
            elif freeze_adaptive == "token-stability":
                if torch.any(token_stable_per_position > num_steps // inner_recurrence):
                    # latest freezable pos:
                    k = (token_stable_per_position > num_steps // inner_recurrence).nonzero()[-1][1].item() + 1
                    kv_cache.clear_last_k_entries(current_sequence.shape[1] - headway_in_step - k + 1)  # -k or k-1?
                    num_cache_clears += current_sequence.shape[1] - headway_in_step - k + 1
                    states = states[:, k - cache_index - 1 :, :]  # k or k-1
                    frozen_tokens = current_sequence[:, :k]
                    if streamer:
                        streamer.put(frozen_tokens[:, old_k:k].cpu())
                else:
                    kv_cache.clear_last_k_entries(states.shape[1] - headway_in_step)
                    num_cache_clears += states.shape[1] - headway_in_step
                    k = 0
            else:  # fixed exit after num_steps compute steps
                if step > (num_steps // inner_recurrence):  # start adding to frozen state after num_steps
                    kv_cache.clear_last_k_entries(states.shape[1] - headway - headway_in_step)
                    num_cache_clears += states.shape[1] - headway - headway_in_step
                    states = states[:, headway_in_step:, :]
                    frozen_tokens = torch.cat([frozen_tokens, new_tokens[:, :headway_in_step]], dim=-1)
                    if streamer:
                        streamer.put(new_tokens[:, :headway_in_step].cpu())
                else:
                    kv_cache.clear_last_k_entries(states.shape[1] - headway_in_step)  #
                    num_cache_clears += states.shape[1] - headway_in_step

            if step > num_steps:
                if stop_tokens is not None:
                    if freeze_adaptive:
                        token_stop = any(f in stop_tokens for f in frozen_tokens[0, old_k:k].tolist())
                    else:
                        token_stop = any(f in stop_tokens for f in frozen_tokens[0, -headway:].tolist())
                else:
                    token_stop = False

                if "stopping_criteria" in model_kwargs:
                    crit_stop = model_kwargs["stopping_criteria"](frozen_tokens, None)
                else:
                    crit_stop = False
                if token_stop or crit_stop:
                    break
            if step > max_diffusion_steps:
                break

            if return_analysis_tablets and step < token_tablet.shape[0]:
                token_tablet[step, : current_sequence.shape[1]] = current_sequence.cpu()
                frozen_tablet[step, : frozen_tokens.shape[1]] = frozen_tokens.cpu()
                counter_tablet[step] = recurrence_counter_per_position.cpu()
                stability_tablet[step] = token_stable_per_position.cpu()
                if return_full_state_tablet:
                    state_tablet[step, cache_index : cache_index + states.shape[1]] = states[0].cpu()

            old_cache_index = cache_index
            step += 1
            old_k = k

        if streamer:
            streamer.end()

        if return_analysis_tablets:
            analysis_data = dict(  # package analysis data
                last_step=step,
                last_recurrence=step * inner_recurrence,
                longest_token=current_sequence.shape[1],
                token_tablet=token_tablet,  # .repeat_interleave(inner_recurrence, dim=0),
                frozen_tablet=frozen_tablet,  # .repeat_interleave(inner_recurrence, dim=0),
                counter_tablet=counter_tablet,  # .repeat_interleave(inner_recurrence, dim=0),
                stability_tablet=stability_tablet,  # .repeat_interleave(inner_recurrence, dim=0),
                state_tablet=state_tablet if return_full_state_tablet else None,
            )

        summary_scores = {
            "num_core_forward_passes": num_core_forward_passes,
            "num_tokens_forward": num_tokens_forward,
            "num_cache_clears": num_cache_clears,
            "num_standing_waves": num_standing_waves,
            "diffusion_steps": step,
            "gen_seq_length": current_sequence.shape[1],
            "len_prefill": input_ids.shape[1],
            "recurrence_per_position": recurrence_counter_per_position[:, : frozen_tokens.shape[1]].cpu(),
            "token_stable_per_position": token_stable_per_position[:, : frozen_tokens.shape[1]].cpu(),
        }

        if generation_config.return_dict_in_generate:
            return GenerateDecoderOnlyOutput(
                sequences=frozen_tokens,  # type: ignore
                scores=summary_scores,  # type: ignore
                logits=None,
                attentions=None,
                hidden_states=analysis_data if return_analysis_tablets else None,  # type: ignore
                past_key_values=model_kwargs.get("past_key_values"),
            )
        return frozen_tokens

```
