"""
================================================================================
ADFWI BASELINE (Modified for ICLR 2026 Submission)
--------------------------------------------------------------------------------
This code is based on the ADFWI framework by LiuFeng (SJTU, https://github.com/liufeng2317/ADFWI),
originally released under the MIT License. This version has been modified for ICLR 2026.
Original Author: LiuFeng (SJTU) | Email: liufeng2317@sjtu.edu.cn
================================================================================
"""

from typing import Optional,Union,List
import os
import math
import torch
import numpy as np
#from tqdm import tqdm
from tqdm.notebook import tqdm
from uniSI.model       import AbstractModel
from uniSI.propagator  import AcousticPropagator,GradProcessor
from uniSI.survey      import SeismicData
from uniSI.inversion.misfit  import Misfit,Misfit_NIM
from uniSI.inversion.regularization import Regularization
from uniSI.inversion.optimizer import NLCG
from uniSI.utils       import numpy2tensor
from uniSI.view        import plot_model
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import time
try: from torch import sparse; sparse.spsolve; print("✅ torch.sparse.spsolve available")
except (ImportError, AttributeError): print("❌ torch.sparse.spsolve not available"); exit(1)

from uniSI.inversion.multiScaleProcessing import lpass

class SHInversion_Freq(torch.nn.Module):
    """Acoustic Full waveform inversion class
    """
    def __init__(self, model:AbstractModel,
                 optimizer:torch.optim.Optimizer,scheduler:torch.optim.lr_scheduler,
                 loss_fn:Union[Misfit,torch.autograd.Function],
                 obs_data:SeismicData,
                 gradient_processor: Union[GradProcessor,List[GradProcessor]] = None,
                 regularization_fn:Optional[Regularization]                   = None, 
                 regularization_weights_x:Optional[List[Union[float]]]        = [0,0], # vs/Q in x direction
                 regularization_weights_z:Optional[List[Union[float]]]        = [0,0], # vs/Q in z direction
                 waveform_normalize:Optional[bool]                            = True,
                 cache_result:Optional[bool]                                  = True,
                 save_fig_epoch:Optional[int]                                 = -1,
                 save_fig_path:Optional[str]                                  = "",
                 min_improvement:Optional[float]                              = 1e-4,
                 Q_min:float                                                  = None,
                 Q_max:float                                                  = None,
                 grad_processor_depth_weight:Optional[bool]                     = True
                ):
        """
        Parameters:
        --------------
            model (Model)                                   : the velocity model class
            optimizer (torch.optim.Optimizer)               : the pytorch optimizer
            scheduler (torch.optim.scheduler)               : the pytorch learning rate decay scheduler
            loss_fn   (Misfit or torch.autograd.Function)   : the misfit function
            regularization_fn (Regularization)              : the regularization function
            obs_data  (SeismicData)                         : the observed dataset
            gradient_processor (GradProcessor)              : the gradient processor
            waveform_normalize (bool)   : normalize the waveform or not, default True
            cache_result (bool)         : save the temp result of the inversion or not
        """
        super().__init__()
        self.model                      = model
        self.optimizer                  = optimizer
        self.scheduler                  = scheduler
        self.loss_fn                    = loss_fn
        self.regularization_fn          = regularization_fn
        self.regularization_weights_x   = regularization_weights_x
        self.regularization_weights_z   = regularization_weights_z
        self.obs_data                   = obs_data
        self.gradient_processor         = gradient_processor
        self.device                     = self.model.device
        self.dtype                      = self.model.dtype 
        self.Q_min                      = Q_min 
        self.Q_max                      = Q_max
        self.grad_processor_depth_weight = grad_processor_depth_weight
        
        # observed data
        try: from torch import sparse; sparse.spsolve; #print("✅ torch.sparse.spsolve available")
        except (ImportError, AttributeError): print("❌ torch.sparse.spsolve not available"); exit(1)
        self.waveform_normalize = waveform_normalize
        obs_p   = self.obs_data.data["p"]
        obs_p   = numpy2tensor(obs_p, dtype=torch.complex128).to(self.device)
        if self.waveform_normalize:
            # obs_p = obs_p/(torch.max(torch.abs(obs_p),axis=1,keepdim=True).values)
            obs_p = self._normalize(obs_p)
        self.obs_p = obs_p

        
        # save result
        self.cache_result   = cache_result
        self.iter_vs, self.iter_Q = [],[]
        self.iter_vs_grad, self.iter_Q_grad = [],[]
        self.iter_loss      = []

        self.min_improvement = min_improvement 
        
        # save figure
        self.save_fig_epoch = save_fig_epoch
        self.save_fig_path  = save_fig_path
        # Check if the save_fig_path exists
        if not os.path.exists(self.save_fig_path):
            os.makedirs(self.save_fig_path)
    
    def check_early_stopping(self, loss_history, min_improvement, min_iter):
        if len(loss_history) < min_iter:  
            return False
        loss_history_tensor = torch.tensor(loss_history)
        recent_avg = torch.mean(loss_history_tensor[-min_iter//2:])
        previous_avg = torch.mean(loss_history_tensor[-min_iter:-min_iter//2])
        
        improvement = torch.abs(previous_avg - recent_avg)

        print("check_early_stopping Done")
        
        return improvement < min_improvement, improvement

    def _normalize(self,data):
        mask = torch.sum(torch.abs(data),axis=1,keepdim=True) == 0
        max_val = torch.max(torch.abs(data),axis=1,keepdim=True).values
        max_val = max_val.masked_fill(mask, 1)
        data = data/max_val
        return data
    
    # misfits calculation
    def calculate_loss(self, synthetic_waveform, observed_waveform, normalization, loss_fn):
        """
        Generalized function to calculate misfit loss for a given component.
        """
        if normalization:
            # synthetic_waveform = synthetic_waveform / (torch.max(torch.abs(synthetic_waveform), axis=1, keepdim=True).values)
            synthetic_waveform = self._normalize(synthetic_waveform)

        if isinstance(loss_fn, Misfit):
            return loss_fn.forward(synthetic_waveform, observed_waveform)
        elif isinstance(loss_fn,Misfit_NIM):
            return loss_fn.apply(synthetic_waveform,observed_waveform,loss_fn.p,loss_fn.trans_type,loss_fn.theta)
        else:
            return loss_fn.apply(synthetic_waveform, observed_waveform)
    
    # regularization calculation
    def calculate_regularization_loss(self, model_param, weight_x, weight_z, regularization_fn):
        """
        Generalized function to calculate regularization loss for a given parameter.
        """
        regularization_loss = torch.tensor(0.0, device=model_param.device)
        # Check if the parameter requires gradient
        if model_param.requires_grad:
            # Set the regularization weights for x and z directions
            regularization_fn.alphax = weight_x
            regularization_fn.alphaz = weight_z
            # Calculate regularization loss if any weight is greater than zero
            if regularization_fn.alphax > 0 or regularization_fn.alphaz > 0:
                regularization_loss = regularization_fn.forward(model_param)
        return regularization_loss

    # gradient precondition
    def process_gradient(self, parameter,forw,idx=None):
        with torch.no_grad():
            grads = parameter.grad.detach()
            vmax = torch.max(parameter.detach())
            # Apply gradient processor
            if isinstance(self.gradient_processor, GradProcessor):
                self.gradient_processor.depth_weight = self.grad_processor_depth_weight
                grads = self.gradient_processor.forward(nz=self.model.nz, nx=self.model.nx, vmax=vmax, grad=grads, forw=forw)
            else:
                grads = self.gradient_processor[idx].forward(nz=self.model.nz, nx=self.model.nx, vmax=vmax, grad=grads, forw=forw)
            # Convert grads back to tensor and assign
            grads_tensor = grads.to(self.model.device)
            parameter.grad = grads_tensor

    def save_figure(self,i,data,model_type="vs"):
        if self.save_fig_epoch == -1:
            pass
        elif i%self.save_fig_epoch == 0:
            if os.path.exists(self.save_fig_path):
                if model_type == "vs":
                    plot_model(data,title=f"Iteration {i}",
                            dx=self.model.dx,dz=self.model.dz,
                            vmin=self.vs_min,vmax=self.vs_max,
                            save_path=os.path.join(self.save_fig_path,f"{model_type}_{i}.png"),show=False,cmap='coolwarm')
                elif model_type == "Q":
                    plot_model(data,title=f"Iteration {i}",
                            dx=self.model.dx,dz=self.model.dz,
                            vmin=self.Q_min,vmax=self.Q_max,
                            save_path=os.path.join(self.save_fig_path,f"{model_type}_{i}.png"),show=False,cmap='coolwarm')
                elif model_type == "grad_vs":
                    plot_model(data,title=f"Iteration {i}",
                            dx=self.model.dx,dz=self.model.dz,
                            save_path=os.path.join(self.save_fig_path,f"{model_type}_{i}.png"),show=False,cmap='coolwarm',
                            vmin=-torch.abs(data).max(),vmax=torch.abs(data).max())
                elif model_type == "grad_Q":
                    plot_model(data,title=f"Iteration {i}",
                            dx=self.model.dx,dz=self.model.dz,
                            save_path=os.path.join(self.save_fig_path,f"{model_type}_{i}.pdf"),show=False,cmap='seismic',
                            vmin=-np.abs(data).max(),vmax=np.abs(data).max())
        return
    
    def forward(self,
                iteration:int,
                batch_size:Optional[int]            = None,
                checkpoint_segments:Optional[int]   = 1 ,
                start_iter                          = 0,
                cutoff_freq                         = None,
                ):
        """
        Parameters:
        ------------
            iteration (int)             : the iteration number of inversion
            fd_order (int)              : the order of finite difference
            batch_size (int)            : the shots for each batch, default -1 means use all the shots
            checkpoint_segments (int)   : seperate all the time seris into N segments for saving memory, default 1
        """
        #if isinstance(self.optimizer,torch.optim.LBFGS) or isinstance(self.optimizer,NLCG):
        #    return self.forward_closure(iteration=iteration,batch_size=batch_size,checkpoint_segments=checkpoint_segments,start_iter=start_iter,cutoff_freq=cutoff_freq)

        # epoch
        print("Iteration Range: {} - {}".format(start_iter,start_iter+iteration))
        for i in range(start_iter, start_iter + iteration):
            start_time = time.time()

            #self.model.forward()    
            self.model.fdfd_setup()
            device = self.model.device
            ns = self.model.ns
            nf = self.model.nf
            if batch_size is None or batch_size > nf:
                batch_size = nf

            full_pf = torch.zeros((self.model.n_total, nf, ns), dtype=torch.complex128, device=device)
            rec_pf  = torch.zeros((self.model.nr, nf, ns), dtype=torch.complex128, device=device)

            total_loss = 0.0
            self.optimizer.zero_grad()
            pbar_batch = range(math.ceil(nf / batch_size))
            for batch in pbar_batch:
                begin_index = 0 if batch == 0 else batch * batch_size
                end_index = nf if batch == math.ceil(nf / batch_size) - 1 else (batch + 1) * batch_size
                freq_index = np.arange(begin_index, end_index)

                for m in freq_index:
                    # sol shape: (n_total, ns)
                    sol = self.model.fdfd_oneFreq(m)
                    full_pf[:, m, :] = sol
                    rec_pf[:, m, :] = sol[self.model.rec_ind, :]


                data_loss = self.calculate_loss(
                    rec_pf[:, freq_index, :],
                    self.obs_p[:, freq_index, :],
                    self.waveform_normalize,
                    self.loss_fn
                )
                if self.regularization_fn is not None:
                    regularization_loss_Q = self.calculate_regularization_loss(
                        self.model.Q,
                        self.regularization_weights_x[1],
                        self.regularization_weights_z[1],
                        self.regularization_fn
                    )
                    loss_batch = data_loss +  regularization_loss_Q
                else:
                    loss_batch = data_loss

                total_loss = total_loss+loss_batch
                print(f"   Frequency Batch {batch+1}/{len(pbar_batch)} solved. Loss of this Batch: {loss_batch.item():.4f}")

            total_loss.backward()


            self.optimizer.step()

            self.scheduler.step()
            lr_after = self.optimizer.param_groups[0]['lr']



            self.iter_loss.append(total_loss.detach().item())

            torch.cuda.empty_cache()

            if self.cache_result:
                # model
                temp_Q  = self.model.Q.cpu().detach().numpy()
                self.iter_Q.append(temp_Q)
                self.save_figure(i,temp_Q    , model_type="Q")
                # gradient
                if self.model.get_requires_grad("Q"):
                    grads_Q  = self.model.Q.grad.cpu().detach().numpy()
                    self.save_figure(i,grads_Q   , model_type="grad_Q")
                    self.iter_Q_grad.append(grads_Q)

            self.true_epoch = 0
            print(f"Iteration {i+1}/{start_iter+iteration}, Loss: {total_loss:.4f}, Learning rate: {lr_after}, Time: {time.time()-start_time:.2f}s")

            min_iter =50
            if i >= min_iter:  
                # If loss in nan, end of training like early stopping
                if torch.isnan(total_loss):
                    print(f"Loss is nan, end of training at iteration {i+1}")
                    break
                early_stopping, improvement = self.check_early_stopping(self.iter_loss, self.min_improvement,min_iter)
                print(f"Improvement: {improvement:.8f}")
                if early_stopping:
                    print(f"Early stopping triggered at iteration {i+1}")
                    print(f"Recent average loss: {np.mean(self.iter_loss[-5:]):.4f}")
                    print(f"Previous average loss: {np.mean(self.iter_loss[-10:-5]):.4f}")
                    print(f"Improvement: {((np.mean(self.iter_loss[-10:-5]) - np.mean(self.iter_loss[-5:])) / np.mean(self.iter_loss[-10:-5])):.6f}")
                    break
                