"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""


import numpy as np
from typing import List, Optional
import numpy as np
from uniSI.utils import list2numpy,numpy2list
from uniSI.view import plot_wavelet

class SHSource(object):
    """Seismic Source class for SH waves
    """
    def __init__(self, nt: int, dt: float, f0: float) -> None:
        self.nt = nt
        self.dt = dt
        self.f0 = f0
        self.t = np.arange(nt) * dt
        self.loc_x = []
        self.loc_z = []
        self.loc = []
        self.type = []
        self.wavelet = []
        self.force_direction = [] 
        self.num = 0
    

    def __repr__(self):
        """Reimplement the repr function for printing the source information"""
        try:
            src_x = list2numpy(self.loc_x)
            src_z = list2numpy(self.loc_z)
            xmin = src_x.min()
            xmax = src_x.max()
            zmin = src_z.min()
            zmax = src_z.max()

            info = f"SH Wave Source:\n"
            info += f"  Source wavelet: {self.nt} samples at {self.dt * 1000:.2f} ms\n"
            info += f"  Source number : {self.num}\n"
            info += f"  Source types  : {self.get_type(unique=True)}\n"
            info += f"  Source x range: {xmin} - {xmax} (grids)\n"
            info += f"  Source z range: {zmin} - {zmax} (grids)\n"
        except:
            info = f"SH Wave Source:\n"
            info += f"  empty\n"
        return info
    
    def add_sources(self,
            src_x: np.array,
            src_z: np.array,
            src_wavelet: np.ndarray,
            src_type: Optional[str] = 'force',
            src_direction: Optional[np.ndarray] = np.array([0, 1, 0])  
        ) -> None:
        """add multiple sources with same wavelet
        Parameters:
        -----------
            src_x: source x coordinates
            src_z: source z coordinates
            src_wavelet: source time function
            src_type: 'force' for SH waves
            src_direction: force direction, default [0,1,0] for y-direction
        """
        if src_x.shape != src_z.shape:
            raise ValueError(
                "Source location along x and z direction must have the same shape"
            )
        if src_type.lower() not in ["force"]:
            raise ValueError(
                "Source type must be force for SH waves"
            )
        if src_wavelet.shape[0] != self.nt:
            raise ValueError(
                "Source wavelet must have the same length as the number of time samples"
            )
        if len(src_direction) != 3:
            raise ValueError("Force direction must be a 3D vector")

        src_n = len(src_x)
        # add source
        self.loc_x.extend(numpy2list(src_x.reshape(-1)))
        self.loc_z.extend(numpy2list(src_z.reshape(-1)))
        self.type.extend([src_type] * src_n)
        self.wavelet.extend(np.ones((src_n, self.nt)) * src_wavelet)
        self.force_direction.extend([src_direction] * src_n)
        self.num += src_n
        return
        
    def add_source(self,
            src_x: int,
            src_z: int,
            src_wavelet: np.ndarray,
            src_type: Optional[str] = 'force',
            src_direction: Optional[np.ndarray] = np.array([0, 1, 0])
        ) -> None:
        """Append single source
        """
        if src_type.lower() not in ["force"]:
            raise ValueError(
                "Source type must be force for SH waves"
            )
        if src_wavelet.shape[0] != self.nt:
            raise ValueError(
                "Source wavelet must have the same length as the number of time samples"
            )
        if len(src_direction) != 3:
            raise ValueError("Force direction must be a 3D vector")

        # add source
        self.loc_x.append(src_x)
        self.loc_z.append(src_z)
        self.type.append(src_type)
        self.wavelet.append(src_wavelet)
        self.force_direction.append(src_direction)
        self.num += 1
    
    def get_force_direction(self):
        """Return the force directions"""
        return list2numpy(self.force_direction)
    
    def get_loc(self):
        """Return the source location
        """
        src_x = list2numpy(self.loc_x).reshape(-1,1)
        src_z = list2numpy(self.loc_z).reshape(-1,1)
        src_loc = np.hstack((src_x,src_z))
        self.loc = src_loc.copy()
        return src_loc 
    
    def get_wavelet(self):
        """Return the source wavelets
        """
        wavelet = list2numpy(self.wavelet)
        return wavelet
    
    def get_moment_tensor(self):
        """Return the source wavelets
        """
        mt = list2numpy(self.moment_tensor)
        return mt
    
    def get_type(self, unique=False) -> List[str]:
        """Return the source type
        """
        type = list2numpy(self.type)
        
        if unique:
            type = list2numpy(list(set(self.type)))
        return type
    
    def plot_wavelet(self,index=0,**kwargs):
        tlist = self.t
        wavelet = self.get_wavelet()[index]
        plot_wavelet(tlist,wavelet,**kwargs)