from .base import Misfit
import torch

class Misfit_Attenuation(Misfit):
    '''Attenuation imaging misfit
    '''
    def __init__(self) -> None:
        super().__init__()
    
    def forward(self, obs: torch.Tensor, syn: torch.Tensor) -> torch.Tensor:
        '''
        Args:
            obs (torch.Tensor)
            syn (torch.Tensor)
        
        Returns:
            torch.Tensor
        '''

        eps = 1e-10

        obs_amp = torch.abs(obs)
        syn_amp = torch.abs(syn)
        
        log_obs = torch.log(obs_amp + eps)
        log_syn = torch.log(syn_amp + eps)
        
        rsd = log_obs - log_syn
        loss = torch.sum(torch.sqrt(torch.sum(rsd * rsd, axis=1)))
        
        return loss