from .base import Misfit
import torch
import torch.nn.functional as F

class Misfit_traveltime(Misfit):
    """
    Travel-time misfit function.

    Parameters:
    -----------
    dt (float): Time sampling interval.
    beta (float): A parameter (unused in this implementation, but available for extension).
    obs (Tensor): Observed waveform.
    syn (Tensor): Synthetic waveform.
    """
    def __init__(self, dt=1, beta=10) -> None:
        super().__init__()
        self.dt = dt
        self.beta = beta

    # Input shape: Batch (srcn), waveform (nt), traces (rcvn), e.g., [40, 3000, 200].
    def forward(self, obs: torch.Tensor, syn: torch.Tensor) -> torch.Tensor:

        # Test if there are nan values in the observed and synthetic data
        if torch.isnan(obs).any():
            print("Observed data contains NaN values.")
        if torch.isnan(syn).any():
            print("Synthetic data contains NaN values.")
        
        syn = torch.nan_to_num(syn, nan=0.0)
        obs = torch.nan_to_num(obs, nan=0.0)

        # Normalize the observed and synthetic data along the time dimension
        max_obs = torch.max(torch.abs(obs), dim=1, keepdim=True)[0]
        obs = obs / max_obs

        max_syn = torch.max(torch.abs(syn), dim=1, keepdim=True)[0]
        syn = syn / max_syn

        if torch.isnan(obs).any():
            print("Observed data contains NaN values after Norm.")
        if torch.isnan(syn).any():
            print("Synthetic data contains NaN values after Norm.")

        syn = torch.nan_to_num(syn, nan=0.0)
        obs = torch.nan_to_num(obs, nan=0.0)

        # Calculate travel-time shifts (taus) and scale by the time sampling interval
        taus = self.findtau(obs, syn) * self.dt 
        taus_abs = torch.abs(taus)
        # Compute the total misfit loss as 0.5 * sum(|tau|^2)
        loss = torch.sum(0.5 * (taus_abs ** 2))
        return loss

    def findtau(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute the travel-time shift (tau) for each trace in a batch using vectorized operations.
        This method processes one trace at a time to reduce memory usage.

        Parameters:
        -----------
        output: Tensor, shape [B, L, T]
                B: Batch size, L: Sequence length, T: Number of traces.
        target: Tensor, shape [B, L, T]

        Returns:
        --------
        taus: Tensor, shape [B, T] containing the travel-time shift (tau) for each trace.
        """
        B, L, T = output.shape
        taus_list = []
        temperature = 10000.0  # High temperature for softmax to produce a near one-hot output

        # Process each trace in the T dimension individually
        for t in range(T):
            # Extract the output and target signals for the current trace, each with shape [B, L]
            out_t = output[:, :, t]
            tgt_t = target[:, :, t]
            
            # Pad the target signal on both sides along the time axis by (L - 1) to obtain shape [B, 3L-2]
            tgt_padded = F.pad(tgt_t, (L - 1, L - 1))
            # Add a channel dimension to shape the tensor as [B, 1, 3L-2]
            tgt_padded = tgt_padded.unsqueeze(1)
            
            # Reshape the output signal to have a channel dimension: shape [B, 1, L]
            out_reshaped = out_t.unsqueeze(1)
            
            # Use unfold to extract sliding windows along the time axis.
            # Each window has a length L with a stride of 1, resulting in shape [B, 1, 2L-1, L]
            windows = tgt_padded.unfold(dimension=2, size=L, step=1)
            
            # Compute a convolution-like result by performing element-wise multiplication
            # between each sliding window and the output signal, then summing over the window length.
            # The result has shape [B, 1, 2L-1].
            conv_result = (windows * out_reshaped.unsqueeze(2)).sum(dim=-1)
            conv_result = conv_result.squeeze(1)  # Now shape: [B, 2L-1]
            
            # Apply a high-temperature softmax so that the softmax output approximates a one-hot vector,
            # then compute the weighted average of indices.
            softmax_vals = F.softmax(conv_result * temperature, dim=1)  # Shape: [B, 2L-1]
            index_vec = torch.arange(0, 2 * L - 1, device=output.device, dtype=softmax_vals.dtype)
            weighted_index = (softmax_vals * index_vec).sum(dim=1)  # Shape: [B]
            
            # Subtract the offset caused by the left padding (L - 1) to obtain tau
            tau_t = weighted_index - (L - 1)  # Shape: [B]
            taus_list.append(tau_t.unsqueeze(1))  # Append with shape [B, 1]

        # Concatenate tau values for all traces to obtain a tensor of shape [B, T]
        taus = torch.cat(taus_list, dim=1)
        return taus