# File adapted from PIDM

from einops import rearrange
import os
from pathlib import Path
import numpy as np
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def fix_seeds(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

def noop(*args, **kwargs):
    pass

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

class EMA(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}
        self.backup = {}

    def register(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module, backup=True):
        for name, param in module.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                if backup:
                    self.backup[name] = param.data.clone()
                param.data.copy_(self.shadow[name].data)

    def restore(self, module):
        assert hasattr(self, 'backup')
        for name, param in module.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data.copy_(self.backup[name])
        self.backup = {}

    def ema_copy(self, module):
        module_copy = type(module)(module.config).to(module.config.device)
        module_copy.load_state_dict(module.state_dict())
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        self.shadow = state_dict

def save_model(model, train_iterations, output_save_dir):

    os.makedirs(Path(output_save_dir, 'model/'), exist_ok=True)

    # save yaml file for later runs
    # with open(output_save_dir + '/model/model.yaml', 'w') as yaml_file:
    #     yaml.dump(dict(config), yaml_file, default_flow_style=False)    
    save_dir_model = output_save_dir + '/model/checkpoint_' + str(train_iterations) + '.pt'
    save_obj = dict(
        model = model.state_dict()
    )
    # save model to path
    with open(save_dir_model, 'wb') as f:
        torch.save(save_obj, f)
    print(f'\ncheckpoint saved to {output_save_dir}/.')

def load_model(path, model, strict = True):

    # to avoid extra GPU memory usage in main process when using Accelerate
    with open(path, 'rb') as f:
        loaded_obj = torch.load(f, map_location='cpu')
    try:
        model.load_state_dict(loaded_obj['model'], strict = strict)
    except RuntimeError:
        print('Failed loading state dict.')
    print('\nCheckpoint loaded from {}'.format(path))

    return model

def estimate_dataset_std_sample(dataset, sample_size=1000):
    # Create indices for random sampling
    indices = torch.randperm(len(dataset))[:sample_size]

    # Collect samples
    samples = []
    for idx in indices:
        # Get sample (handle both tuple/list returns and direct tensor returns)
        sample = dataset[idx]
        if isinstance(sample, (tuple, list)):
            sample = sample[0]
        samples.append(sample)
    
    # Stack samples and compute statistics
    samples_tensor = torch.cat(samples, dim=0)
    return torch.std(samples_tensor).item()

def binning_function(num_bins, t, values, t_max=1.0):
    """
    Bins values into fixed-width bins based on corresponding time values.
    
    This function creates num_bins evenly spaced bins from 0 to t_max, and for each bin,
    calculates the mean of the values where the corresponding time value (t) falls
    within that bin. Values with NaN entries are excluded from the calculation.
    If no values fall within a particular bin, that bin's value is set to 0.
    
    Parameters:
    -----------
    num_bins : int
        Number of bins to create. Bins will be evenly spaced from 0 to t_max.
    t : torch.Tensor
        Time values of shape (n_samples,) or (n_samples, 1). 
        Used to determine which bin each value belongs to.
    values : torch.Tensor
        Values to bin, of shape (n_samples,). These are averaged within each bin.
    t_max : float, optional
        Maximum time value for binning. Default is 1.0.
    
    Returns:
    --------
    torch.Tensor
        Tensor of shape (num_bins,) containing the mean of values in each bin.
        Bins with no values are set to 0.
    
    Notes:
    ------
    - The function handles 2D time tensors by squeezing them to 1D
    - NaN values in the input are automatically excluded from bin calculations
    - The function is designed to work across multiple batches, which is why
      empty bins are set to 0 rather than NaN
    """
    bins = torch.linspace(0, t_max, num_bins+1)
    if t.ndim == 2:
        t = t.squeeze().detach() # (n_samples,), drop out the extra dimension

    values = values.detach()

    # Initialize array to store binned values
    binned_values = torch.zeros(num_bins)
    
    # For each bin, find values where t falls within bin range
    for i in range(num_bins):
        mask = ((t >= bins[i]) & (t < bins[i+1])).flatten()
        # Only include non-NaN values in the mask
        # mask = torch.logical_and(mask, torch.logical_not(torch.isnan(values)))
        if mask.any():
            binned_values[i] = values[mask].mean()
        else:
            binned_values[i] = 0 # we use this function to accumulate values aross multiple batches, so we can set missing values to 0
    return binned_values

def evaluate_log_likelihood(
        model: torch.nn.Module, 
        samples: torch.Tensor,
        tmax: float = 80.,
        tmin: float = 0.002,
        rho: float = 7,
        num_steps: int = 100,
        **net_fwd_kwargs
    ):
    """
    Evaluate the log likelihood of given samples under the diffusion model.
    
    Args:
        target: Target distribution
        model: Score-based model that approximates the score function
        samples: Input samples to evaluate [batch_size, ...]
        tmax: Maximum diffusion time
        tmin: Minimum diffusion time
        rho: Parameter controlling the time schedule
        num_steps: Number of integration steps
        
    Returns:
        log_likelihood: Log likelihood of the input samples
    """
    # Generate time steps from tmin to tmax
    ts = tmin ** (1/rho) + np.arange(num_steps)/(num_steps-1) * (tmax ** (1/rho) - tmin ** (1/rho))
    ts = ts ** rho
    
    # dimensionality
    d = np.prod(samples.shape[1:])
    
    with torch.enable_grad():
        current_samples = samples.clone().detach().requires_grad_(True)
        
        # Start with the log probability under the data distribution
        # Initialize this to zero as we'll track the change in log probability
        log_likelihood = torch.zeros(samples.shape[0], device=samples.device)
        
        # Forward diffusion process (ODE from data to noise)
        for i in range(1, ts.shape[0]):
            t_prev = torch.ones(current_samples.shape[0], 1, 1, 1).to(current_samples.device) * ts[i-1]
            t = torch.ones(current_samples.shape[0], 1, 1, 1).to(current_samples.device) * ts[i]
            
            # Get score estimate from the model
            x_hat,_ = model(current_samples, t_prev, **net_fwd_kwargs)
            f = -(x_hat - current_samples)/t_prev
            
            # Estimate trace of Jacobian using Hutchinson's estimator
            epsilon = torch.randint(0, 2, current_samples.shape, device=current_samples.device).float() * 2 - 1
            
            vjp = torch.autograd.grad(f, current_samples, epsilon, create_graph=False, retain_graph=False)[0]
            trace_est = (vjp * epsilon).sum(dim=tuple(range(1, len(current_samples.shape))))
            
            # Update log likelihood (negative trace because we're going forward)
            dt = t - t_prev
            log_likelihood = log_likelihood + trace_est * dt
            
            # Forward ODE step (deterministic, no noise)
            with torch.no_grad():
                current_samples = current_samples + f * dt
                
            current_samples = current_samples.detach().requires_grad_(True)
                    
        # Add log probability of the final noise distribution (Gaussian with variance tmax^2)
        log_likelihood = log_likelihood - 0.5 * d * np.log(2 * np.pi * tmax**2) - \
                        torch.sum(current_samples**2, dim=tuple(range(1, len(current_samples.shape)))) / (2 * tmax**2)
                
        return log_likelihood

def wasserstein_distance(real_images, gen_images):
    """
    Calculate a simplified Wasserstein distance between image distributions
    Works directly with 2-channel images, no expansion needed
    """
    # Calculate statistics for each channel separately
    real_means = real_images.mean(dim=[0, 2, 3])
    real_stds = real_images.std(dim=[0, 2, 3])
    
    gen_means = gen_images.mean(dim=[0, 2, 3])
    gen_stds = gen_images.std(dim=[0, 2, 3])
    
    # Calculate Wasserstein-2 distance (simplified for efficiency)
    mean_distance = torch.sum((real_means - gen_means)**2)
    cov_distance = torch.sum((real_stds - gen_stds)**2)
    
    w_dist = torch.sqrt(mean_distance + cov_distance)
    
    return w_dist.item()

def compute_mmd(real_images, gen_images, kernel_multiplier=1.0):
    """
    Compute Maximum Mean Discrepancy directly on flattened image data
    Works well with 2-channel images without modification
    """
    # Flatten images for comparison
    x = real_images.reshape(real_images.shape[0], -1)
    y = gen_images.reshape(gen_images.shape[0], -1)
    
    # Convert to numpy for easier computation
    x = x.cpu().numpy()
    y = y.cpu().numpy()
    
    # Compute kernel width (median heuristic)
    n = min(1000, x.shape[0])  # Subsample for efficiency
    distances = np.sum((x[:n, None, :] - x[None, :n, :])**2, axis=-1)
    sigma = np.sqrt(np.median(distances)) * kernel_multiplier
    
    # Compute MMD with RBF kernel
    xx = np.exp(-distances / (2 * sigma**2))
    
    distances = np.sum((y[:n, None, :] - y[None, :n, :])**2, axis=-1)
    yy = np.exp(-distances / (2 * sigma**2))
    
    distances = np.sum((x[:n, None, :] - y[None, :n, :])**2, axis=-1)
    xy = np.exp(-distances / (2 * sigma**2))
    
    mmd = np.mean(xx) + np.mean(yy) - 2 * np.mean(xy)
    
    return float(mmd)

class EDM2Loss:
    def __init__(self, P_mean=-2, P_std=1.2, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, in_data, n_bins=32, log_bins=False, return_gs=False, **net_fwd_kwargs):
        rnd_normal = torch.randn([in_data.shape[0], 1, 1, 1], device=in_data.device)  # (b,1,1,1)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()                         # (b,1,1,1)
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 # (b,1,1,1)
        noise = torch.randn_like(in_data) * sigma                                     # (b,c,x,y)
        denoised, grads, logvar = net(in_data + noise, sigma, return_grads=log_bins, return_logvar=True, **net_fwd_kwargs)
        # loss = (weight) * ((denoised - in_data) ** 2)
        loss = (weight / logvar.exp()) * ((denoised - in_data) ** 2) + logvar
        
        if log_bins:
            grads = rearrange(grads, "b c x y -> b c (x y)")
            bins = binning_function(n_bins, sigma, grads.norm(dim=[1,2]))
            return loss.mean(), bins
        return loss.mean(), None

def edm_sampler(
    net, noise, gnet=None,
    num_steps=100, sigma_min=0.002, sigma_max=80, rho=7, guidance=1,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
    dtype=torch.float32, randn_like=torch.randn_like,
    **net_kwargs
):
    # Guided denoiser.
    def denoise(x, t):
        Dx, _ = net(x, t, **net_kwargs)
        Dx = Dx.to(dtype)
        if guidance == 1:
            return Dx
        ref_Dx = gnet(x, t).to(dtype)
        return ref_Dx.lerp(Dx, guidance)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_next = noise.to(dtype) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        if S_churn > 0 and S_min <= t_cur <= S_max:
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1)
            t_hat = t_cur + gamma * t_cur
            x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
        else:
            t_hat = t_cur
            x_hat = x_cur

        # Euler step.
        d_cur = ( (x_hat - denoise(x_hat, t_hat)) / t_hat ).detach()
        x_next = ( x_hat + (t_next - t_hat) * d_cur ).detach()

        # Apply 2nd order correction.
        if i < num_steps - 1:
            d_prime = ( (x_next - denoise(x_next, t_next)) / t_next ).detach()
            x_next = ( x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) ).detach()

    return x_next
