"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

import numpy as np 
import matplotlib.pyplot as plt
import torch
from typing import Optional, List, Union
from uniSI.survey import Survey
from uniSI.utils import gpu2cpu, tensor2numpy
from uniSI.view import plot_waveform2D, plot_waveform_wiggle, plot_waveform_trace

class SeismicData_Freq():
    def __init__(self,
                 survey: Survey = None,
                 source_num: int = None,
                 receiver_num: int = None,
                 source_loc: List[float] = None, 
                 receiver_loc: List[float] = None, 
                 nf: int = None,    
                 df: float = None   
                ):
        self.survey = survey
        self.data = None

        if self.survey is None:
            self.src_num = source_num
            self.rcv_num = receiver_num
            self.src_loc = source_loc
            self.rcv_loc = receiver_loc

            # Frequency domain parameters
            self.nf = nf
            self.df = df
            self.f = np.arange(self.nf) * self.df

            # Time domain mapping from frequency domain
            self.nt = 2 * (self.nf - 1) if self.nf > 1 else 1
            self.T = 1.0 / self.df
            self.dt = self.T / self.nt
            self.t = np.linspace(0, self.T - self.dt, self.nt)
        else:
            self.nf = survey.receiver.nf
            self.df = survey.receiver.df
            self.f = np.arange(self.nf) * self.df

            self.nt = 2 * (self.nf - 1) if self.nf > 1 else 1
            self.T = 1.0 / self.df
            self.dt = self.T / self.nt
            self.t = np.linspace(0, self.T - self.dt, self.nt)
    
    def __repr__(self):
        info = f"Seismic Data (Time Domain Version):\n"
        info += f"  Source number     : {self.src_num}\n"
        info += f"  Receiver number   : {self.rcv_num}\n"
        info += f"  Frequency samples : {self.nf} samples at {self.df:.2f} Hz resolution\n"
        info += f"  -> Corresponding Time Domain:\n"
        info += f"     Total time T    : {self.T:.4f} s\n"
        info += f"     Time samples    : {self.nt} samples at {self.dt*1000:.2f} ms interval\n"
        return info

    def add_noise(self, noise_level=0.01):
        print("self.data['p'].dtype", self.data["p"].dtype)
        # Compute amplitude for noise scaling
        max_amp = np.max(np.abs(self.data["p"]))
        noise_amp = noise_level * max_amp

        # Copy original data (assumed shape: (N,) complex array)
        original_data = self.data["p"][0, :, 0].copy()
        
        # Generate complex random noise
        complex_noise = noise_amp * (np.random.randn(*original_data.shape) + 1j * np.random.randn(*original_data.shape))
        
        # Add noise to data
        self.data["p"][0, :, 0] = original_data + complex_noise
        noisy_data = self.data["p"][0, :, 0]
        
        # Plot original and noisy data (real and imaginary parts)
        fig, axs = plt.subplots(2, 1, figsize=(10, 4), sharex=True)
        axs[0].plot(np.real(original_data), label="Real Part")
        axs[0].plot(np.imag(original_data), label="Imaginary Part", linestyle="--")
        axs[0].set_title("Original Data")
        axs[0].legend()
        axs[1].plot(np.real(noisy_data), label="Real Part")
        axs[1].plot(np.imag(noisy_data), label="Imaginary Part", linestyle="--")
        axs[1].set_title("Noisy Data")
        axs[1].legend()
        plt.tight_layout()
        plt.show()
            
    def record_data(self, data: dict):
        for key, value in data.items():
            value = tensor2numpy(gpu2cpu(value)).copy()
            data[key] = value
        self.data = data
    
    def save(self, path: str):
        """Save seismic data and metadata to npz file."""
        data_save = {
            'data': self.data,
            'src_loc': self.src_loc,
            'rcv_loc': self.rcv_loc,
            'src_num': self.src_num,
            'rcv_num': self.rcv_num,
            'nf': self.nf,
            'df': self.df,
            'nt': self.nt,
            'dt': self.dt,
            't': self.t,
            'f': self.f
        }
        np.savez(path, **data_save)
    
    def load(self, path: str):
        """Load seismic data and metadata from npz file."""
        data = np.load(path, allow_pickle=True)
        self.data = data['data'].item()
        self.src_loc = data['src_loc']
        self.rcv_loc = data['rcv_loc']
        self.src_num = data['src_num']
        self.rcv_num = data['rcv_num']
        self.nf = data['nf']
        self.df = data['df']
        self.f = data['f']
        self.nt = data['nt']
        self.dt = data['dt']
        self.t = data['t']
        return