# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Loss functions used in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

from networkx import sigma
import torch
from torch_utils import persistence
from einops import rearrange
from torch_utils import training_stats
from training.evaluation_utils import compute_pde_loss
#----------------------------------------------------------------------------
# Loss function corresponding to the variance preserving (VP) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VPLoss:
    def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.epsilon_t = epsilon_t

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()

#----------------------------------------------------------------------------
# Loss function corresponding to the variance exploding (VE) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".

@persistence.persistent_class
class VELoss:
    def __init__(self, sigma_min=0.02, sigma_max=100):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def __call__(self, net, images, labels, augment_pipe=None):
        rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
        weight = 1 / sigma ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------
# Improved loss function proposed in the paper "Elucidating the Design Space
# of Diffusion-Based Generative Models" (EDM).

@persistence.persistent_class
class EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, debug=False):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.debug = debug

    def __call__(self, net, images, labels=None, augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        # self.debug and print("rnd_normal shape: ", rnd_normal.shape)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2

        # We want to augment pipe both x and y at once. #cond diffusion code
        x_dim = images.size(1)
        all_images = torch.cat((images, labels), dim=1)
        all_images_augmented, augment_labels = augment_pipe(all_images) if augment_pipe is not None else (images, None)
        # Extract out the x and y components
        y = all_images[:, 0:x_dim]
        labels = all_images[:,  x_dim::]
        # self.debug and print("labels shape: ", labels.shape) # BS,H,W
        # self.debug and print(" y shape: ", y.shape) # BS,H,W
        # labels = rearrange(labels, 'bs h w -> bs 1 h w')


        n = torch.randn_like(y) * sigma
        # self.debug and print("shape of n: ", n.shape)
        # self.debug and print("shape of y+n: ", (y+n).shape)
        D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
        loss = weight * ((D_yn - y) ** 2)
        return loss

#----------------------------------------------------------------------------
# EDM Loss function with PDE residual 


@persistence.persistent_class
class EDMLossResidual:
    def __init__(self, noise_src, sampler, P_mean=-1.2, P_std=1.2, sigma_data=0.5, 
                 pde_residual_step_mode="two_step", dataset_obj=None, 
                 normalize_pde_residual=False, training_mode='conditional', guided_pde_residual_mode=False, debug=False):
        """
        Args:
            step_mode (str): Defines the training strategy.
                - 'one_step': Compute PDE residual once before net forward pass.
                - 'two_step': Compute PDE residual after first forward pass.
                - 'iterative': Compute PDE residual iteratively and simulates a 2-step diffusion process
            dataset_obj: Dataset object with denorm_output and denorm_input methods
        """
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        print("Sigma data: ", self.sigma_data)
        self.noise_src = noise_src
        self.sampler = sampler
        self.pde_residual_step_mode = pde_residual_step_mode
        self.sigma_decay = 0.9 #TODO make this a hyperparameter
        self.debug = debug
        self.training_mode = training_mode
        self.pde_res_tracker = None  # Will be set from training loop
        self.dataset_obj = dataset_obj  # Store the dataset object
        self.normalize_pde_residual = normalize_pde_residual
        self.guided_pde_residual_mode = guided_pde_residual_mode
        print("Guided PDE residual mode:", self.guided_pde_residual_mode)

    def set_pde_res_tracker(self, tracker):
        """Set the PDE residual tracker"""
        self.pde_res_tracker = tracker

    def __call__(self, net, images, pde_loss_fn, pde_direction, labels=None, mask=None, augment_pipe=None):
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        # self.debug and print("rnd_normal shape: ", rnd_normal.shape)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        self.debug and print("sigma shape: ", sigma.shape) # BS,1,1,1
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        self.debug and print("weight shape: ", weight.shape) # BS,1,1,1

        # We want to augment pipe both x and y at once. #cond diffusion code
        x_dim = images.size(1)
        if self.training_mode == 'unified':
            labels = images * mask  # In unified mode, labels are masked version of images
            # print("In unified mode, using masked images as labels")
            # print("Mask shape: ", mask.shape)
            # print("Labels shape: ", labels.shape)
            # print("labels channel 0 min, max: ", labels[:,0:1].min(), labels[:,0:1].max(), labels[:,0:1].mean())
            # print("labels channel 1 min, max: ", labels[:,1:2].min(), labels[:,1:2].max(), labels[:,1:2].mean())
            # print("masks channel 0 min, max: ", mask[:,0:1].min(), mask[:,0:1].max(), mask[:,0:1].mean())
            # print("masks channel 1 min, max: ", mask[:,1:2].min(), mask[:,1:2].max(), mask[:,1:2].mean())
        all_images = torch.cat((images, labels), dim=1)
        all_images_augmented, augment_labels = augment_pipe(all_images) if augment_pipe is not None else (images, None)
        # Extract out the x and y components
        y = all_images[:, 0:x_dim]
        labels = all_images[:,  x_dim::]
        # print("y shape: ", y.shape)
        # print("labels shape: ", labels.shape)
        # print("labels min, max: ", labels.min(), labels.max(), labels.mean())

        # self.debug and print("labels shape: ", labels.shape) # BS,H,W
        # self.debug and print(" y shape: ", y.shape) # BS,H,W
        # labels = rearrange(labels, 'bs h w -> bs 1 h w')

        if self.noise_src == "gauss":
            # Gaussian noise
            # print("Using Gaussian noise")
            n = torch.randn_like(y) * sigma
        elif self.noise_src == "grf":
            # print("Using Gaussian Random Field noise")
            # Gaussian Random Field noise
            n = self.sampler.sample(y.size(0)) * sigma
        else:
            raise ValueError(f"Unknown noise source: {self.noise_src}")
        noisy_y = y + n

        self.debug and print("shape of n: ", n.shape)
        self.debug and print("shape of y+n: ", (y+n).shape)

        # if self.pde_residual_step_mode == "one_step":
        # Compute residual for noisy input
        if self.training_mode == 'conditional':
            pde_residual = compute_pde_loss(pde_loss_fn, pde_direction, noisy_y, labels, device=images.device).unsqueeze(1)
        if self.training_mode == 'unified':
            noisy_pde_input = noisy_y[:, 0:1]
            noisy_pde_output = noisy_y[:, 1:2]
            if self.guided_pde_residual_mode:
                # print("Forcing PDE residual computation in unified mode")
                pde_input = (1-mask[:, 0:1]) * noisy_pde_input + labels[:, 0:1] # labels is mask * images
                pde_output = (1-mask[:, 1:2]) * noisy_pde_output + labels[:, 1:2]
                pde_residual = compute_pde_loss(pde_loss_fn, pde_direction, pde_input, pde_output, device=images.device, training_mode=self.training_mode).unsqueeze(1)
            else:
                pde_residual = compute_pde_loss(pde_loss_fn, pde_direction, noisy_pde_input, noisy_pde_output, device=images.device, training_mode=self.training_mode).unsqueeze(1)
        self.debug and print("Residual computed before forward pass, shape:", pde_residual.shape) # BS,1,H,W
        
        # Normalize the residual to have std = sigma_data
        if self.normalize_pde_residual:
            self.debug and print("Normalizing PDE residual")
            pde_residual = self.get_normalized_pde_residual(pde_residual, self.sigma_data)
        # breakpoint()
        # Apply mask to residual if provided
        # if mask is not None:
        #     pde_residual = pde_residual * mask

        # print("noisy_y shape: ", noisy_y.shape)
        # print("sigma shape: ", sigma.shape)
        # print("labels shape: ", labels.shape)
        # print("pde_residual shape: ", pde_residual.shape)
        # print("mask shape: ", mask.shape if mask is not None else "N/A")
        D_yn = net(noisy_y, sigma, labels, pde_residual=pde_residual, mask=mask, augment_labels=augment_labels)
        self.debug and print("shape of D_yn: ", D_yn.shape) # BS,1,H,W

        # elif self.pde_residual_step_mode == "two_step":
        #     with torch.no_grad():
        #         D_yn = net(y + n, sigma, labels, pde_residual=None, augment_labels=augment_labels)

        #     pde_residual =  compute_pde_loss(pde_loss_fn, pde_direction, D_yn, labels).unsqueeze(1)
        #     self.debug and print("shape of residual: ", pde_residual.shape)
        #     D_yn = net(y + n, sigma, labels, pde_residual=pde_residual, augment_labels=augment_labels)  

        # else:
        #     # Forward pass without residual
        #     with torch.no_grad():
        #         D_yn = net(y + n, sigma, labels, pde_residual=None, augment_labels=augment_labels)

        #     # Update sigma and compute residual
        #     sigma = sigma * self.sigma_decay  # Reduce sigma
        #     pde_residual = compute_pde_loss(pde_loss_fn, pde_direction, D_yn, labels).unsqueeze(1)
        #     self.debug and print("Residual computed after first forward pass, shape:", pde_residual.shape)

        #     # Second forward pass with residual
        #     D_yn = net(D_yn, sigma, labels, pde_residual=pde_residual, augment_labels=augment_labels)

        # self._track_simple_stats(pde_residual, sigma)
        
        loss = weight * ((D_yn - y) ** 2)
        if self.pde_res_tracker is not None:
            with torch.no_grad():
                if self.training_mode == 'conditional':
                    D_yn_denorm = self.dataset_obj.denorm_output(D_yn)
                    labels_denorm = self.dataset_obj.denorm_input(labels)
                    y_denorm = self.dataset_obj.denorm_output(y)
                    output_pde_residual = compute_pde_loss(pde_loss_fn, pde_direction, D_yn_denorm, labels_denorm).unsqueeze(1)
                    # print("Training mode:", self.training_mode)
                    # print("D_yn_denorm shape:", D_yn_denorm.shape)
                    # print("labels_denorm shape:", labels_denorm.shape)
                    # print("Output pde residual shape:", output_pde_residual.shape)
                if self.training_mode == 'unified':
                    D_yn_denorm = self.dataset_obj.denorm_tensor(D_yn)
                    y_denorm = self.dataset_obj.denorm_tensor(y)
                    labels_denorm = self.dataset_obj.denorm_tensor(labels)
                    y_input_denorm = y_denorm[:, 0:1]
                    y_output_denorm = y_denorm[:, 1:2]
                    output_pde_residual = compute_pde_loss(pde_loss_fn, pde_direction, D_yn_denorm[:, 0:1], D_yn_denorm[:, 1:2], training_mode=self.training_mode).unsqueeze(1)
                    # output_pde_residual = compute_pde_loss(pde_loss_fn, pde_direction, y_output_denorm, y_input_denorm ).unsqueeze(1)
                    # print("Training mode:", self.training_mode)
                    # print("D_yn_denorm shape:", D_yn_denorm.shape)
                    # print("D_yn_denorm input shape:", D_yn_denorm[:, 0:1].shape)
                    # print("D_yn_denorm output shape:", D_yn_denorm[:, 1:2].shape)
                    # print("Output pde residual shape:", output_pde_residual.shape)
            self.debug and print("Residual computed after forward pass, shape:", output_pde_residual.shape) # BS,1,H,W
            # Log residuals if tracker is available
            self.pde_res_tracker.log_residuals(output_pde_residual, sigma, getattr(self, '_current_step', 0))
            current_step = getattr(self, '_current_step', 0)

            # For forward problems: input=labels (a), ground_truth=images (u), prediction=D_yn_denorm (u_pred)
            # For inverse problems: input=images (u), ground_truth=labels (a), prediction=D_yn_denorm (a_pred)
            if self.training_mode == 'conditional':
                self.debug and print("Conditional mode for logging")
                self.pde_res_tracker.log_training_data(
                    input_data=labels_denorm,
                    ground_truth=y_denorm,
                    predictions=D_yn_denorm,
                    pde_residuals=output_pde_residual,
                    sigmas=sigma.squeeze(),
                    step=current_step,
                    direction=pde_direction,
                    pde_loss_fn=pde_loss_fn
                )
            if self.training_mode == 'unified':
                self.pde_res_tracker.log_training_data_unified(
                    ground_truth_a= y_denorm[:, 0:1],
                    ground_truth_u= y_denorm[:, 1:2],
                    model_input_a= labels_denorm[:, 0:1] * mask[:, 0:1].detach().cpu(),
                    model_input_u= labels_denorm[:, 1:2] * mask[:, 1:2].detach().cpu(),
                    predictions_a= D_yn_denorm[:, 0:1],
                    predictions_u= D_yn_denorm[:, 1:2],
                    mask_a=mask[:, 0:1],
                    mask_u=mask[:, 1:2],
                    pde_residuals=output_pde_residual,
                    sigmas=sigma.squeeze(),
                    step=current_step,
                    direction=pde_direction,
                    pde_loss_fn=pde_loss_fn
                )


        return loss
    
    def set_current_step(self, step):
        """Set current training step for logging"""
        self._current_step = step

    def get_normalized_pde_residual(self, residual, sigma_data=0.5):
        """
        Normalize PDE residual to have standard deviation matching sigma_data
        
        Args:
            residual: PDE residual tensor (B, 1, H, W)
            sigma_data: Target standard deviation (typically 0.5 or 1.0)
            
        Returns:
            Normalized residual with std = sigma_data
        """
        # Per-sample spatial normalization first (mean=0, std=1)
        mean = residual.mean(dim=[2, 3], keepdim=True)
        std = residual.std(dim=[2, 3], keepdim=True) + 1e-8
        normalized = (residual - mean) / std
        
        # Scale to target sigma_data
        normalized = normalized * sigma_data
        
        # # Clip outliers to avoid extreme values (optional)
        # threshold = 3.0 * sigma_data
        # normalized = torch.clamp(normalized, -threshold, threshold)
        
        return normalized

#----------------------------------------------------------------------------
# EDM Loss function wrapper

class EDMLossWrapper:
    def __init__(self, loss_type="edm", noise_src="grf", sampler=None, pde_loss_fn=None, 
                 pde_direction=None, training_mode=None, pde_residual_step_mode="two_step", sigma_data=None, 
                 dataset_obj=None, normalize_pde_residual=False, guided_pde_residual_mode=False, debug=False):
        self.loss_type = loss_type.lower()
        self.noise_src = noise_src
        self.sampler = sampler
        self.training_mode = training_mode
        self.guided_pde_residual_mode = guided_pde_residual_mode
        self.pde_loss_fn = pde_loss_fn
        self.pde_direction = pde_direction
        self.debug = debug
        self.sigma_data = sigma_data

        if self.loss_type == "edm":
            self.loss_fn = EDMLoss(noise_src=self.noise_src, debug=debug)
        elif self.loss_type == "edm_residual":
            self.loss_fn = EDMLossResidual(noise_src=self.noise_src, sampler=self.sampler, training_mode=self.training_mode, 
                                           pde_residual_step_mode=pde_residual_step_mode, 
                                           sigma_data=sigma_data, dataset_obj=dataset_obj, 
                                           normalize_pde_residual=normalize_pde_residual, guided_pde_residual_mode=guided_pde_residual_mode, debug=debug)
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")

    def __call__(self, net, images, labels, masks=None, augment_pipe=None):
        self.debug and print("In the Loss Wrapper")
        if self.loss_type == "edm":
            self.debug and print("Calling regular edm loss")
            return self.loss_fn(net, images, labels, augment_pipe)
        elif self.loss_type == "edm_residual":
            self.debug and print("calling edm residual loss")
            return self.loss_fn(net, images, self.pde_loss_fn, self.pde_direction, labels, masks, augment_pipe)
