{
  "id": "astropy__astropy-12907",
  "question": "Modeling's `separability_matrix` does not compute separability correctly for nested CompoundModels\nConsider the following model:\r\n\r\n```python\r\nfrom astropy.modeling import models as m\r\nfrom astropy.modeling.separable import separability_matrix\r\n\r\ncm = m.Linear1D(10) & m.Linear1D(5)\r\n```\r\n\r\nIt's separability matrix as you might expect is a diagonal:\r\n\r\n```python\r\n>>> separability_matrix(cm)\r\narray([[ True, False],\r\n       [False,  True]])\r\n```\r\n\r\nIf I make the model more complex:\r\n```python\r\n>>> separability_matrix(m.Pix2Sky_TAN() & m.Linear1D(10) & m.Linear1D(5))\r\narray([[ True,  True, False, False],\r\n       [ True,  True, False, False],\r\n       [False, False,  True, False],\r\n       [False, False, False,  True]])\r\n```\r\n\r\nThe output matrix is again, as expected, the outputs and inputs to the linear models are separable and independent of each other.\r\n\r\nIf however, I nest these compound models:\r\n```python\r\n>>> separability_matrix(m.Pix2Sky_TAN() & cm)\r\narray([[ True,  True, False, False],\r\n       [ True,  True, False, False],\r\n       [False, False,  True,  True],\r\n       [False, False,  True,  True]])\r\n```\r\nSuddenly the inputs and outputs are no longer separable?\r\n\r\nThis feels like a bug to me, but I might be missing something?\n",
  "answer": "{\"task_id\": \"cd5b34e2-13c5-4483-9a03-d5a8a8d61114\", \"workflow_run_id\": \"cac51e5b-4b4b-43a5-a133-6af6961e0821\", \"data\": {\"id\": \"cac51e5b-4b4b-43a5-a133-6af6961e0821\", \"workflow_id\": \"9a0a6efa-4d15-4876-b13f-4fb6a999caa2\", \"status\": \"succeeded\", \"outputs\": {\"final_patch\": [\"--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -85,6 +85,7 @@ def _traverse_model_graph(model, inputs, outputs):\\n     \\\"\\\"\\\"\\n     from astropy.modeling.core import Model\\n     from astropy.modeling.mappings import Mapping\\n+    from astropy.modeling.core import CompoundModel\\n \\n     if isinstance(model, Model):\\n         if isinstance(model, Mapping):\\n@@ -95,6 +96,14 @@ def _traverse_model_graph(model, inputs, outputs):\\n             outputs.extend(range(n_outputs))\\n         else:\\n             # Handle compound models\\n+            # Check if any submodel is itself a CompoundModel and flatten recursively\\n+            flattened_submodels = []\\n+            for submodel in model.submodel_names:\\n+                sub = getattr(model, submodel)\\n+                if isinstance(sub, CompoundModel):\\n+                    # Recursively flatten nested compound models\\n+                    _traverse_model_graph(sub, inputs, outputs)\\n+                else:\\n             for submodel_name in model.submodel_names:\\n                 submodel = getattr(model, submodel_name)\\n                 _traverse_model_graph(submodel, inputs, outputs)\", \"--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -10,6 +10,25 @@ from astropy.modeling.core import Model, CompoundModel\\n from astropy.modeling.mappings import Mapping\\n \\n \\n+def _flatten_compound_model(model):\\n+    \\\"\\\"\\\"\\n+    Recursively flatten nested CompoundModels into a list of atomic models.\\n+    \\\"\\\"\\\"\\n+    from astropy.modeling.core import CompoundModel\\n+    \\n+    if not isinstance(model, CompoundModel):\\n+        return [model]\\n+    \\n+    flattened = []\\n+    for submodel_name in model.submodel_names:\\n+        submodel = getattr(model, submodel_name)\\n+        if isinstance(submodel, CompoundModel):\\n+            flattened.extend(_flatten_compound_model(submodel))\\n+        else:\\n+            flattened.append(submodel)\\n+    return flattened\\n+\\n+\\n def separability_matrix(model):\\n     \\\"\\\"\\\"\\n     Compute the separability matrix for a model.\\n@@ -30,6 +49,13 @@ def separability_matrix(model):\\n     if not isinstance(model, Model):\\n         raise ValueError(\\\"Input must be a Model.\\\")\\n \\n+    # Flatten nested CompoundModels first\\n+    if isinstance(model, CompoundModel):\\n+        flattened_models = _flatten_compound_model(model)\\n+        # Reconstruct the model structure with flattened components\\n+        if len(flattened_models) > 1:\\n+            model = flattened_models[0]\\n+            for m in flattened_models[1:]:\\n+                model = model & m\\n+\\n     inputs = []\\n     outputs = []\\n     _traverse_model_graph(model, inputs, outputs)\", \"--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -82,7 +82,7 @@ def _traverse_model_graph(model, inputs, outputs):\\n     Traverse the model graph and collect input/output information.\\n     \\n     This function recursively traverses a model and its submodels.\\n     \\\"\\\"\\\"\\n     from astropy.modeling.core import Model\\n     from astropy.modeling.mappings import Mapping\\n+    from astropy.modeling.core import CompoundModel\\n \\n     if isinstance(model, Model):\\n@@ -95,9 +95,18 @@ def _traverse_model_graph(model, inputs, outputs):\\n             outputs.extend(range(n_outputs))\\n         else:\\n             # Handle compound models\\n-            for submodel_name in model.submodel_names:\\n-                submodel = getattr(model, submodel_name)\\n-                _traverse_model_graph(submodel, inputs, outputs)\\n+            current_input_offset = len(inputs)\\n+            current_output_offset = len(outputs)\\n+            \\n+            for i, submodel_name in enumerate(model.submodel_names):\\n+                submodel = getattr(model, submodel_name)\\n+                if isinstance(submodel, CompoundModel):\\n+                    # Recursively handle nested compound models\\n+                    _traverse_model_graph(submodel, inputs, outputs)\\n+                else:\\n+                    # Handle atomic models\\n+                    sub_inputs = list(range(current_input_offset, current_input_offset + submodel.n_inputs))\\n+                    sub_outputs = list(range(current_output_offset, current_output_offset + submodel.n_outputs))\\n+                    inputs.extend(sub_inputs)\\n+                    outputs.extend(sub_outputs)\\n+                    current_input_offset += submodel.n_inputs\\n+                    current_output_offset += submodel.n_outputs\", \"--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -10,6 +10,35 @@ from astropy.modeling.core import Model, CompoundModel\\n from astropy.modeling.mappings import Mapping\\n \\n \\n+def _get_model_structure(model):\\n+    \\\"\\\"\\\"\\n+    Get the flattened structure of a model, handling nested CompoundModels.\\n+    \\n+    Returns a list of atomic models in the order they appear.\\n+    \\\"\\\"\\\"\\n+    from astropy.modeling.core import CompoundModel\\n+    \\n+    if not isinstance(model, CompoundModel):\\n+        return [model]\\n+    \\n+    atomic_models = []\\n+    for submodel_name in model.submodel_names:\\n+        submodel = getattr(model, submodel_name)\\n+        atomic_models.extend(_get_model_structure(submodel))\\n+    \\n+    return atomic_models\\n+\\n+\\n def separability_matrix(model):\\n     \\\"\\\"\\\"\\n     Compute the separability matrix for a model.\\n@@ -30,6 +59,20 @@ def separability_matrix(model):\\n     if not isinstance(model, Model):\\n         raise ValueError(\\\"Input must be a Model.\\\")\\n \\n+    # Get flattened model structure\\n+    atomic_models = _get_model_structure(model)\\n+    \\n+    # Calculate total inputs and outputs\\n+    total_inputs = sum(m.n_inputs for m in atomic_models)\\n+    total_outputs = sum(m.n_outputs for m in atomic_models)\\n+    \\n+    # Create separability matrix\\n+    matrix = np.zeros((total_outputs, total_inputs), dtype=bool)\\n+    \\n+    input_offset = 0\\n+    output_offset = 0\\n+    \\n+    for atomic_model in atomic_models:\\n+        # Each atomic model's inputs only affect its own outputs\\n+        matrix[output_offset:output_offset + atomic_model.n_outputs, \\n+               input_offset:input_offset + atomic_model.n_inputs] = True\\n+        input_offset += atomic_model.n_inputs\\n+        output_offset += atomic_model.n_outputs\\n+    \\n+    return matrix\\n+\\n+    # Original code below (keeping for reference but will be replaced)\\n     inputs = []\\n     outputs = []\\n     _traverse_model_graph(model, inputs, outputs)\", \"--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -120,6 +120,7 @@ def _flatten_inputs_outputs(model, inputs, outputs):\\n     \\\"\\\"\\\"\\n     Flatten the inputs and outputs for compound models.\\n     \\\"\\\"\\\"\\n+    from astropy.modeling.core import CompoundModel\\n     \\n     if hasattr(model, 'submodel_names'):\\n         # This is a compound model\\n@@ -127,8 +128,15 @@ def _flatten_inputs_outputs(model, inputs, outputs):\\n         flat_outputs = []\\n         \\n         for submodel_name in model.submodel_names:\\n             submodel = getattr(model, submodel_name)\\n-            sub_inputs = list(range(submodel.n_inputs))\\n-            sub_outputs = list(range(submodel.n_outputs))\\n+            \\n+            if isinstance(submodel, CompoundModel):\\n+                # Recursively flatten nested compound models\\n+                sub_inputs, sub_outputs = _flatten_inputs_outputs(submodel, [], [])\\n+            else:\\n+                sub_inputs = list(range(submodel.n_inputs))\\n+                sub_outputs = list(range(submodel.n_outputs))\\n+            \\n             flat_inputs.extend(sub_inputs)\\n             flat_outputs.extend(sub_outputs)\\n         \\n         return flat_inputs, flat_outputs\", \"diff\\n--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -82,7 +82,7 @@ def _traverse_model_graph(model, inputs, outputs):\\n     \\\"\\\"\\\"\\n     Traverse the model graph and collect input/output information.\\n     \\n     This function recursively traverses a model and its submodels.\\n     \\\"\\\"\\\"\\n     from astropy.modeling.core import Model\\n     from astropy.modeling.mappings import Mapping\\n+    from astropy.modeling.core import CompoundModel\\n \\n     if isinstance(model, Model):\\n         if isinstance(model, Mapping):\\n@@ -95,9 +95,12 @@ def _traverse_model_graph(model, inputs, outputs):\\n             outputs.extend(range(n_outputs))\\n         else:\\n             # Handle compound models\\n-            for submodel_name in model.submodel_names:\\n+            input_offset = len(inputs)\\n+            output_offset = len(outputs)\\n+            \\n+            for submodel_name in model.submodel_names:\\n                 submodel = getattr(model, submodel_name)\\n-                _traverse_model_graph(submodel, inputs, outputs)\\n+                _traverse_model_graph(submodel, inputs, outputs)\\n \\n \\n def _flatten_inputs_outputs(model, inputs, outputs):\\n\", \"diff\\n--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -82,7 +82,7 @@ def _traverse_model_graph(model, inputs, outputs):\\n     \\\"\\\"\\\"\\n     Traverse the model graph and collect input/output information.\\n     \\n     This function recursively traverses a model and its submodels.\\n     \\\"\\\"\\\"\\n     from astropy.modeling.core import Model\\n     from astropy.modeling.mappings import Mapping\\n+    from astropy.modeling.core import CompoundModel\\n \\n     if isinstance(model, Model):\\n         if isinstance(model, Mapping):\\n@@ -95,9 +95,15 @@ def _traverse_model_graph(model, inputs, outputs):\\n             outputs.extend(range(n_outputs))\\n         else:\\n             # Handle compound models\\n+            input_start = len(inputs)\\n+            output_start = len(outputs)\\n+            \\n             for submodel_name in model.submodel_names:\\n                 submodel = getattr(model, submodel_name)\\n-                _traverse_model_graph(submodel, inputs, outputs)\\n+                if isinstance(submodel, CompoundModel):\\n+                    _traverse_model_graph(submodel, inputs, outputs)\\n+                else:\\n+                    inputs.extend(range(input_start + len(inputs) - input_start, input_start + len(inputs) - input_start + submodel.n_inputs))\\n+                    outputs.extend(range(output_start + len(outputs) - output_start, output_start + len(outputs) - output_start + submodel.n_outputs))\\n \\n \\n def _flatten_inputs_outputs(model, inputs, outputs):\\n\", \"diff\\n--- a/astropy/modeling/separable.py\\n+++ b/astropy/modeling/separable.py\\n@@ -82,7 +82,8 @@ def _traverse_model_graph(model, inputs, outputs):\\n     \\\"\\\"\\\"\\n     Traverse the model graph and collect input/output information.\\n     \\n     This function recursively traverses a model and its submodels.\\n     \\\"\\\"\\\"\\n     from astropy.modeling.core import Model\\n     from astropy.modeling.mappings import Mapping\\n+    from astropy.modeling.core import CompoundModel\\n \\n     if isinstance(model, Model):\\n         if isinstance(model, Mapping):\\n@@ -95,9 +96,21 @@ def _traverse_model_graph(model, inputs, outputs):\\n             outputs.extend(range(n_outputs))\\n         else:\\n             # Handle compound models\\n+            input_offset = len(inputs)\\n+            output_offset = len(outputs)\\n+            \\n             for submodel_name in model.submodel_names:\\n                 submodel = getattr(model, submodel_name)\\n-                _traverse_model_graph(submodel, inputs, outputs)\\n+                if isinstance(submodel, CompoundModel):\\n+                    # Recursively handle nested compound models\\n+                    _traverse_model_graph(submodel, inputs, outputs)\\n+                else:\\n+                    # Handle atomic models - add their inputs/outputs with proper offsets\\n+                    sub_inputs = list(range(input_offset, input_offset + submodel.n_inputs))\\n+                    sub_outputs = list(range(output_offset, output_offset + submodel.n_outputs))\\n+                    inputs.extend(sub_inputs)\\n+                    outputs.extend(sub_outputs)\\n+                    input_offset += submodel.n_inputs\\n+                    output_offset += submodel.n_outputs\\n \\n \\n def _flatten_inputs_outputs(model, inputs, outputs):\\n\"]}, \"error\": \"\", \"elapsed_time\": 109.657296, \"total_tokens\": 43547, \"total_steps\": 30, \"created_at\": 1754643820, \"finished_at\": 1754643930}}"
}