"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

from .base import Misfit
import torch

class Misfit_global_correlation(Misfit):
    """ global correlation misfit functions
    
    Paraemters:
    ------------
        obs (Tensors)   : the observed waveform 
        syn (Tensors)   : the synthetic waveform 
    """
    def __init__(self,dt=1) -> None:
        super().__init__()
        self.dt = dt
    
    # One loop for calculate the misfits (traces)
    def forward(self, obs, syn):
        """Compute the global correlation misfit between observed and synthetic waveforms.
        
        Args:
            obs (Tensor): Observed waveform, shape (batch, channels, traces).
            syn (Tensor): Synthetic waveform, shape (batch, channels, traces).
        
        Returns:
            Tensor: Correlation-based misfit loss.
        """
        # Compute norms
        obs_norm = obs.norm(dim=1, keepdim=True)  # Shape: (N, 1, M)
        syn_norm = syn.norm(dim=1, keepdim=True)  # Shape: (N, 1, M)

        # Normalize the observed and synthetic waveforms
        obs_normalized = obs / obs_norm
        syn_normalized = syn / syn_norm

        # Initialize result tensor
        rsd = torch.empty(obs.shape[0], obs.shape[2], device=obs.device)

        # Compute correlation for each trace
        for itrace in range(obs.shape[2]):
            obs_trace = obs_normalized[:, :, itrace]  # Shape: (N, T)
            syn_trace = syn_normalized[:, :, itrace]  # Shape: (N, T)

            # Calculate covariance and variances
            cov = torch.mean(obs_trace * syn_trace, dim=1)  # Shape: (N,)
            var_obs = torch.var(obs_trace, dim=1)  # Shape: (N,)
            var_syn = torch.var(syn_trace, dim=1)  # Shape: (N,)

            # Avoid division by zero by masking
            corr = cov / (torch.sqrt(var_obs * var_syn) + 1e-8)  # Adding small value to avoid div by zero

            # Handle the case where both variances are zero
            corr[torch.isnan(corr)] = 0  # If both variances are zero, set correlation to zero

            rsd[:, itrace] = -corr
        
        loss = torch.sum(rsd * self.dt)
        return loss