import argparse
import sys
from pathlib import Path
import types

import torch

# Ensure private implementation is importable
ROOT = Path(__file__).resolve().parents[2]
PRIVATE_ROOT = ROOT / "private_impl"
if str(PRIVATE_ROOT) not in sys.path:
    sys.path.insert(0, str(PRIVATE_ROOT))

# Provide lightweight stubs for optional dependencies used only in training
if "torchmetrics.functional" not in sys.modules:
    tm_func = types.ModuleType("torchmetrics.functional")

    def _not_available(*args, **kwargs):
        raise RuntimeError("torchmetrics is required for training only.")

    tm_func.scale_invariant_signal_noise_ratio = _not_available
    tm_func.signal_noise_ratio = _not_available
    tm_func.signal_distortion_ratio = _not_available
    tm_func.scale_invariant_signal_distortion_ratio = _not_available
    sys.modules["torchmetrics.functional"] = tm_func

if "asteroid.losses" not in sys.modules:
    asteroid_losses = types.ModuleType("asteroid.losses")

    class _LossStub:
        def __init__(self, *args, **kwargs):
            pass

        def __call__(self, *args, **kwargs):
            raise RuntimeError("asteroid.losses is required for training only.")

    asteroid_losses.PITLossWrapper = _LossStub

    def pairwise_neg_sisdr(*args, **kwargs):
        raise RuntimeError("asteroid.losses is required for training only.")

    asteroid_losses.pairwise_neg_sisdr = pairwise_neg_sisdr
    sys.modules["asteroid.losses"] = asteroid_losses

from spatial_encoder_impl.spatial_encoder import SpatialEncoder  # type: ignore  # noqa: E402


def load_model(ckpt_path: Path) -> SpatialEncoder:
    model = SpatialEncoder(
        dim_input=4,
        dim_hidden=96,
        num_layers=8,
    )
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    state_dict = checkpoint.get("state_dict", checkpoint)
    new_state = {}
    prefix = "model."
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_state[k[len(prefix):]] = v
        else:
            new_state[k] = v
    model.load_state_dict(new_state, strict=False)
    model.eval()
    return model


def main():
    parser = argparse.ArgumentParser(description="Export spatial encoder to TorchScript")
    parser.add_argument("--ckpt", type=Path, required=True, help="Path to spatial encoder checkpoint")
    parser.add_argument("--output", type=Path, required=True, help="Target TorchScript path")
    parser.add_argument("--sample-length", type=int, default=30, help="Audio length in seconds for tracing")
    parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate")
    args = parser.parse_args()

    model = load_model(args.ckpt)

    class Wrapper(torch.nn.Module):
        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder

        def forward(self, audios: torch.Tensor) -> torch.Tensor:
            return self.encoder.forward_as_encoder(audios)

    wrapper = Wrapper(model)
    dummy = torch.randn(1, 2, args.sample_length * args.sample_rate)
    with torch.no_grad():
        traced = torch.jit.trace(wrapper, dummy, strict=False)
    args.output.parent.mkdir(parents=True, exist_ok=True)
    traced.save(str(args.output))
    print(f"Saved TorchScript spatial encoder to {args.output}")


if __name__ == "__main__":
    main()
