from __future__ import annotations

from typing import Any

import numpy as np
import torch

from worm_objectives.corr import corr_loss_and_dLtdv_numpy, corr_loss_and_dLtdv_torch
from worm_train_config import WormTrainConfig


def compute_corr_loss_and_dLtdv(
    output_vs: np.ndarray,
    target: np.ndarray,
    *,
    lr_start: int,
    lr_end: int,
    cfg: WormTrainConfig,
    epoch_prof: Any | None = None,
) -> tuple[float, np.ndarray | torch.Tensor]:
    """
    Compute corr objective and dL/dv.

    This is factored out of the training implementation to keep the main epoch step readable.
    Behavior should remain identical to the inlined logic.
    """
    if bool(cfg.corr_use_torch):
        if epoch_prof:
            with epoch_prof.phase("objective_corr_torch"):
                mean_error_t, dLtdv = corr_loss_and_dLtdv_torch(
                    output_vs,
                    target,
                    lr_start,
                    lr_end,
                    device=str(cfg.corr_torch_device),
                )
                mean_error = float(mean_error_t.detach().cpu().item())
        else:
            mean_error_t, dLtdv = corr_loss_and_dLtdv_torch(
                output_vs,
                target,
                lr_start,
                lr_end,
                device=str(cfg.corr_torch_device),
            )
            mean_error = float(mean_error_t.detach().cpu().item())
        return mean_error, dLtdv

    if epoch_prof:
        with epoch_prof.phase("objective_corr_numpy"):
            mean_error, dLtdv = corr_loss_and_dLtdv_numpy(output_vs, target, lr_start, lr_end)
    else:
        mean_error, dLtdv = corr_loss_and_dLtdv_numpy(output_vs, target, lr_start, lr_end)
    return float(mean_error), dLtdv
