from typing import Any

import torch

from .base import FusingFunction


class Complex(FusingFunction):
    """Complex fusing function."""

    def __init__(self):
        super().__init__()

    def __call__(self, s: torch.Tensor, r: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        h1 = s.real * r.real - s.imag * r.imag
        h2 = s.real * r.imag + s.imag * r.real
        return torch.cat([h1, h2], dim=-1)
