{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "header",
   "metadata": {},
   "source": [
    "# MIDI-RWKV: Interactive Demo (Prototype)\n",
    "\n",
    "This notebook demonstrates the MIDI-RWKV model for symbolic music generation.\n",
    "\n",
    "This notebook is meant to be run in VSCode, but we intend to release a Colab notebook in the near future to leverage its more user-friendly input methods.\n",
    "However, we wanted to release this version as early as possible, since we are not confident in our ability to maintain anonymity with cloud-based sharing.\n",
    "\n",
    "---\n",
    "\n",
    "## Table of Contents\n",
    "\n",
    "1. [Setup and Installation](#setup)\n",
    "2. [Model Conversion](#conversion)\n",
    "3. [Load Model and Tokenizer](#loading)\n",
    "4. [Inference](#bar-infilling)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "setup-header",
   "metadata": {},
   "source": [
    "## 1. Setup and Installation <a id=\"setup\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "setup",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preliminary setup\n",
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "PROJECT_ROOT = Path.cwd()\n",
    "os.environ[\"PROJECT_ROOT\"] = str(PROJECT_ROOT)\n",
    "sys.path.insert(0, str(PROJECT_ROOT / \"rwkv.cpp\" / \"python\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "install-deps",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install dependencies\n",
    "%pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
    "%pip install miditok symusic transformers numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "conversion-header",
   "metadata": {},
   "source": [
    "## 2. Model Conversion <a id=\"conversion\"></a>\n",
    "\n",
    "The MIDI-RWKV model is distributed in PyTorch format (`.pth`). For efficient CPU inference, we convert it to GGML format using `rwkv.cpp`. This must first be built if it has not been already.\n",
    "\n",
    "This conversion only needs to be done once. The converted model will be saved at `rwkv.cpp/python/rwkv_cpp/rcpp.bin`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b04044f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build rwkv.cpp if necessary\n",
    "librwkv_is_built = any((PROJECT_ROOT / \"rwkv.cpp\").rglob(\"librwkv.so\"))\n",
    "if librwkv_is_built:\n",
    "    print(\"rwkv.cpp already built, skipping build process.\")\n",
    "else:\n",
    "    %pip install -q cmake\n",
    "    %cd {PROJECT_ROOT / \"rwkv.cpp\"}\n",
    "    !cmake .\n",
    "    !cmake --build . --config Release\n",
    "    %cd {PROJECT_ROOT}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "check-model",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert model if necessary\n",
    "pytorch_model = PROJECT_ROOT / \"midi_rwkv.pth\"\n",
    "ggml_model = PROJECT_ROOT / \"rwkv.cpp\" / \"python\" / \"rwkv_cpp\" / \"rcpp.bin\"\n",
    "\n",
    "assert pytorch_model.exists(), \"The model file (midi_rwkv.pth) doesn't seem to be in the right place. Please re-extract the supplemental.\"\n",
    "size_mb = pytorch_model.stat().st_size / (1024 * 1024)\n",
    "\n",
    "if not ggml_model.exists():\n",
    "    print(\"Converting PyTorch model to GGML format...\")\n",
    "    print(\"This may take a few minutes...\\n\")\n",
    "    \n",
    "    !python3 rwkv.cpp/python/convert_pytorch_to_ggml.py \\\n",
    "        midi_rwkv.pth \\\n",
    "        rwkv.cpp/python/rwkv_cpp/rcpp.bin \\\n",
    "        FP16\n",
    "    \n",
    "    print(\"\\nConversion complete!\")\n",
    "else:\n",
    "    print(\"GGML model already exists, skipping conversion.\")\n",
    "    size_mb = ggml_model.stat().st_size / (1024 * 1024)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "loading-header",
   "metadata": {},
   "source": [
    "## 3. Load Model and Tokenizer <a id=\"loading\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "import-libs",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model\n",
    "from miditok import MMM\n",
    "from symusic import Score, Synthesizer, dump_wav\n",
    "from transformers import GenerationConfig\n",
    "import numpy as np\n",
    "\n",
    "from rwkv_cpp.cpp_model import CustomGenerator, CppModelConfig\n",
    "from inference import generate\n",
    "from config import InferenceConfig\n",
    "\n",
    "TOK_PATH = str(PROJECT_ROOT / \"train\" / \"tokenizer\" / \"tokenizer_with_acs.json\")\n",
    "MODEL_PATH = str(ggml_model)\n",
    "INPUT_MIDI = str(PROJECT_ROOT / \"rwkv.cpp\" / \"python\" / \"mat\" / \"rollinggirl.mid\")\n",
    "\n",
    "print(\"Loading tokenizer...\")\n",
    "tokenizer = MMM(params=TOK_PATH)\n",
    "\n",
    "print(\"Loading RWKV model...\")\n",
    "config = CppModelConfig(MODEL_PATH, \"\")\n",
    "model = CustomGenerator(config, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bar-infilling-header",
   "metadata": {},
   "source": [
    "## 4. Bar Infilling Demo <a id=\"bar-infilling\"></a>\n",
    "\n",
    "This section demonstrates single-section infilling inference using attribute controls.\n",
    "\n",
    "### Attribute Controls\n",
    "\n",
    "We can control the generation using attribute controls (ACs). These allow fine-grained control over:\n",
    "- **Polyphony**: Min/max simultaneous notes\n",
    "- **Note Density**: Number of notes per bar\n",
    "- **Note Duration**: Distribution of note lengths (whole, half, quarter, eighth, sixteenth)\n",
    "\n",
    "Each bar can have different attributes for more nuanced control."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "408e0f84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONFIGURATION: modify these variables\n",
    "INPUT_MIDI = \"sample.mid\"\n",
    "\n",
    "track_idx = 0  # Track to infill\n",
    "# Start and end of the infilling region\n",
    "start_bar = 14\n",
    "end_bar = 18\n",
    "\n",
    "context_length = 16  # Single-section infilling C (context length on either side of infilling region)\n",
    "\n",
    "# Sampling parameters\n",
    "gen_config = GenerationConfig(\n",
    "    num_beams=1,\n",
    "    temperature=1.0,\n",
    "    repetition_penalty=1.2,\n",
    "    top_k=20,\n",
    "    top_p=0.95,\n",
    "    max_new_tokens=400,\n",
    "    epsilon_cutoff=9e-4,\n",
    "    do_sample=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "462328ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set reasonable default attribute controls\n",
    "# These can be computed from the original music or entered by the composer\n",
    "# We do not have a good way of entering these manually in VSCode, but we will strive to make a Colab version that is much tidier in the near future\n",
    "default_controls = [\n",
    "    'ACBarOnsetPolyphonyMin_1',\n",
    "    'ACBarOnsetPolyphonyMax_3',\n",
    "    'ACBarNoteDensity_8',\n",
    "    'ACBarNoteDurationWhole_0',\n",
    "    'ACBarNoteDurationHalf_0',\n",
    "    'ACBarNoteDurationQuarter_1',\n",
    "    'ACBarNoteDurationEight_1',\n",
    "    'ACBarNoteDurationSixteenth_1'\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c56b181b",
   "metadata": {},
   "source": [
    "### A note about attribute controls.\n",
    "\n",
    "Since we train with all attribute controls on each bar, the model expects the full set of attribute controls to be provided for each bar. This requires some more effort on the part of the composer but allows more fine-grained control. Additionally, attribute controls can be inferred from existing content and injected if the composer does not wish to manually select attribute controls for each bar, or one control can be set for multiple bars.\n",
    "\n",
    "The generation may not be very good on the first try; it is expected that you will need to run it multiple times to get good results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "run-generation",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Generating...\\n\")\n",
    "\n",
    "output_score = generate(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    InferenceConfig(\n",
    "        bars_to_generate={\n",
    "            track_idx: [(start_bar, end_bar, [default_controls, default_controls], \"bar\")],\n",
    "            # For random infilling, add more track_idx: [list of bars] entries here\n",
    "        },\n",
    "        new_tracks=[],\n",
    "        context_length=context_length\n",
    "    ),\n",
    "    INPUT_MIDI,\n",
    "    {\"generation_config\": gen_config},\n",
    ")\n",
    "\n",
    "print(\"\\nGeneration complete!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "save-output",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save output MIDI and audio, plus reference baseline audio, to /demo_out\n",
    "os.makedirs(\"demo_out\", exist_ok=True)\n",
    "OUTPUT_PATH = str(PROJECT_ROOT / \"demo_out\" / \"output.mid\")\n",
    "output_score.dump_midi(OUTPUT_PATH)\n",
    "print(f\"Output MIDI saved to: {OUTPUT_PATH}\")\n",
    "\n",
    "print(\"Synthesizing audio...\")\n",
    "synth = Synthesizer()\n",
    "output_wav = synth.render(output_score, stereo=True)\n",
    "original_wav = synth.render(Score(INPUT_MIDI), stereo=True)\n",
    "OUTPUT_WAV = str(PROJECT_ROOT / \"demo_out\" / \"output.wav\")\n",
    "ORIGINAL_WAV = str(PROJECT_ROOT / \"demo_out\" / \"original.wav\")\n",
    "dump_wav(OUTPUT_WAV, output_wav, sample_rate=44100, use_int16=True)\n",
    "print(f\"Output audio saved to {OUTPUT_WAV}\")\n",
    "dump_wav(ORIGINAL_WAV, original_wav, sample_rate=44100, use_int16=True)\n",
    "print(f\"Original audio saved to {ORIGINAL_WAV}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
