"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

from typing import Optional, Union, List
import os
import math
import torch
import numpy as np
from tqdm.notebook import tqdm
from uniSI.model import AbstractModel
from uniSI.propagator import SHPropagator, GradProcessor
from uniSI.survey import SeismicData
from uniSI.inversion.misfit import Misfit, Misfit_NIM
from uniSI.inversion.regularization import Regularization
from uniSI.inversion.optimizer import NLCG
from uniSI.utils import numpy2tensor
from uniSI.view import plot_model
from uniSI.inversion.multiScaleProcessing import lpass
import time
class SHInversion(torch.nn.Module):
    def __init__(self, propagator: SHPropagator, model: AbstractModel,
                 optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler,
                 loss_fn: Union[Misfit, torch.autograd.Function],
                 obs_data: SeismicData,
                 gradient_processor: Union[GradProcessor, List[GradProcessor]] = None,
                 regularization_fn: Optional[Regularization] = None, 
                 regularization_weights_x: Optional[List[Union[float]]] = [0, 0],
                 regularization_weights_z: Optional[List[Union[float]]] = [0, 0],
                 waveform_normalize: Optional[bool] = True,
                 cache_result: Optional[bool] = True,
                 save_fig_epoch: Optional[int] = -1,
                 save_fig_path: Optional[str] = "",
                 min_improvement: Optional[float] = 1e-4):
        super().__init__()
        self.propagator = propagator
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.loss_fn = loss_fn
        self.regularization_fn = regularization_fn
        self.regularization_weights_x = regularization_weights_x
        self.regularization_weights_z = regularization_weights_z
        self.obs_data = obs_data
        self.gradient_processor = gradient_processor
        self.device = self.propagator.device
        self.dtype = self.propagator.dtype 
        
        self.waveform_normalize = waveform_normalize
        obs_vy = self.obs_data.data["vy"]
        obs_vy = numpy2tensor(obs_vy, self.dtype).to(self.device)
        if self.waveform_normalize:
            obs_vy = self._normalize(obs_vy)
        self.obs_vy = obs_vy
        
        # Boundaries for vs and rho
        vs_bound = self.model.get_bound("vs")
        if vs_bound[0] is None and vs_bound[1] is None:
            self.vs_min = self.model.get_model("vs").min() - 100
            self.vs_max = self.model.get_model("vs").max() + 100
        else:
            self.vs_min = vs_bound[0]
            self.vs_max = vs_bound[1]
            if self.model.water_layer_mask is not None:
                self.vs_min = 300

        rho_bound = self.model.get_bound("rho")
        if rho_bound[0] is None and rho_bound[1] is None:
            self.rho_min = self.model.get_model("rho").min() - 500
            self.rho_max = self.model.get_model("rho").max() + 500
        else:
            self.rho_min = rho_bound[0]
            self.rho_max = rho_bound[1]
            if self.model.water_layer_mask is not None:
                self.rho_min = 1000
        
        self.cache_result = cache_result
        self.iter_vs, self.iter_rho = [], []
        self.iter_vs_grad, self.iter_rho_grad = [], []
        self.iter_loss = []
        self.min_improvement = min_improvement
        
        self.save_fig_epoch = save_fig_epoch
        self.save_fig_path = save_fig_path
        if self.save_fig_path and not os.path.exists(self.save_fig_path):
            os.makedirs(self.save_fig_path)
    
    def check_early_stopping(self, loss_history, min_improvement):
        if len(loss_history) < 100:
            return False
        recent_avg = np.mean(loss_history[-50:])
        previous_avg = np.mean(loss_history[-100:-50])
        improvement = np.abs(previous_avg - recent_avg)
        return improvement < min_improvement, improvement

    def _normalize(self, data):
        mask = torch.sum(torch.abs(data), axis=1, keepdim=True) == 0
        max_val = torch.max(torch.abs(data), axis=1, keepdim=True).values
        max_val = max_val.masked_fill(mask, 1)
        return data / max_val
    
    def calculate_loss(self, synthetic_waveform, observed_waveform, normalization, loss_fn, cutoff_freq=None, propagator_dt=None):
        if normalization:
            synthetic_waveform = self._normalize(synthetic_waveform)
        if cutoff_freq is not None:
            synthetic_waveform, observed_waveform = lpass(synthetic_waveform, observed_waveform, cutoff_freq, int(1/propagator_dt))
        if isinstance(loss_fn, Misfit):
            return loss_fn.forward(synthetic_waveform, observed_waveform)
        elif isinstance(loss_fn, Misfit_NIM):
            return loss_fn.apply(synthetic_waveform, observed_waveform, loss_fn.p, loss_fn.trans_type, loss_fn.theta)
        else:
            return loss_fn.apply(synthetic_waveform, observed_waveform)
    
    def calculate_regularization_loss(self, model_param, weight_x, weight_z, regularization_fn):
        reg_loss = torch.tensor(0.0, device=model_param.device)
        if model_param.requires_grad:
            regularization_fn.alphax = weight_x
            regularization_fn.alphaz = weight_z
            if regularization_fn.alphax > 0 or regularization_fn.alphaz > 0:
                reg_loss = regularization_fn.forward(model_param)
        return reg_loss

    def process_gradient(self, parameter, forw, idx=None):
        with torch.no_grad():
            grads = parameter.grad.detach()
            vmax = torch.max(parameter.detach())
            if isinstance(self.gradient_processor, GradProcessor):
                grads = self.gradient_processor.forward(nz=self.model.nz, nx=self.model.nx, vmax=vmax, grad=grads, forw=forw)
            else:
                grads = self.gradient_processor[idx].forward(nz=self.model.nz, nx=self.model.nx, vmax=vmax, grad=grads, forw=forw)
            parameter.grad = grads.to(self.propagator.device)

    def save_figure(self, i, data, model_type="vs"):
        if self.save_fig_epoch == -1:
            return
        elif i % self.save_fig_epoch == 0:
            if os.path.exists(self.save_fig_path):
                if model_type == "vs":
                    plot_model(data, title=f"Iteration {i}",
                               dx=self.model.dx, dz=self.model.dz,
                               vmin=self.vs_min, vmax=self.vs_max,
                               save_path=os.path.join(self.save_fig_path, f"{model_type}_{i}.png"),
                               show=False)
                elif model_type == "rho":
                    plot_model(data, title=f"Iteration {i}",
                               dx=self.model.dx, dz=self.model.dz,
                               vmin=self.rho_min, vmax=self.rho_max,
                               save_path=os.path.join(self.save_fig_path, f"{model_type}_{i}.png"),
                               show=False)
                elif model_type[:4] == "grad":
                    plot_model(data, title=f"Iteration {i}",
                               dx=self.model.dx, dz=self.model.dz,
                               save_path=os.path.join(self.save_fig_path, f"{model_type}_{i}.pdf"),
                               show=False, cmap='seismic')
                else:
                    plot_model(data, title=f"Iteration {i}",
                            dx=self.model.dx, dz=self.model.dz,
                            save_path=os.path.join(self.save_fig_path, f"{model_type}_{i}.png"),
                            show=False, cmap='coolwarm')
        return
    
    def forward(self,
                iteration: int,
                batch_size: Optional[int] = None,
                checkpoint_segments: Optional[int] = 1,
                start_iter=0,
                cutoff_freq=None):
        n_shots = self.propagator.src_n
        if batch_size is None or batch_size > n_shots:
            batch_size = n_shots
        print("Iteration Range: {} - {}".format(start_iter, start_iter+iteration))
        for i in range(start_iter, start_iter+iteration):
            start_time = time.time()

            self.optimizer.zero_grad()
            loss_batch = 0
            shot_indices = np.random.permutation(n_shots)
            n_batches = math.ceil(n_shots / batch_size)
            for batch in range(n_batches):
                begin_index = batch * batch_size
                end_index = min((batch+1)*batch_size, n_shots)
                if batch % 5 == 0:
                    print(f"Batch: {batch+1}/{n_batches}, Random Shots: {begin_index} to {end_index}")
                shot_index = shot_indices[begin_index:end_index]
                record_waveform = self.propagator.forward(shot_index=shot_index, checkpoint_segments=checkpoint_segments)
                rcv_vy = record_waveform["vy"]
                forward_wavefield_vy = record_waveform["forward_wavefield_v"]
                if batch == 0:
                    forw = forward_wavefield_vy.cpu().detach().numpy()
                else:
                    forw += forward_wavefield_vy.cpu().detach().numpy()
                syn_vy = rcv_vy
                data_loss = self.calculate_loss(syn_vy, self.obs_vy[shot_index], self.waveform_normalize, self.loss_fn, cutoff_freq, self.propagator.dt)
                if self.regularization_fn is not None:
                    reg_loss_vs = self.calculate_regularization_loss(self.model.vs, self.regularization_weights_x[0], self.regularization_weights_z[0], self.regularization_fn)
                    reg_loss_rho = self.calculate_regularization_loss(self.model.rho, self.regularization_weights_x[1], self.regularization_weights_z[1], self.regularization_fn)
                    loss_batch += data_loss.item() + (reg_loss_vs + reg_loss_rho).item()
                    loss = data_loss + reg_loss_vs + reg_loss_rho
                else:
                    loss_batch += data_loss.item()
                    loss = data_loss
                loss.backward()
            if self.model.get_requires_grad("vs"):
                self.process_gradient(self.model.vs, forw=forw, idx=0)
            if self.model.get_requires_grad("rho"):
                self.process_gradient(self.model.rho, forw=forw, idx=1)
            previous_vs = self.model.vs.clone()
            self.optimizer.step()
            self.scheduler.step()
            if self.model.get_requires_grad("vs"):
                with torch.no_grad():
                    new_vs = torch.where(torch.isnan(self.model.vs), previous_vs, self.model.vs)
                    new_vs = torch.clamp(new_vs, self.vs_min, self.vs_max)
                    self.model.vs.copy_(new_vs)
            self.model.forward()
            self.iter_loss.append(loss_batch)
            if self.cache_result:
                temp_vs = self.model.vs.cpu().detach().numpy()
                temp_rho = self.model.rho.cpu().detach().numpy()
                self.iter_vs.append(temp_vs)
                self.iter_rho.append(temp_rho)
                self.save_figure(i, temp_vs, model_type="vs")
                self.save_figure(i, temp_rho, model_type="rho")
                if self.model.get_requires_grad("vs"):
                    grads_vs = self.model.vs.grad.cpu().detach().numpy()
                    self.save_figure(i, grads_vs, model_type="grad_vs")
                    self.iter_vs_grad.append(grads_vs)
                if self.model.get_requires_grad("rho"):
                    grads_rho = self.model.rho.grad.cpu().detach().numpy()
                    self.save_figure(i, grads_rho, model_type="grad_rho")
                    self.iter_rho_grad.append(grads_rho)
            print(f"Iteration {i+1}/{start_iter+iteration}, Loss: {loss_batch:.4f}, LR: {self.optimizer.param_groups[0]['lr']}, Time: {time.time()-start_time:.2f}s")
            if i >= 100:
                if np.isnan(loss_batch):
                    print(f"Loss is nan, ending training at iteration {i+1}")
                    break
                early_stopping, improvement = self.check_early_stopping(self.iter_loss, self.min_improvement)
                print(f"Improvement: {improvement:.8f}")
                if early_stopping:
                    print(f"Early stopping triggered at iteration {i+1}")
                    print(f"Recent avg loss: {np.mean(self.iter_loss[-5:]):.4f}")
                    print(f"Prev avg loss: {np.mean(self.iter_loss[-10:-5]):.4f}")
                    print(f"Improvement ratio: {((np.mean(self.iter_loss[-10:-5]) - np.mean(self.iter_loss[-5:])) / np.mean(self.iter_loss[-10:-5])):.6f}")
                    break

    def forward_closure(self,
                        iteration: int,
                        batch_size: Optional[int] = None,
                        checkpoint_segments: Optional[int] = 1,
                        start_iter=0,
                        cutoff_freq=None):
        n_shots = self.propagator.src_n
        if batch_size is None or batch_size > n_shots:
            batch_size = n_shots
        pbar_epoch = tqdm(range(start_iter, start_iter+iteration), position=0, leave=False, colour='green', ncols=80)
        self.true_epoch = 0
        self.forw = None
        for i in pbar_epoch:
            def closure():
                self.optimizer.zero_grad()
                loss_batch = 0
                pbar_batch = tqdm(range(math.ceil(n_shots / batch_size)), position=1, leave=False, colour='red', ncols=80)
                for batch in pbar_batch:
                    begin_index = 0 if batch == 0 else batch * batch_size
                    end_index = n_shots if batch == math.ceil(n_shots / batch_size) - 1 else (batch+1) * batch_size
                    shot_index = np.arange(begin_index, end_index)
                    record_waveform = self.propagator.forward(shot_index=shot_index, checkpoint_segments=checkpoint_segments)
                    rcv_vy = record_waveform["vy"]
                    forward_wavefield_vy = record_waveform["forward_wavefield_vy"]
                    if batch == 0:
                        self.forw = forward_wavefield_vy.cpu().detach().numpy()
                    else:
                        self.forw += forward_wavefield_vy.cpu().detach().numpy()
                    syn_vy = rcv_vy
                    data_loss = self.calculate_loss(syn_vy, self.obs_vy[shot_index], self.waveform_normalize, self.loss_fn, cutoff_freq, self.propagator.dt)
                    if self.regularization_fn is not None:
                        reg_loss_vs = self.calculate_regularization_loss(self.model.vs, self.regularization_weights_x[0], self.regularization_weights_z[0], self.regularization_fn)
                        reg_loss_rho = self.calculate_regularization_loss(self.model.rho, self.regularization_weights_x[1], self.regularization_weights_z[1], self.regularization_fn)
                        loss_batch += data_loss.item() + (reg_loss_vs + reg_loss_rho).item()
                        loss = data_loss + reg_loss_vs + reg_loss_rho
                    else:
                        loss_batch += data_loss.item()
                        loss = data_loss
                    loss.backward()
                    if math.ceil(n_shots / batch_size) == 1:
                        pbar_batch.set_description(f"Shot: {begin_index} to {end_index}")
                self.true_epoch += 1
                if self.model.get_requires_grad("vs"):
                    self.process_gradient(self.model.vs, forw=self.forw, idx=0)
                if self.model.get_requires_grad("rho"):
                    self.process_gradient(self.model.rho, forw=self.forw, idx=1)
                return loss_batch
            
            loss_batch = self.optimizer.step(closure=closure)
            self.scheduler.step()
            self.model.forward()
            if self.cache_result:
                temp_vs = self.model.vs.cpu().detach().numpy()
                temp_rho = self.model.rho.cpu().detach().numpy()
                self.iter_vs.append(temp_vs)
                self.iter_rho.append(temp_rho)
                self.iter_loss.append(loss_batch)
                self.save_figure(i, temp_vs, model_type="vs")
                self.save_figure(i, temp_rho, model_type="rho")
                if self.model.get_requires_grad("vs"):
                    grads_vs = self.model.vs.grad.cpu().detach().numpy()
                    self.save_figure(i, grads_vs, model_type="grad_vs")
                    self.iter_vs_grad.append(grads_vs)
                if self.model.get_requires_grad("rho"):
                    grads_rho = self.model.rho.grad.cpu().detach().numpy()
                    self.save_figure(i, grads_rho, model_type="grad_rho")
                    self.iter_rho_grad.append(grads_rho)
            pbar_epoch.set_description("Iter:{},Loss:{:.4}".format(i+1, loss_batch))