"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

from .base import Misfit
import torch
import matplotlib.pyplot as plt

class Misfit_waveform_L2(Misfit):
    '''Waveform L2-norm difference misfit (Tarantola, 1984)
    
    Parameters:
    -----------
        dt (float)      : Time sampling interval.
        obs (Tensor)    : Observed waveform.
        syn (Tensor)    : Synthetic waveform.
    '''
    def __init__(self, dt=1) -> None:
        super().__init__()
        self.dt = dt
    
    def forward(self, obs: torch.Tensor, syn: torch.Tensor) -> torch.Tensor:
        '''Compute the L2-norm waveform misfit between observed and synthetic data.

        Args:
            obs (Tensor): Observed waveform.
            syn (Tensor): Synthetic waveform.
        
        Returns:
            Tensor: L2-norm misfit loss.
        '''
        # Calculate residuals by subtracting synthetic from observed data
        rsd = obs - syn

        # Compute the L2-norm loss as the square root of the sum of squared residuals, weighted by dt
        # Summation along axis=1 (channels) for each sample, then take square root and sum over all samples
        loss = torch.sum(torch.sqrt(torch.sum(rsd * rsd * self.dt, axis=1)))
        loss = torch.abs(loss)
        
        return loss
