# Retrieval-Based Pre-Training for Symbolic Regression

This repository contains the implementation of our proposed **retrieval-based pre-training mechanism** for symbolic regression, designed to enable **physics-aligned video forecasting** through interpretable motion equations. Our method is built upon the [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl) package in the Julia programming language.

---

## 🧠 Overview

Our framework first extracts object motion trajectories from video using CoTracker. Then, it employs symbolic regression—augmented with retrieval-based pre-training—to learn equations of motion from these trajectories. The learned equations can be used to forecast future motion and guide trajectory-conditioned video generation models.
Below are some videos generated by different models.
<h4 align="center">Damped Spring System</h4>
<table align="center">
  <tr>
    <td align="center">
      <img src="demos/CogVideoX/cogvideo_damped_spring_mass.gif" width="90"/><br/>
      <sub>CogVideoX</sub>
    </td>
    <td align="center">
      <img src="demos/Tora/Tora_damped_spring_mass.gif" width="90"/><br/>
      <sub>Tora</sub>
    </td>
    <td align="center">
      <img src="demos/Kling/kling_damped_spring_mass.gif" width="90"/><br/>
      <sub>Kling</sub>
    </td>
    <td align="center">
      <img src="demos/Kling_manual_draw/kling_draw_damped_spring_mass.gif" width="90"/><br/>
      <sub>Kling-manual-draw</sub>
    </td>
    <td align="center">
      <img src="demos/GT/damped_spring_mass.gif" width="90"/><br/>
      <sub>GT</sub>
    </td>
  </tr>
</table>

<h4 align="center">Spring System</h4>
<table align="center">
  <tr>
    <td align="center">
      <img src="demos/CogVideoX/cogvideo_spring_mass.gif" width="90"/><br/>
      <sub>CogVideoX</sub>
    </td>
    <td align="center">
      <img src="demos/Tora/Tora_spring_mass.gif" width="90"/><br/>
      <sub>Tora</sub>
    </td>
    <td align="center">
      <img src="demos/Kling/kling_spring_mass.gif" width="90"/><br/>
      <sub>Kling</sub>
    </td>
    <td align="center">
      <img src="demos/Kling_manual_draw/kling_draw_spring_mass.gif" width="90"/><br/>
      <sub>Kling-manual-draw</sub>
    </td>
    <td align="center">
      <img src="demos/GT/spring_mass.gif" width="90"/><br/>
      <sub>GT</sub>
    </td>
  </tr>
</table>

<h4 align="center">Single Pendulum</h4>
<table align="center">
  <tr>
    <td align="center">
      <img src="demos/CogVideoX/cogvideo_single_pendulum.gif" width="90"/><br/>
      <sub>CogVideoX</sub>
    </td>
    <td align="center">
      <img src="demos/Tora/Tora_single_pendulum.gif" width="90"/><br/>
      <sub>Tora</sub>
    </td>
    <td align="center">
      <img src="demos/Kling/kling_single_pendulum.gif" width="90"/><br/>
      <sub>Kling</sub>
    </td>
    <td align="center">
      <img src="demos/Kling_manual_draw/kling_draw_single_pendulum.gif" width="90"/><br/>
      <sub>Kling-manual-draw</sub>
    </td>
    <td align="center">
      <img src="demos/GT/single_pendulum.gif" width="90"/><br/>
      <sub>GT</sub>
    </td>
  </tr>
</table>

<h4 align="center">Double Pendulum</h4>
<table align="center">
  <tr>
    <td align="center">
      <img src="demos/CogVideoX/cogvideo_double_pendulum.gif" width="90"/><br/>
      <sub>CogVideoX</sub>
    </td>
    <td align="center">
      <img src="demos/Tora/Tora_double_pendulum.gif" width="90"/><br/>
      <sub>Tora</sub>
    </td>
    <td align="center">
      <img src="demos/Kling/kling_double_pendulum.gif" width="90"/><br/>
      <sub>Kling</sub>
    </td>
    <td align="center">
      <img src="demos/Kling_manual_draw/kling_draw_double_pendulum.gif" width="90"/><br/>
      <sub>Kling-manual-draw</sub>
    </td>
    <td align="center">
      <img src="demos/GT/double_pendulum.gif" width="90"/><br/>
      <sub>GT</sub>
    </td>
  </tr>
</table>

<h4 align="center">Projectile Motion</h4>
<table align="center">
  <tr>
    <td align="center">
      <img src="demos/CogVideoX/cogvideo_projectile.gif" width="90"/><br/>
      <sub>CogVideoX</sub>
    </td>
    <td align="center">
      <img src="demos/Tora/Tora_projectile.gif" width="90"/><br/>
      <sub>Tora</sub>
    </td>
    <td align="center">
      <img src="demos/Kling/kling_projectile.gif" width="90"/><br/>
      <sub>Kling</sub>
    </td>
    <td align="center">
      <img src="demos/Kling_manual_draw/kling_draw_projectile.gif" width="90"/><br/>
      <sub>Kling-manual-draw</sub>
    </td>
    <td align="center">
      <img src="demos/GT/projectile.gif" width="90"/><br/>
      <sub>GT</sub>
    </td>
  </tr>
</table>

<h4 align="center">Fluid Motion</h4>
<table align="center">
  <tr>
    <td align="center">
      <img src="demos/CogVideoX/cogvideo_flow.gif" width="90"/><br/>
      <sub>CogVideoX</sub>
    </td>
    <td align="center">
      <img src="demos/Tora/Tora_flow.gif" width="90"/><br/>
      <sub>Tora</sub>
    </td>
    <td align="center">
      <img src="demos/Kling/kling_flow.gif" width="90"/><br/>
      <sub>Kling</sub>
    </td>
    <td align="center">
      <img src="demos/Kling_manual_draw/kling_draw_flow.gif" width="90"/><br/>
      <sub>Kling-manual-draw</sub>
    </td>
    <td align="center">
      <img src="demos/GT/flow.gif" width="90"/><br/>
      <sub>GT</sub>
    </td>
  </tr>
</table>

<h4 align="center">Two Body</h4>
<table align="center">
  <tr>
    <td align="center">
      <img src="demos/CogVideoX/cogvideo_two_body.gif" width="90"/><br/>
      <sub>CogVideoX</sub>
    </td>
    <td align="center">
      <img src="demos/Tora/Tora_two_body.gif" width="90"/><br/>
      <sub>Tora</sub>
    </td>
    <td align="center">
      <img src="demos/Kling/kling_two_body.gif" width="90"/><br/>
      <sub>Kling</sub>
    </td>
    <td align="center">
      <img src="demos/Kling_manual_draw/kling_draw_two_body.gif" width="90"/><br/>
      <sub>Kling-manual-draw</sub>
    </td>
    <td align="center">
      <img src="demos/GT/two_body.gif" width="90"/><br/>
      <sub>GT</sub>
    </td>
  </tr>
</table>

---

## 📁 Repository Structure

```
.
├── SymbolicRegression.jl/        # Modified SymbolicRegression package with retrieval-based initialization
├── equation_bank/                # Constructed equation bank containing physics-relevant symbolic expressions
├── demos/                        # demo videos 
├── learn_equations_ReSR.jl       # Main script to run symbolic regression with retrieval-based pre-training
├── extract_traj_co_tracker.py    # Script to extract motion trajectories using CoTracker
└── README.md                     # You're here!
```

---

## 🔧 Dependencies

- **Julia (>=1.11)**
  - [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl) (modified version provided)
- **Python (>=3.10)**
  - `torch`, `numpy`, `opencv-python`, etc. (required for CoTracker)

---

## 📌 Trajectory Extraction

To extract object motion trajectories from a video:

1. Clone and install the [CoTracker](https://github.com/facebookresearch/co-tracker) repository.
2. Run the provided script:
   ```bash
   python extract_traj_co_tracker.py --url path/to/video.mp4 --save_dir path/to/dir
   ```

---

## 🧮 Learning Symbolic Equations

To run symbolic regression with our retrieval-based pre-training mechanism:

```bash
julia learn_equations_ReSR.jl
```

- The script will retrieve initial equations from the `equation_bank/` and use them as candidates to guide symbolic regression.
- Ensure the path to your input trajectory data and equation bank is correctly set in the script.

---

## 🎥 Trajectory-Guided Video Generation

We integrate our predicted trajectories with external trajectory-conditioned I2V models. For reproducing video generation results, please refer to their official repositories:

- [Kling (commercial API)](https://app.klingai.com/cn/dev/document-api/quickStart/userManual)
- [Tora](https://github.com/alibaba/Tora)
- [SG-I2V](https://github.com/Kmcode1/SG-I2V)
- [MotionCtrl](https://github.com/TencentARC/MotionCtrl)
- [DragAnything](https://github.com/showlab/DragAnything)

Please follow their instructions for setup and usage. Our predicted trajectories can be formatted accordingly for each model.

