"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

import numpy as np
from typing import List, Optional
from uniSI.utils import list2numpy, numpy2list
from uniSI.view import plot_wavelet
import matplotlib.pyplot as plt
from scipy import fft

class Source_Freq(object):
    """Seismic Source class (frequency domain)"""
    def __init__(self, freq_list: np.ndarray, df: float, f0: float, t0: float) -> None:
        """
        Parameters:
            freq_list : float array
                Selected frequency components (Hz)
            df : float
                Frequency interval (Hz)
            f0 : float
                Dominant frequency (Hz)
        """
        self.freq_list = freq_list
        self.df = df
        self.f0 = f0
        self.nf = len(freq_list)
        self.t0 = t0
        
        self.loc_x = []
        self.loc_z = []
        self.loc = []
        self.type = []
        self.spectrum = []  # Complex spectrum storage
        self.moment_tensor = []
        self.num = 0

    def __repr__(self):
        try:
            src_x = list2numpy(self.loc_x)
            src_z = list2numpy(self.loc_z)
            xmin = src_x.min()
            xmax = src_x.max()
            zmin = src_z.min()
            zmax = src_z.max()
            info = f"Frequency-domain Seismic Source:\n"
            info += f"  Frequency range: {self.freq_list.min():.2f}-{self.freq_list.max():.2f} Hz\n"
            info += f"  Frequency components: {self.nf}\n"
            info += f"  Source number : {self.num}\n"
            info += f"  Source types  : {self.get_type(unique=True)}\n"
            info += f"  Source x range: {xmin} - {xmax} (grids)\n"
            info += f"  Source z range: {zmin} - {zmax} (grids)\n"
        except:
            info = f"Frequency-domain Seismic Source:\n  empty\n"
        return info
    
    def add_sources(self,
                   src_x: np.ndarray,
                   src_z: np.ndarray,
                   src_spectrum: np.ndarray,  # Complex spectrum [nf]
                   src_type: Optional[str] = 'mt',
                   src_mt: Optional[np.ndarray] = np.array([[1,0,0],[0,1,0],[0,0,1]])) -> None:
        """Add multiple sources with same spectrum"""
        if src_x.shape != src_z.shape:
            raise ValueError("Source x/z locations must have same shape")
        if src_spectrum.shape[0] != self.nf:
            raise ValueError("Spectrum must match frequency components")
        if src_type.lower() not in ["mt"]:
            raise ValueError("Invalid source type")
        src_n = len(src_x)
        expanded_spectrum = np.tile(src_spectrum, (src_n, 1))  # [src_n, nf]
        expanded_mt = np.tile(src_mt, (src_n, 1, 1))  # [src_n, 3, 3]
        self.loc_x.extend(src_x.ravel().tolist())
        self.loc_z.extend(src_z.ravel().tolist())
        self.type.extend([src_type]*src_n)
        self.spectrum.extend(expanded_spectrum)
        self.moment_tensor.extend(expanded_mt)
        self.num += src_n

    def add_source(self,
                  src_x: int,
                  src_z: int,
                  src_spectrum: np.ndarray,  # Complex spectrum [nf]
                  src_type: Optional[str] = 'mt',
                  src_mt: Optional[np.ndarray] = np.array([[1,0,0],[0,1,0],[0,0,1]])) -> None:
        """Add single source with complex spectrum"""
        if src_spectrum.shape[0] != self.nf:
            raise ValueError("Spectrum must match frequency components")
        self.loc_x.append(src_x)
        self.loc_z.append(src_z)
        self.type.append(src_type)
        self.spectrum.append(src_spectrum)
        self.moment_tensor.append(src_mt)
        self.num += 1

    def get_spectrum(self, index: int = 0) -> np.ndarray:
        """Get complex spectrum for specified source"""
        all_spectra = list2numpy(self.spectrum)
        if len(all_spectra.shape) == 2:
            return all_spectra[index]
        else:
            return all_spectra

    def plot_spectrum(self, index: int = 0, savepath: str = None, df=1, nf=None, min_freq=0.1, fmax=100):
        """Plot amplitude spectrum"""
        spectrum = self.get_spectrum(index)
        self.plot_spectrum_(spectrum, savepath, df, nf, min_freq=min_freq, fmax=fmax)

    def get_loc(self):
        """Return the source locations"""
        src_x = list2numpy(self.loc_x).reshape(-1, 1)
        src_z = list2numpy(self.loc_z).reshape(-1, 1)
        src_loc = np.hstack((src_x, src_z))
        self.loc = src_loc.copy()
        return src_loc 
    
    def get_wavelet(self):
        """Return the source wavelets"""
        wavelet = list2numpy(self.wavelet)
        return wavelet
    
    def get_moment_tensor(self):
        """Return the source moment tensors"""
        mt = list2numpy(self.moment_tensor)
        return mt
    
    def get_type(self, unique=False) -> List[str]:
        """Return the source type"""
        type = list2numpy(self.type)
        if unique:
            type = list2numpy(list(set(self.type)))
        return type
    
    def plot_wavelet(self, index=0, **kwargs):
        tlist = self.t
        wavelet = self.get_wavelet()[index]
        plot_wavelet(tlist, wavelet, **kwargs)

    def plot_spectrum_(self, spectrum, savepath: Optional[str] = None, df=1, nf=None, min_freq=0.1, fmax=100) -> np.ndarray:
        """
        Plot frequency domain amplitude spectrum and its time-domain inverse FFT.
        """
        if nf is None:
            nf = len(spectrum)
        nt = 2 * (nf - 1)
        dt = 1 / (nt * df)
        full_freq_axis = np.fft.rfftfreq(nt, d=dt)
        trunc_idx = np.where((full_freq_axis >= min_freq) & (full_freq_axis <= fmax))[0]
        expected_len = len(trunc_idx)
        if len(spectrum) != expected_len:
            print("Warning: Provided spectrum length (%d) does not match expected truncated frequency range length (%d)."
                  " Interpolating spectrum to the expected frequency points." % (len(spectrum), expected_len))
            provided_freq_axis = np.linspace(min_freq, fmax, len(spectrum))
            spectrum_interp_real = np.interp(full_freq_axis[trunc_idx], provided_freq_axis, spectrum.real)
            spectrum_interp_imag = np.interp(full_freq_axis[trunc_idx], provided_freq_axis, spectrum.imag)
            truncated_spectrum = spectrum_interp_real + 1j * spectrum_interp_imag
            trunc_freq_axis = full_freq_axis[trunc_idx]
        else:
            truncated_spectrum = spectrum
            trunc_freq_axis = full_freq_axis[trunc_idx]
        amp_spectrum = np.abs(truncated_spectrum)
        full_spectrum = np.zeros(nf, dtype=complex)
        full_spectrum[trunc_idx] = truncated_spectrum
        time_signal = np.fft.irfft(full_spectrum, n=nt)
        time_axis = np.arange(nt) * dt
        fig, axes = plt.subplots(1, 2, figsize=(10, 4))
        axes[0].plot(trunc_freq_axis, amp_spectrum, 'b-', lw=2)
        axes[0].set_xlabel("Frequency (Hz)")
        axes[0].set_ylabel("Amplitude")
        axes[0].set_title("Frequency Domain Amplitude Spectrum")
        axes[0].grid(True)
        axes[1].plot(time_axis, time_signal, 'r-', lw=2)
        axes[1].set_xlabel("Time (s)")
        axes[1].set_ylabel("Amplitude")
        axes[1].set_title("Time Domain Source Signal")
        axes[1].grid(True)
        plt.tight_layout()
        if savepath is not None:
            plt.savefig(savepath)
        else:
            plt.show()
        return time_signal