<script lang="ts">
	import Reveal from 'reveal.js';
	import * as lmpeek from 'lmpeek';
	import { fromPreTrained } from '@lenml/tokenizer-gpt2';
	import * as d3 from 'd3';
	import Slide from '$lib/components/Slide.svelte';

	let model: any = $state(null),
		tokenizer: ReturnType<typeof fromPreTrained>,
		baseLogits: Float32Array = new Float32Array(),
		probabilities: Float32Array[] | null = $state(null),
		str = $state('I think Wikipedia is very'),
		deck: any,
		tokens: string[] = $state([]),
		topN = 10,
		highlightedIndex = $state(0),
		temperature = $state(1.0),
		temperatureProbabilities: any = $state(null),
		topK = $state(3),
		topP = $state(0.9),
		topPProbabilities: any = $state(null),
		showSlides = $state(false),
		showLoading = $state(false);

	let explorerInput = $state('My favorite actor is Tom'),
		explorerProbabilities: any = $state(null),
		explorerBaseLogits: any = $state(null),
		explorerTemperature = $state(1.0),
		explorerTopK = $state(3),
		explorerTopP = $state(0.75),
		useTemperature = $state(true),
		useTopK = $state(false),
		useTopP = $state(true),
		explorerFlashIndex: number | null = $state(null),
		sampleLoading = $state(false);

	async function loadModelAndComputeOutputs() {
		try {
			console.log('about to load model');
			model = await lmpeek.loadModel('gpt-2');
			console.log('loaded model');
			tokenizer = fromPreTrained();

			if (model && tokenizer) {
				tokens = tokenizer.encode(str).map((id) => tokenizer.decode([id]));
				baseLogits = (await model.forward(str)).final.logits.cpuData;
				probabilities = await model.sample(baseLogits);

				if (probabilities) {
					temperatureProbabilities = [...probabilities];
					topPProbabilities = [...probabilities];
				} else {
					throw new Error('Failed to compute probabilities');
				}

				deck = new Reveal();

				deck.initialize({
					view: 'scroll',
					scrollProgress: true
				});

				deck.addEventListener('ready', () => {
					deck.slide(0, 0);
					showSlides = true;
				});

				setInterval(() => {
					highlightedIndex = (highlightedIndex + 1) % topN;
				}, 200);
			} else {
				throw new Error('Failed to load model or tokenizer');
			}
		} catch (error) {
			console.error('Error loading model:', error);
		}
	}

	let temperatureTimeout: any;
	async function updateTemperature() {
		if (!model || !baseLogits) return;
		clearTimeout(temperatureTimeout);
		temperatureTimeout = setTimeout(async () => {
			temperatureProbabilities = await model.sample(baseLogits, { temperature });
		}, 300);
	}

	async function generateExplorerDistribution() {
		try {
			if (!model) {
				model = await lmpeek.loadModel('gpt-2');
				tokenizer = fromPreTrained();
			}
			const result = await model.forward(explorerInput);
			explorerBaseLogits = result.final.logits.cpuData;
			const options: any = {};
			if (useTemperature) options.temperature = explorerTemperature;
			explorerProbabilities = await model.sample(explorerBaseLogits, options);
		} catch (error) {
			console.error('Error generating distribution:', error);
		}
	}

	function handleTopKChange() {
		if (useTopK && useTopP) useTopP = false;
		updateExplorerDistribution();
	}

	function handleTopPChange() {
		if (useTopP && useTopK) useTopK = false;
		updateExplorerDistribution();
	}

	let explorerTimeout: any;
	async function updateExplorerDistribution() {
		if (!model || !explorerBaseLogits) return;
		clearTimeout(explorerTimeout);
		return new Promise<void>((resolve) => {
			explorerTimeout = setTimeout(async () => {
				const options: any = {};
				if (useTemperature) options.temperature = explorerTemperature;
				explorerProbabilities = await model.sample(explorerBaseLogits, options);
				resolve();
			}, 300);
		});
	}

	async function sampleFromExplorerDistribution() {
		if (!explorerProbabilities || explorerProbabilities.length === 0) return;
		sampleLoading = true;
		try {
			// Build candidate list according to active sampling params (prefix of sorted probs)
			let candidates: [string, number][] = explorerProbabilities as [string, number][];
			if (useTopP && explorerTopP < 1.0) {
				let cum = 0;
				let idx = -1;
				for (let i = 0; i < candidates.length; i++) {
					cum += candidates[i][1];
					if (cum >= explorerTopP) {
						idx = i;
						break;
					}
				}
				if (idx >= 0) candidates = candidates.slice(0, idx + 1);
			}
			if (useTopK) {
				candidates = candidates.slice(0, Math.max(1, Math.min(explorerTopK, candidates.length)));
			}

			// Renormalize and sample via Math.random
			const mass = candidates.reduce((s, [, p]) => s + p, 0);
			if (mass <= 0) return;
			let r = Math.random();
			let acc = 0;
			let chosenLocalIndex = 0;
			for (let i = 0; i < candidates.length; i++) {
				acc += candidates[i][1] / mass;
				if (r <= acc) {
					chosenLocalIndex = i;
					break;
				}
			}
			const [chosenToken] = candidates[chosenLocalIndex];

			// Since candidates are always a prefix of explorerProbabilities, absolute index is the same
			const absoluteIndex = chosenLocalIndex;
			explorerFlashIndex = absoluteIndex;

			// Append token to input
			explorerInput += chosenToken;

			// Start forward pass concurrently with the 2s highlight timer
			const forwardPromise = (async () => {
				const res = await model.forward(explorerInput);
				explorerBaseLogits = res.final.logits.cpuData;
			})();
			const delayPromise = new Promise<void>((res) => setTimeout(() => res(), 2000));

			// Wait for both to complete (whichever finishes first is fine; proceed after both)
			await Promise.all([forwardPromise, delayPromise]);

			explorerFlashIndex = null;
			await updateExplorerDistribution();
		} finally {
			sampleLoading = false;
		}
	}
</script>

<div class="reveal">
	<div class="slides">
		<Slide
			topText="Many people view large language models (LLMs) as <i>next token prediction</i> machines."
		>
			<div class="flex items-center justify-center gap-16 select-none">
				<div class="flex flex-col items-center space-y-4">
					{#each tokens as token}
						<div
							class="bg-[#3b82f6] text-white px-3 py-1.5 rounded-lg text-lg border-2 border-white"
						>
							{token}
						</div>
					{/each}
				</div>
				<div class="text-6xl text-gray-400">→</div>
				<div
					class="bg-green-600 text-white px-3 py-1.5 rounded-lg text-xl font-semibold border-2 border-white"
				>
					{probabilities ? probabilities[0][0] : '...'}
				</div>
			</div>
		</Slide>

		<Slide
			topText="That's actually a misconception."
			bottomText="When an LLM processes text, it doesn't output a token. It actually outputs a probability
					distribution, like the one you see above."
		>
			{#if probabilities}
				{@const topProbs = probabilities.slice(0, topN)}
				{@const width = 700}
				{@const height = 400}
				{@const marginTop = 30}
				{@const marginRight = 0}
				{@const marginBottom = 80}
				{@const marginLeft = 50}
				{@const xScale = d3
					.scaleBand()
					.domain(topProbs.map((_, i) => i))
					.range([marginLeft, width - marginRight])
					.padding(0.1)}
				{@const yScale = d3
					.scaleLinear()
					.domain([0, d3.max(topProbs, (d) => d[1])])
					.range([height - marginBottom, marginTop])}

				<svg
					{width}
					{height}
					viewBox="0 0 {width} {height}"
					style:max-width="100%"
					style:height="auto"
					class="select-none"
				>
					<g fill="#3b82f6">
						{#each topProbs as [token, prob], i}
							<rect
								x={xScale(i)}
								y={yScale(prob)}
								height={yScale(0) - yScale(prob)}
								width={xScale.bandwidth()}
							/>
						{/each}
					</g>

					<g transform="translate(0,{height - marginBottom})">
						<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

						{#each topProbs as [token, prob], i}
							<line
								stroke="currentColor"
								x1={xScale(i) + xScale.bandwidth() / 2}
								x2={xScale(i) + xScale.bandwidth() / 2}
								y1={0}
								y2={6}
							/>

							<text
								fill="currentColor"
								text-anchor="middle"
								x={xScale(i) + xScale.bandwidth() / 2}
								y={22}
								class="text-sm"
							>
								{token.length > 8 ? token.slice(0, 8) + '...' : token}
							</text>
						{/each}

						<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
					</g>

					<g transform="translate({marginLeft},0)">
						<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

						{#each yScale.ticks() as tick}
							{#if tick !== 0}
								<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
							{/if}

							<text
								fill="currentColor"
								text-anchor="end"
								dominant-baseline="middle"
								x={-9}
								y={yScale(tick)}
								class="text-sm"
							>
								{Math.trunc(tick * 100)}
							</text>
						{/each}

						<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
							probability (%)
						</text>
					</g>
				</svg>
			{/if}
		</Slide>

		<Slide
			topText="LLMs are actually <strong>deterministic</strong>!"
			bottomText="If we run the model with the same input text, this resulting
					probability distribution will remain the same."
		></Slide>

		<Slide
			topText="But when you use ChatGPT, you never see this probability distribution!"
			bottomText="You only see the model continually appending tokens to its response. So how <em>do</em> LLMs pick a single token from this distribution?"
		></Slide>

		<Slide
			topText="The easiest way: pick the token with the highest probability!"
			bottomText="This is called <strong>greedy decoding</strong>."
		>
			{#if probabilities}
				{@const topProbs = probabilities.slice(0, topN)}
				{@const width = 700}
				{@const height = 400}
				{@const marginTop = 30}
				{@const marginRight = 0}
				{@const marginBottom = 80}
				{@const marginLeft = 50}
				{@const xScale = d3
					.scaleBand()
					.domain(topProbs.map((_, i) => i))
					.range([marginLeft, width - marginRight])
					.padding(0.1)}
				{@const yScale = d3
					.scaleLinear()
					.domain([0, d3.max(topProbs, (d) => d[1])])
					.range([height - marginBottom, marginTop])}

				<svg
					{width}
					{height}
					viewBox="0 0 {width} {height}"
					style:max-width="100%"
					style:height="auto"
					class="select-none"
				>
					<g>
						{#each topProbs as [token, prob], i}
							<rect
								x={xScale(i)}
								y={yScale(prob)}
								height={yScale(0) - yScale(prob)}
								width={xScale.bandwidth()}
								fill={i === 0 ? '#f59e0b' : '#3b82f6'}
							/>
						{/each}
					</g>

					<g transform="translate(0,{height - marginBottom})">
						<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

						{#each topProbs as [token, prob], i}
							<line
								stroke="currentColor"
								x1={xScale(i) + xScale.bandwidth() / 2}
								x2={xScale(i) + xScale.bandwidth() / 2}
								y1={0}
								y2={6}
							/>

							<text
								fill="currentColor"
								text-anchor="middle"
								x={xScale(i) + xScale.bandwidth() / 2}
								y={22}
								class="text-sm"
							>
								{token.length > 8 ? token.slice(0, 8) + '...' : token}
							</text>
						{/each}

						<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
					</g>

					<g transform="translate({marginLeft},0)">
						<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

						{#each yScale.ticks() as tick}
							{#if tick !== 0}
								<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
							{/if}

							<text
								fill="currentColor"
								text-anchor="end"
								dominant-baseline="middle"
								x={-9}
								y={yScale(tick)}
								class="text-sm"
							>
								{Math.trunc(tick * 100)}
							</text>
						{/each}

						<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
							probability (%)
						</text>
					</g>
				</svg>
			{/if}
		</Slide>

		<Slide
			topText="Greedy decoding makes next token selection deterministic."
			bottomText="But this is likely in stark contrast to your experience with tools like ChatGPT: the same prompt definitely doesn't lead to the same output. So why isn't ChatGPT using greedy decoding?"
		></Slide>

		<Slide
			topText="In practice, greedy decoding is hardly ever used."
			bottomText="Empirically, it doesn't lead to great performance from language models, as it tends to produce boring and repetitive outputs."
		></Slide>

		<Slide topText="So, what's the alternative?"></Slide>

		<Slide
			topText="Sampling!"
			bottomText="Instead of <i>always</i> picking the token with the highest probability, we intentionally introduce stochasticity into the process of selecting a token!"
		>
			{#if probabilities}
				{@const topProbs = probabilities.slice(0, topN)}
				{@const width = 700}
				{@const height = 400}
				{@const marginTop = 30}
				{@const marginRight = 0}
				{@const marginBottom = 80}
				{@const marginLeft = 50}
				{@const xScale = d3
					.scaleBand()
					.domain(topProbs.map((_, i) => i))
					.range([marginLeft, width - marginRight])
					.padding(0.1)}
				{@const yScale = d3
					.scaleLinear()
					.domain([0, d3.max(topProbs, (d) => d[1])])
					.range([height - marginBottom, marginTop])}

				<svg
					{width}
					{height}
					viewBox="0 0 {width} {height}"
					style:max-width="100%"
					style:height="auto"
					class="select-none"
				>
					<g>
						{#each topProbs as [token, prob], i}
							<rect
								x={xScale(i)}
								y={yScale(prob)}
								height={yScale(0) - yScale(prob)}
								width={xScale.bandwidth()}
								fill={i === highlightedIndex ? '#4ade80' : '#3b82f6'}
							/>
						{/each}
					</g>

					<g transform="translate(0,{height - marginBottom})">
						<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

						{#each topProbs as [token, prob], i}
							<line
								stroke="currentColor"
								x1={xScale(i) + xScale.bandwidth() / 2}
								x2={xScale(i) + xScale.bandwidth() / 2}
								y1={0}
								y2={6}
							/>

							<text
								fill="currentColor"
								text-anchor="middle"
								x={xScale(i) + xScale.bandwidth() / 2}
								y={22}
								class="text-sm"
							>
								{token.length > 8 ? token.slice(0, 8) + '...' : token}
							</text>
						{/each}

						<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
					</g>

					<g transform="translate({marginLeft},0)">
						<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

						{#each yScale.ticks() as tick}
							{#if tick !== 0}
								<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
							{/if}

							<text
								fill="currentColor"
								text-anchor="end"
								dominant-baseline="middle"
								x={-9}
								y={yScale(tick)}
								class="text-sm"
							>
								{Math.trunc(tick * 100)}
							</text>
						{/each}

						<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
							probability (%)
						</text>
					</g>
				</svg>
			{/if}
		</Slide>

		<Slide
			topText="Think of sampling like drawing a colored marble from a bag with a variety of colored
					marbles."
			bottomText="Each color represents a different token, and the number of marbles of each color corresponds to that token's probability. We randomly draw one marble (token) from the bag based on these proportions."
		></Slide>

		<Slide
			topText="But sampling isn't without issues!"
			bottomText="Occasionally, the very nature of sampling can work to our detriment: we might end up picking a token that was assigned a very low probability. In the worst of cases, this can lead to incoherent or nonsensical text generation."
		>
			{#if probabilities}
				{@const topProbs = probabilities.slice(0, topN)}
				{@const width = 700}
				{@const height = 400}
				{@const marginTop = 30}
				{@const marginRight = 0}
				{@const marginBottom = 80}
				{@const marginLeft = 50}
				{@const xScale = d3
					.scaleBand()
					.domain(topProbs.map((_, i) => i))
					.range([marginLeft, width - marginRight])
					.padding(0.1)}
				{@const yScale = d3
					.scaleLinear()
					.domain([0, d3.max(topProbs, (d) => d[1])])
					.range([height - marginBottom, marginTop])}

				<svg
					{width}
					{height}
					viewBox="0 0 {width} {height}"
					style:max-width="100%"
					style:height="auto"
					class="select-none"
				>
					<g>
						{#each topProbs as [token, prob], i}
							<rect
								x={xScale(i)}
								y={yScale(prob)}
								height={yScale(0) - yScale(prob)}
								width={xScale.bandwidth()}
								fill={i === topProbs.length - 1 ? '#dc2626' : '#3b82f6'}
							/>
						{/each}
					</g>

					<g transform="translate(0,{height - marginBottom})">
						<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

						{#each topProbs as [token, prob], i}
							<line
								stroke="currentColor"
								x1={xScale(i) + xScale.bandwidth() / 2}
								x2={xScale(i) + xScale.bandwidth() / 2}
								y1={0}
								y2={6}
							/>

							<text
								fill="currentColor"
								text-anchor="middle"
								x={xScale(i) + xScale.bandwidth() / 2}
								y={22}
								class="text-sm"
							>
								{token.length > 8 ? token.slice(0, 8) + '...' : token}
							</text>
						{/each}

						<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
					</g>

					<g transform="translate({marginLeft},0)">
						<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

						{#each yScale.ticks() as tick}
							{#if tick !== 0}
								<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
							{/if}

							<text
								fill="currentColor"
								text-anchor="end"
								dominant-baseline="middle"
								x={-9}
								y={yScale(tick)}
								class="text-sm"
							>
								{Math.trunc(tick * 100)}
							</text>
						{/each}

						<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
							probability (%)
						</text>
					</g>
				</svg>
			{/if}
		</Slide>

		<Slide topText="We need to make some adjustments to the sampling process."></Slide>

		<Slide
			topText="Introducing... sampling parameters!"
			bottomText="By introducing parameters that modify the underlying probability distribution before sampling actually takes place, we can ensure that the model generates more relevant and coherent text."
		></Slide>

		<Slide
			topText="Temperature"
			bottomText="Let's introduce our first sampling parameter: <strong>temperature</strong>. Temperature controls the <i>sharpness</i> of the overall probability distribution. A lower temperature value accentuates the differences between token probabilities, while a higher value make the distribution more uniform."
		>
			<div class="relative">
				{#if temperatureProbabilities}
					{@const topProbs = temperatureProbabilities.slice(0, topN)}
					{@const width = 700}
					{@const height = 400}
					{@const marginTop = 30}
					{@const marginRight = 0}
					{@const marginBottom = 80}
					{@const marginLeft = 50}
					{@const xScale = d3
						.scaleBand()
						.domain(topProbs.map((_, i) => i))
						.range([marginLeft, width - marginRight])
						.padding(0.1)}
					{@const yScale = d3
						.scaleLinear()
						.domain([0, d3.max(topProbs, (d) => d[1])])
						.range([height - marginBottom, marginTop])}

					<svg
						{width}
						{height}
						viewBox="0 0 {width} {height}"
						style:max-width="100%"
						style:height="auto"
						class="select-none"
					>
						<g fill="#f59e0b">
							{#each topProbs as [token, prob], i}
								<rect
									x={xScale(i)}
									y={yScale(prob)}
									height={yScale(0) - yScale(prob)}
									width={xScale.bandwidth()}
								/>
							{/each}
						</g>

						<g transform="translate(0,{height - marginBottom})">
							<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

							{#each topProbs as [token, prob], i}
								<line
									stroke="currentColor"
									x1={xScale(i) + xScale.bandwidth() / 2}
									x2={xScale(i) + xScale.bandwidth() / 2}
									y1={0}
									y2={6}
								/>

								<text
									fill="currentColor"
									text-anchor="middle"
									x={xScale(i) + xScale.bandwidth() / 2}
									y={22}
									class="text-sm"
								>
									{token.length > 8 ? token.slice(0, 8) + '...' : token}
								</text>
							{/each}

							<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
						</g>

						<g transform="translate({marginLeft},0)">
							<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

							{#each yScale.ticks() as tick}
								{#if tick !== 0}
									<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
								{/if}

								<text
									fill="currentColor"
									text-anchor="end"
									dominant-baseline="middle"
									x={-9}
									y={yScale(tick)}
									class="text-sm"
								>
									{Math.trunc(tick * 100)}
								</text>
							{/each}

							<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
								probability (%)
							</text>
						</g>
					</svg>
					<div
						class="absolute z-10 top-2 right-2 bg-gray-800 p-4 pb-6 rounded-lg flex flex-col gap-y-4 items-center justify-center"
					>
						<div class="block text-sm font-medium text-white">
							Temperature: <span class="font-bold mono">{temperature.toFixed(1)}</span>
						</div>
						<input
							type="range"
							min="0.1"
							max="1.5"
							step="0.1"
							bind:value={temperature}
							oninput={updateTemperature}
							class="w-full h-2 bg-gray-700 rounded-lg appearance-none cursor-pointer"
						/>
					</div>
				{/if}
			</div>
		</Slide>

		<Slide
			topText="But this doesn't completely solve our problem."
			bottomText="Low probability tokens can still be occasionally sampled; it's just that now, they're
					sampled at a far smaller frequency."
		></Slide>

		<Slide
			topText="Why not restrict our sampling to the top K tokens?"
			bottomText="If we do this (and subsequently renormalize token probabilities), we can ensure that only the most likely candidates are considered. This is yet another sampling parameter, one that's aptly named <strong>top-k</strong>."
		>
			<div class="relative">
				{#if probabilities}
					{@const topProbs = probabilities.slice(0, topN)}
					{@const width = 700}
					{@const height = 400}
					{@const marginTop = 30}
					{@const marginRight = 0}
					{@const marginBottom = 80}
					{@const marginLeft = 50}
					{@const xScale = d3
						.scaleBand()
						.domain(topProbs.map((_, i) => i))
						.range([marginLeft, width - marginRight])
						.padding(0.1)}
					{@const yScale = d3
						.scaleLinear()
						.domain([0, d3.max(topProbs, (d) => d[1])])
						.range([height - marginBottom, marginTop])}

					<svg
						{width}
						{height}
						viewBox="0 0 {width} {height}"
						style:max-width="100%"
						style:height="auto"
						class="select-none"
					>
						<g>
							{#each topProbs as [token, prob], i}
								<rect
									x={xScale(i)}
									y={yScale(prob)}
									height={yScale(0) - yScale(prob)}
									width={xScale.bandwidth()}
									fill={i < topK ? '#f59e0b' : '#3b82f6'}
								/>
							{/each}
						</g>

						<g transform="translate(0,{height - marginBottom})">
							<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

							{#each topProbs as [token, prob], i}
								<line
									stroke="currentColor"
									x1={xScale(i) + xScale.bandwidth() / 2}
									x2={xScale(i) + xScale.bandwidth() / 2}
									y1={0}
									y2={6}
								/>

								<text
									fill="currentColor"
									text-anchor="middle"
									x={xScale(i) + xScale.bandwidth() / 2}
									y={22}
									class="text-sm"
								>
									{token.length > 8 ? token.slice(0, 8) + '...' : token}
								</text>
							{/each}

							<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
						</g>

						<g transform="translate({marginLeft},0)">
							<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

							{#each yScale.ticks() as tick}
								{#if tick !== 0}
									<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
								{/if}

								<text
									fill="currentColor"
									text-anchor="end"
									dominant-baseline="middle"
									x={-9}
									y={yScale(tick)}
									class="text-sm"
								>
									{Math.trunc(tick * 100)}
								</text>
							{/each}

							<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
								probability (%)
							</text>
						</g>
					</svg>

					<div
						class="absolute z-10 top-2 right-2 bg-gray-800 p-4 pb-6 rounded-lg flex flex-col gap-y-4 items-center justify-center"
					>
						<div class="block text-sm font-medium text-white">
							Top-k: <span class="font-bold mono">{topK}</span>
						</div>
						<input
							type="range"
							min="1"
							max="5"
							step="1"
							bind:value={topK}
							class="w-full h-2 bg-gray-700 rounded-lg appearance-none cursor-pointer"
						/>
					</div>
				{/if}
			</div>
		</Slide>

		<Slide topText="But can you see where top-k might also lead to suboptimal results?"></Slide>

		<Slide
			topText="Top-k doesn't pay attention to the underlying probability distribution."
			bottomText="This means that using too small of a top-k value can be overly restrictive. In the graphic above, you can see that tokens D and E are excluded from sampling, even though the model assigns relatively high probabilities to them."
		>
			{@const fakeData = [
				{ token: 'token A', prob: 0.21 },
				{ token: 'token B', prob: 0.2 },
				{ token: 'token C', prob: 0.19 },
				{ token: 'token D', prob: 0.18 },
				{ token: 'token E', prob: 0.17 },
				{ token: 'token F', prob: 0.03 },
				{ token: 'token G', prob: 0.02 }
			]}
			{@const width = 700}
			{@const height = 400}
			{@const marginTop = 30}
			{@const marginRight = 0}
			{@const marginBottom = 80}
			{@const marginLeft = 50}
			{@const xScale = d3
				.scaleBand()
				.domain(fakeData.map((_, i) => i))
				.range([marginLeft, width - marginRight])
				.padding(0.1)}
			{@const yScale = d3
				.scaleLinear()
				.domain([0, d3.max(fakeData, (d) => d.prob)])
				.range([height - marginBottom, marginTop])}

			<svg
				{width}
				{height}
				viewBox="0 0 {width} {height}"
				style:max-width="100%"
				style:height="auto"
				class="select-none"
			>
				<g>
					{#each fakeData as d, i}
						<rect
							x={xScale(i)}
							y={yScale(d.prob)}
							height={yScale(0) - yScale(d.prob)}
							width={xScale.bandwidth()}
							fill={i < 3 ? '#f59e0b' : '#3b82f6'}
						/>
					{/each}
				</g>

				<g transform="translate(0,{height - marginBottom})">
					<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

					{#each fakeData as d, i}
						<line
							stroke="currentColor"
							x1={xScale(i) + xScale.bandwidth() / 2}
							x2={xScale(i) + xScale.bandwidth() / 2}
							y1={0}
							y2={6}
						/>

						<text
							fill="currentColor"
							text-anchor="middle"
							x={xScale(i) + xScale.bandwidth() / 2}
							y={22}
							class="text-sm"
						>
							{d.token}
						</text>
					{/each}

					<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
				</g>

				<g transform="translate({marginLeft},0)">
					<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

					{#each yScale.ticks() as tick}
						{#if tick !== 0}
							<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
						{/if}

						<text
							fill="currentColor"
							text-anchor="end"
							dominant-baseline="middle"
							x={-9}
							y={yScale(tick)}
							class="text-sm"
						>
							{Math.trunc(tick * 100)}
						</text>
					{/each}

					<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
						probability (%)
					</text>
				</g>
			</svg>
		</Slide>

		<Slide
			topText="In many cases, we actually want to only sample from tokens that make up the top X% of the
					distribution."
		></Slide>

		<Slide
			topText="Enter top-p!"
			bottomText="<strong>Top-p</strong> sampling (also called <strong>nucleus sampling</strong>) selects from the smallest possible set of
					tokens whose cumulative probability exceeds the threshold p. This maintains diversity
					while avoiding unlikely tokens."
		>
			<div class="relative">
				{#if topPProbabilities}
					{@const topProbs = topPProbabilities.slice(0, topN)}
					{@const width = 700}
					{@const height = 400}
					{@const marginTop = 30}
					{@const marginRight = 0}
					{@const marginBottom = 80}
					{@const marginLeft = 50}
					{@const xScale = d3
						.scaleBand()
						.domain(topProbs.map((_, i) => i))
						.range([marginLeft, width - marginRight])
						.padding(0.1)}
					{@const yScale = d3
						.scaleLinear()
						.domain([0, d3.max(topProbs, (d) => d[1])])
						.range([height - marginBottom, marginTop])}

					{@const runningTotal = topProbs.reduce(
						(acc, curr, i) => {
							acc.sum += curr[1];
							if (acc.count === 0 && acc.sum >= topP) acc.count = i + 1;
							return acc;
						},
						{ sum: 0, count: 0 }
					)}
					{@const numTopPTokens = runningTotal.count || topProbs.length}

					<svg
						{width}
						{height}
						viewBox="0 0 {width} {height}"
						style:max-width="100%"
						style:height="auto"
						class="select-none"
					>
						<g>
							{#each topProbs as [token, prob], i}
								<rect
									x={xScale(i)}
									y={yScale(prob)}
									height={yScale(0) - yScale(prob)}
									width={xScale.bandwidth()}
									fill={i < numTopPTokens ? '#f59e0b' : '#3b82f6'}
								/>
							{/each}
						</g>

						<g transform="translate(0,{height - marginBottom})">
							<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

							{#each topProbs as [token, prob], i}
								<line
									stroke="currentColor"
									x1={xScale(i) + xScale.bandwidth() / 2}
									x2={xScale(i) + xScale.bandwidth() / 2}
									y1={0}
									y2={6}
								/>

								<text
									fill="currentColor"
									text-anchor="middle"
									x={xScale(i) + xScale.bandwidth() / 2}
									y={22}
									class="text-sm"
								>
									{token.length > 8 ? token.slice(0, 8) + '...' : token}
								</text>
							{/each}

							<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
						</g>

						<g transform="translate({marginLeft},0)">
							<line stroke="currentColor" x1={0} x2={0} y1={marginTop} y2={height - marginBottom} />

							{#each yScale.ticks() as tick}
								{#if tick !== 0}
									<line stroke="currentColor" x1={0} x2={-6} y1={yScale(tick)} y2={yScale(tick)} />
								{/if}

								<text
									fill="currentColor"
									text-anchor="end"
									dominant-baseline="middle"
									x={-9}
									y={yScale(tick)}
									class="text-sm"
								>
									{Math.trunc(tick * 100)}
								</text>
							{/each}

							<text fill="currentColor" text-anchor="start" x={-marginLeft} y={15} class="text-lg">
								probability (%)
							</text>
						</g>
					</svg>

					<div
						class="absolute z-10 top-2 right-2 bg-gray-800 p-4 pb-6 rounded-lg flex flex-col gap-y-4 items-center justify-center"
					>
						<div class="block text-sm font-medium text-white">
							Top-p: <span class="font-bold mono">{topP.toFixed(2)}</span>
						</div>
						<input
							type="range"
							min="0.1"
							max="1.0"
							step="0.01"
							bind:value={topP}
							class="w-full h-2 bg-gray-700 rounded-lg appearance-none cursor-pointer"
						/>
					</div>
				{/if}
			</div>
		</Slide>

		<Slide
			topText="Now, it's your turn to explore!"
			bottomText="On the next slide, you can generate and sample tokens using the GPT-2 small model."
		></Slide>

		<Slide topText="<div class='w-full'>Explore sampling with GPT-2!</div>">
			<div class="flex flex-col items-center gap-y-8">
				<div class="flex items-center w-[600px] rounded-lg">
					<input
						bind:value={explorerInput}
						class="flex-1 px-3 py-2 text-black rounded-l-lg border-2 border-white focus:outline-none text-xl border-r-0"
						placeholder="Enter text here..."
					/>
					<button
						onclick={generateExplorerDistribution}
						disabled={sampleLoading}
						class="px-3 py-2 bg-blue-500 hover:bg-blue-600 disabled:opacity-60 disabled:cursor-not-allowed border-2 border-white text-white font-semibold rounded-r-lg transition-colors duration-200 text-xl border-l-0"
					>
						View Distribution
					</button>
				</div>
				{#if explorerProbabilities}
					{@const topProbs = explorerProbabilities.slice(0, topN)}
					{@const width = 700}
					{@const height = 400}
					{@const marginTop = 30}
					{@const marginRight = 0}
					{@const marginBottom = 80}
					{@const marginLeft = 50}
					{@const xScale = d3
						.scaleBand()
						.domain(topProbs.map((_, i) => i))
						.range([marginLeft, width - marginRight])
						.padding(0.1)}
					{@const yScale = d3
						.scaleLinear()
						.domain([0, d3.max(topProbs, (d) => d[1])])
						.range([height - marginBottom, marginTop])}

					{@const runningTotal = topProbs.reduce(
						(acc, curr, i) => {
							acc.sum += curr[1];
							if (acc.count === 0 && acc.sum >= explorerTopP) acc.count = i + 1;
							return acc;
						},
						{ sum: 0, count: 0 }
					)}
					{@const numTopPTokens = useTopP ? runningTotal.count || topProbs.length : topProbs.length}
					{@const numTopKTokens = useTopK ? explorerTopK : topProbs.length}
					{@const highlightedTokens = Math.min(numTopPTokens, numTopKTokens)}

					<div class="w-full relative">
						<svg
							{width}
							{height}
							viewBox="0 0 {width} {height}"
							style:max-width="100%"
							style:height="auto"
							class="select-none"
						>
							<g>
								{#each topProbs as [token, prob], i}
									<rect
										x={xScale(i)}
										y={yScale(prob)}
										height={yScale(0) - yScale(prob)}
										width={xScale.bandwidth()}
										fill={i === explorerFlashIndex
											? '#22c55e'
											: i < highlightedTokens
												? '#f59e0b'
												: '#3b82f6'}
									/>
								{/each}
							</g>

							<g transform="translate(0,{height - marginBottom})">
								<line stroke="currentColor" x1={marginLeft - 6} x2={width} />

								{#each topProbs as [token, prob], i}
									<line
										stroke="currentColor"
										x1={xScale(i) + xScale.bandwidth() / 2}
										x2={xScale(i) + xScale.bandwidth() / 2}
										y1={0}
										y2={6}
									/>

									<text
										fill="currentColor"
										text-anchor="middle"
										x={xScale(i) + xScale.bandwidth() / 2}
										y={22}
										class="text-sm"
									>
										{token.length > 8 ? token.slice(0, 8) + '...' : token}
									</text>
								{/each}

								<text fill="currentColor" x={width / 2} y={65} class="text-lg">tokens</text>
							</g>

							<g transform="translate({marginLeft},0)">
								<line
									stroke="currentColor"
									x1={0}
									x2={0}
									y1={marginTop}
									y2={height - marginBottom}
								/>

								{#each yScale.ticks() as tick}
									{#if tick !== 0}
										<line
											stroke="currentColor"
											x1={0}
											x2={-6}
											y1={yScale(tick)}
											y2={yScale(tick)}
										/>
									{/if}

									<text
										fill="currentColor"
										text-anchor="end"
										dominant-baseline="middle"
										x={-9}
										y={yScale(tick)}
										class="text-sm"
									>
										{Math.trunc(tick * 100)}
									</text>
								{/each}

								<text
									fill="currentColor"
									text-anchor="start"
									x={-marginLeft}
									y={15}
									class="text-lg"
								>
									probability (%)
								</text>
							</g>
						</svg>

						<div
							class="absolute grid grid-cols-2 z-10 top-2 right-2 bg-gray-800 p-3 rounded-lg gap-4"
						>
							<div class="flex flex-col items-center justify-center w-full gap-y-2">
								<div class="flex items-center gap-x-1 w-full">
									<input
										type="checkbox"
										bind:checked={useTemperature}
										class="mr-1"
										id="temp-checkbox"
									/>
									<div class="block text-sm font-medium text-white">
										Temperature: <span class="font-bold mono">{explorerTemperature.toFixed(1)}</span
										>
									</div>
								</div>
								<input
									type="range"
									min="0.1"
									max="1.5"
									step="0.1"
									bind:value={explorerTemperature}
									disabled={!useTemperature}
									oninput={updateExplorerDistribution}
									class="w-full h-2 bg-gray-700 rounded-lg appearance-none cursor-pointer disabled:opacity-50"
								/>
							</div>

							<div class="flex flex-col items-center justify-center w-full gap-y-2">
								<div class="flex items-center gap-x-1 w-full">
									<input
										type="checkbox"
										bind:checked={useTopK}
										onchange={handleTopKChange}
										class="mr-1"
										id="topk-checkbox"
									/>
									<div class="block text-sm font-medium text-white">
										Top-K: <span class="font-bold mono">{explorerTopK}</span>
									</div>
								</div>
								<input
									type="range"
									min="1"
									max="10"
									step="1"
									bind:value={explorerTopK}
									disabled={!useTopK}
									oninput={updateExplorerDistribution}
									class="w-full h-2 bg-gray-700 rounded-lg appearance-none cursor-pointer disabled:opacity-50"
								/>
							</div>

							<div class="flex flex-col items-center justify-center w-full gap-y-2">
								<div class="flex items-center gap-x-1 w-full">
									<input
										type="checkbox"
										bind:checked={useTopP}
										onchange={handleTopPChange}
										class="mr-1"
										id="topp-checkbox"
									/>
									<div class="block text-sm font-medium text-white">
										Top-P: <span class="font-bold mono">{explorerTopP.toFixed(2)}</span>
									</div>
								</div>
								<input
									type="range"
									min="0.1"
									max="1.0"
									step="0.01"
									bind:value={explorerTopP}
									disabled={!useTopP}
									oninput={updateExplorerDistribution}
									class="w-full h-2 bg-gray-700 rounded-lg appearance-none cursor-pointer disabled:opacity-50"
								/>
							</div>

							<button
								type="button"
								disabled={sampleLoading}
								class="bg-green-600 hover:bg-green-700 disabled:opacity-60 disabled:cursor-not-allowed text-white font-semibold rounded-lg transition-colors duration-200 py-1 px-2 border-2 border-white text-lg flex items-center justify-center gap-2 min-h-[46px]"
								onclick={sampleFromExplorerDistribution}
							>
								{#if sampleLoading}
									<svg
										xmlns="http://www.w3.org/2000/svg"
										width="20"
										height="20"
										viewBox="0 0 24 24"
										fill="none"
										stroke="currentColor"
										stroke-width="2"
										stroke-linecap="round"
										stroke-linejoin="round"
										class="icon icon-tabler icons-tabler-outline icon-tabler-loader-2 animate-spin"
										><path stroke="none" d="M0 0h24v24H0z" fill="none" /><path
											d="M12 3a9 9 0 1 0 9 9"
										/></svg
									>
								{:else}
									Sample
								{/if}
							</button>
						</div>
					</div>
					<div class="text-2xl max-w-[650px]">
						Try experimenting with different sampling parameters to see how they affect the
						resulting probability distribution. Click <span class="font-bold">sample</span> to pick the
						next token from the distribution!
					</div>
				{:else}
					<div class="text-2xl max-w-[650px]">
						Modify the above input text to your liking, and click <span class="font-bold"
							>View Distribution</span
						> to visualize the next token probability distribution.
					</div>
				{/if}
			</div>
		</Slide>
	</div>
</div>

{#if !showSlides}
	<div class="w-[100svw] h-[100svh] flex flex-col items-center justify-center text-white gap-y-7">
		<h1 class="text-7xl font-medium text-center max-w-[700px] leading-tighter">
			The Art of Picking the <span class="text-amber-400 font-bold">Next Token</span>
		</h1>
		<h2 class="text-2xl text-center max-w-[650px] mb-3">
			How do large language models select the next token in a sequence? This interactive
			"scrollytelling" experience teaches you about the wonderful world of <span class="font-bold"
				>sampling</span
			>!
		</h2>
		<button
			onclick={async () => {
				showLoading = true;
				await loadModelAndComputeOutputs();
				showLoading = false;
			}}
			type="button"
			disabled={showLoading}
			class="min-w-[300px] min-h-[60px] bg-blue-500 hover:bg-blue-600 disabled:opacity-60 disabled:cursor-not-allowed text-white font-semibold rounded-lg transition-colors duration-200 text-3xl border-4 border-white flex items-center justify-center"
		>
			{#if showLoading}
				<svg
					xmlns="http://www.w3.org/2000/svg"
					width="24"
					height="24"
					viewBox="0 0 24 24"
					fill="none"
					stroke="currentColor"
					stroke-width="2"
					stroke-linecap="round"
					stroke-linejoin="round"
					class="icon icon-tabler icons-tabler-outline icon-tabler-loader-2 size-8 animate-spin"
					><path stroke="none" d="M0 0h24v24H0z" fill="none" /><path d="M12 3a9 9 0 1 0 9 9" /></svg
				>
			{:else}
				Load experience
			{/if}
		</button>
	</div>
{/if}
