# Baseline Models Training

This directory contains the training script for baseline models (TimesNet, iTransformer, PatchTST) for multi-task classification.

## Models Supported

- **TimesNet**: Time series forecasting model
- **iTransformer**: Inverted Transformer for time series
- **PatchTST**: Patch-based Time Series Transformer

## Usage

### Basic Usage

```bash
cd baseline
python3 train_baseline_models.py --model TimesNet --gpu 0 --epochs 20
```

### Arguments

- `--model`: Model to train (required): `TimesNet`, `iTransformer`, or `PatchTST`
- `--gpu`: GPU ID (default: `0`)
- `--epochs`: Number of training epochs (default: `20`)
- `--lr`: Learning rate (default: `1e-4`)
- `--batch_size`: Batch size (default: `256`)
- `--num_workers`: Number of data loading workers (default: `4`)
- `--experiment_name`: Experiment name (default: auto-generated)
- `--max_samples`: Maximum training samples (0 for full dataset, default: `80000`)
- `--max_val_samples`: Maximum validation samples (0 for full, default: `10000`)
- `--max_test_samples`: Maximum test samples (0 for full, default: `10000`)
- `--log_dir`: TensorBoard log directory (default: `runs/classify` in release directory)
- `--config`: Path to config file (default: auto-detect)

### Offline Data Support

The script automatically detects offline data if available:
- Checks for `data_dir/sample_data/sample_data.pkl` in the release directory
- Uses offline configuration if detected
- Falls back to database mode if offline data is not available

### Example Commands

```bash
# Train TimesNet with default settings
python3 train_baseline_models.py --model TimesNet

# Train iTransformer with custom settings
python3 train_baseline_models.py --model iTransformer --epochs 50 --batch_size 128 --lr 5e-5

# Train PatchTST with limited samples for quick testing
python3 train_baseline_models.py --model PatchTST --max_samples 1000 --max_val_samples 200
```

## Output

- Training logs: Saved to `runs/classify/{experiment_name}/`
- Best model: `runs/classify/{experiment_name}/best_model.pth`
- Test results: `runs/classify/{experiment_name}/test_results.json`
- Final report: `runs/classify/{experiment_name}/final_report.json`

## Tasks

The script trains on two tasks:
- `mortality_24h_48h`: Mortality prediction within 24-48 hours
- `los_prediction_48h`: Length of stay prediction at 48 hours

## Notes

- The script automatically calculates positive class weights from training data
- Gradient clipping is applied (max_norm=2.0) for training stability
- All debug prints have been removed for cleaner output
- The script supports both offline and database data modes

