# Ulysses Memory Issue Fix

## Problem
When using Ulysses parallelism, FSDP allocates memory for the full model size on each rank instead of the Ulysses-sharded size. This causes memory usage to scale linearly with Ulysses degree.

## Root Cause
The `dp_shard_cp` mesh used by FSDP only includes `dp_shard` and `cp_ring` dimensions, but **not** `cp_ulysses`. FSDP is unaware that attention heads should be distributed across Ulysses ranks.

## Solution Options

### Option 1: Include Ulysses in FSDP Mesh (Recommended)
Modify `parallel_dims.py` to include `cp_ulysses` in the `dp_shard_cp` mesh:

```python
# In parallel_dims.py, around line 96
if self.cp_ulysses > 1:
    cp_mesh_dim_names.append("cp_ulysses")
    dp_shard_cp_mesh_dim_names.append("cp_ulysses")  # Add this line
    dp_cp_mesh_dim_names.append("cp_ulysses")  # Add this line
```

This will make FSDP aware of Ulysses sharding and properly distribute model parameters.

### Option 2: Pre-shard Attention Modules
Modify the attention module initialization to account for Ulysses sharding:

```python
# In model.py Attention.__init__
def __init__(self, model_args: TransformerModelArgs, ulysses_degree: int = 1):
    super().__init__()
    # Adjust head counts for Ulysses sharding
    self.n_heads = model_args.n_heads // ulysses_degree
    self.n_kv_heads = (
        model_args.n_heads // ulysses_degree
        if model_args.n_kv_heads is None
        else model_args.n_kv_heads // ulysses_degree
    )
    # ... rest of initialization
```

### Option 3: Custom FSDP Policy
Create a custom FSDP auto wrap policy that's aware of Ulysses sharding:

```python
def ulysses_aware_fsdp_policy(module, ulysses_degree):
    # Custom logic to adjust module parameters based on Ulysses degree
    if isinstance(module, Attention):
        # Manually shard attention weights
        module.wq.weight = module.wq.weight.chunk(ulysses_degree, dim=0)[rank]
        # ... similar for wk, wv, wo
    return module
```

## Implementation Steps

### For Option 1 (Recommended):

1. **Update parallel_dims.py**:
```python
# Around line 96
if self.cp_ulysses > 1:
    cp_mesh_dim_names.append("cp_ulysses")
    dp_shard_cp_mesh_dim_names.append("cp_ulysses")
    dp_cp_mesh_dim_names.append("cp_ulysses")
```

2. **Update mesh flattening**:
```python
# Around line 104
if dp_shard_cp_mesh_dim_names != []:
    # For 3D mesh (dp_shard, cp_ring, cp_ulysses), flatten appropriately
    mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
        mesh_dim_name="dp_shard_cp"
    )
```

3. **Verify SDPA handler compatibility**:
Ensure the SDPA handler can still extract the Ulysses dimension from the mesh for attention computation.

## Testing

1. Run with different Ulysses degrees and verify memory usage:
   - Ulysses=1: Baseline memory
   - Ulysses=2: Should be ~50% of baseline
   - Ulysses=4: Should be ~25% of baseline

2. Verify model convergence is unchanged

3. Check that attention computation still works correctly with the 2D context parallel mesh

## Expected Outcome

After the fix:
- Memory usage should scale as `base_memory / ulysses_degree`
- Each rank should only allocate memory for its portion of attention heads
- No change in model accuracy or convergence

## Temporary Workaround

Until fixed, users can:
1. Use only ring attention (set cp_ulysses=1)
2. Account for the memory overhead: `actual_memory = expected_memory * ulysses_degree`
3. Reduce batch size or sequence length to compensate 