"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

from typing import List
import numpy as np
from uniSI.utils import list2numpy, numpy2list

class Receiver_freq(object):
    """
    Frequency-domain Receiver class.

    Args:
      nf (int): Number of frequency samples for the receiver data
      df (float): Frequency interval (Hz)

    Notes:
      1. Frequency coordinates for receiver data are: df, 2*df, ..., nf*df
      2. Receiver locations are added via add_receiver(s) methods.
      3. set_data is used to store frequency-domain data (e.g. p_fd).
    """
    def __init__(self, nf: int, df: float) -> None:
        self.nf = nf
        self.df = df
        self.loc_x = []
        self.loc_z = []
        self.locs = []
        self.type = []
        self.num = 0
        self.data = None

    def __repr__(self):
        try:
            rcv_x = np.array(self.loc_x)
            rcv_z = np.array(self.loc_z)
            xmin = rcv_x.min()
            xmax = rcv_x.max()
            zmin = rcv_z.min()
            zmax = rcv_z.max()
            info = f"Frequency Domain Receiver:\n"
            info += f"  Frequency samples : {self.nf} with df = {self.df:.2f} Hz\n"
            info += f"  Receiver number   : {self.num}\n"
            info += f"  Receiver types    : {self.get_type(unique=True)}\n"
            info += f"  Receiver x range  : {xmin} - {xmax} (grids)\n"
            info += f"  Receiver z range  : {zmin} - {zmax} (grids)\n"
        except:
            info = "Frequency Domain Receiver:\n  empty\n"
        return info

    def add_receivers(self, rcv_x: np.array, rcv_z: np.array, rcv_type: str) -> None:
        """
        Add multiple receivers of the same type.

        Args:
          rcv_x: x coordinates (numpy array)
          rcv_z: z coordinates (numpy array)
          rcv_type: receiver type ('pr', 'vx', 'vz', 'vy', 'pr_spec')
        """
        if rcv_x.shape != rcv_z.shape:
            raise ValueError("Receiver Error: Inconsistent number of receivers in X and Z directions")
        if rcv_type.lower() not in ["pr", "vx", "vz", "vy", "pr_spec"]:
            raise ValueError("Receiver type must be one of: pr, vx, vz, vy, pr_spec")
        rcv_n = len(rcv_x)
        self.loc_x.extend(numpy2list(rcv_x.reshape(-1)))
        self.loc_z.extend(numpy2list(rcv_z.reshape(-1)))
        self.type.extend([rcv_type] * rcv_n)
        self.num += rcv_n

    def add_receiver(self, rcv_x: int, rcv_z: int, rcv_type: str) -> None:
        """
        Add a single receiver.

        Args:
          rcv_x: x coordinate (int)
          rcv_z: z coordinate (int)
          rcv_type: receiver type ('pr', 'vx', 'vz', 'vy', 'pr_spec')
        """
        if rcv_type.lower() not in ["pr", "vx", "vz", "vy", "pr_spec"]:
            raise ValueError("Receiver type must be one of: pr, vx, vz, vy, pr_spec")
        self.loc_x.append(rcv_x)
        self.loc_z.append(rcv_z)
        self.type.append(rcv_type)
        self.num += 1

    def get_loc(self):
        """
        Return receiver locations as [x, z] 2D array.
        """
        rcv_x = list2numpy(self.loc_x).reshape(-1, 1)
        rcv_z = list2numpy(self.loc_z).reshape(-1, 1)
        rcv_loc = np.hstack((rcv_x, rcv_z))
        self.locs = rcv_loc.copy()
        return rcv_loc

    def get_type(self, unique: bool = False) -> List[str]:
        """
        Return list of receiver types.

        Args:
          unique: If True, return unique type list.
        """
        types = list2numpy(self.type)
        if unique:
            types = list2numpy(list(set(self.type)))
        return types

    def set_data(self, data):
        """
        Set frequency-domain receiver data, e.g., p_fd.

        Args:
          data: numpy array or torch.Tensor,
                recommended shape (num_receivers, nf) or (num_receivers, nf, ...).
        """
        self.data = data

    def get_data(self):
        """Return stored frequency-domain receiver data."""
        return self.data