# ---------------------------------------------------------------
# This file has been modified from following sources: 
# Source:
# 1. https://github.com/NVlabs/LSGM/blob/main/util/ema.py (NVIDIA License)
# 2. https://github.com/NVlabs/denoising-diffusion-gan/blob/main/train_ddgan.py (NVIDIA License)
# 3. https://github.com/nhartland/KL-divergence-estimators (MIT License)
# ---------------------------------------------------------------
import os
import warnings
import numpy as np
import torch
from torch.optim import Optimizer
import torch.nn.functional as F
from scipy.spatial import KDTree
from torchvision.utils import save_image
import json
from torch.distributions import Beta
import torchvision

# ------------------------
# Initialize Transport Maps
# ------------------------
def initialize_fwd(netG, optimizerG, dataset, args, exp_path):
    device = args.device
    if args.init_num_iterations > 0:
        import ot
        for i in range(args.init_num_iterations):
            netG.zero_grad()
            src, tgt = dataset.sample()
            # y = dataset.sample()
            latent_z = torch.randn((args.batch_size, args.nz))

            M = ((src.reshape(args.batch_size, -1)[:, None, :] - tgt.reshape(args.batch_size, -1)[None,:,:])**2).mean(2)
            uniform = torch.ones(args.batch_size) 

            try: A = ot.sinkhorn(uniform, uniform, M, 0.01)
            except: A = ot.sinkhorn(uniform, uniform, M, 0.1)
            idx = torch.argmax(A, dim=1)
            tgt = tgt[idx]

            loss = args.tau * ((netG(src.to(device), latent_z.to(device)) - tgt.to(device))**2).reshape(args.batch_size, -1).sum(1)
            loss.mean().backward()
            optimizerG.step()

            if (i + 1) % 100 == 0:
                with torch.no_grad():
                    save_image(0.5 * netG(src, latent_z).detach().cpu() + 0.5, os.path.join(exp_path, f'init_{i+1}.png'))
    return netG

def initialize_bwd(netG, optimizerG, dataset, args, exp_path):
    device = args.device
    if args.init_num_iterations > 0:
        import ot
        for i in range(args.init_num_iterations):
            netG.zero_grad()
            tgt, src = dataset.sample()
            # y = dataset.sample()
            latent_z = torch.randn((args.batch_size, args.nz))

            M = ((src.reshape(args.batch_size, -1)[:, None, :] - tgt.reshape(args.batch_size, -1)[None,:,:])**2).mean(2)
            uniform = torch.ones(args.batch_size) 

            try: A = ot.sinkhorn(uniform, uniform, M, 0.01)
            except: A = ot.sinkhorn(uniform, uniform, M, 0.1)
            idx = torch.argmax(A, dim=1)
            tgt = tgt[idx]

            loss = args.tau * ((netG(src.to(device), latent_z.to(device)) - tgt.to(device))**2).reshape(args.batch_size, -1).sum(1)
            loss.mean().backward()
            optimizerG.step()

            if (i + 1) % 100 == 0:
                with torch.no_grad():
                    save_image(0.5 * netG(src, latent_z).detach().cpu() + 0.5, os.path.join(exp_path, f'init_{i+1}.png'))
    return netG
# ------------------------
# Sampler
# ------------------------
class TimeSampler:
    def __init__(self, args):
        self.time_sample = args.time_sample
        if 'beta' in self.time_sample:
            alpha, beta = float(self.time_sample.split('_')[1]), float(self.time_sample.split('_')[2])
            self.beta = Beta(torch.FloatTensor(args.batch_size*[alpha]), torch.FloatTensor(args.batch_size*[beta]))
        elif 'discrete' in self.time_sample:
            self.num_timesteps = int(self.time_sample.split('_')[1])
        elif 't' in self.time_sample:
            self.start_time = float(self.time_sample.split('t')[1])

    def sample_t(self, batch_size):
        if self.time_sample == 'uniform':
            return torch.rand(batch_size)
        elif 'beta' in self.time_sample:
            return self.beta.sample()
        elif 'discrete' in self.time_sample:
            return torch.randint(high=self.num_timesteps+1, size=(batch_size,)).float() / self.num_timesteps
        elif 't' in self.time_sample:
            return torch.rand(batch_size) * (1 - self.start_time) + self.start_time
        

    def __call__(self, t, x0, x1):
        expended_t = t.view(t.size(0),*(1,)*(len(x0.shape)-1))
        xt = (1-expended_t) * x0 + expended_t * x1
        return xt


def cost_V(V_net, t, x, y):
    return torch.sum(((x-y).view(x.size(0), -1))**2, dim=1) + V_net(x, t).squeeze()*(torch.ones_like(t)-t) ## V_net : (B,1)

def rank_relaxation(V_net, t, x, x0, y, M=1e-4, sigma=1e-7, num_samples=8):
    """
    Computes the L_rank term:
        for constraint A_t(y) \in arginf_x [c(x,y)+(1-t)V_phi(t,x)], (denote : f_(2,t)*(x) <- -(1-t)V_phi(t,x))

        E[max(0, M + V_phi(A_t(y)) - V_phi(x))]  (E_{t,x})  -- x_0 = A_t(y)

    Args:
        V_net: V_phi(t, x) network
        t: (B, 1, 1, 1), requires_grad=True
        x: (B, C, H, W), requires_grad=True
        y : (B, C, H, W), requires_grad=True 
        M : margin (>=0)

    Returns:
        rank-relaxation loss: tensor of shape (B,)
    """

    loss = 0
    cost_x0 = cost_V(V_net,t,x0,y)
    
    for i in range(num_samples):
        noisy_x0 = x0 + (sigma*(1/2)) * torch.rand_like(x0) ## N(0,1) -> U(0,1) 
        loss += (1 / num_samples) * torch.max(torch.zeros_like(cost_x0), M*torch.ones_like(cost_x0) + cost_x0 - cost_V(V_net,t,noisy_x0,y))
        
    return loss

def grad_relaxation(V_net, t, x0, y):
    
    x0.requires_grad_(True)
    
    # Compute V(x, t)
    # V = V_net(x, t) # shape: (B,)
    
    cost = cost_V(V_net, t, x0, y)
    
    grad_cost = torch.autograd.grad(
        cost, x0,
        grad_outputs=torch.ones_like(cost),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]  # shape: (B, C, H, W)
    
    # grad_cost = grad_cost 


    # Compute ||∇x cost ||^2
    return torch.sqrt(grad_cost.pow(2).sum(dim=[1, 2, 3]))  # (B,)


# ------------------------
# arg.parser to json
# ------------------------
def arg2json(input_args, dir):
    args_dict = vars(input_args)
    with open(f'{dir}/config.json', 'w') as f:
        json.dump(args_dict, f, indent=4)



# ------------------------
# Select Phi_star
# ------------------------
def select_phi(name):
    if name == 'linear':
        def phi(x):
            return x
            
    elif name == 'kl':
        def phi(x):
            return torch.exp(x)
    
    elif name == 'chi':
        def phi(x):
            y = F.relu(x+2)-2
            return 0.25 * y**2 + y
        
    elif name == 'softplus':
        def phi(x):
            return F.softplus(x)
    else:
        raise NotImplementedError
    
    return phi

# ------------------------
# EMA
# ------------------------
class EMA(Optimizer):
    def __init__(self, opt, ema_decay):
        '''
        EMA Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py
        '''
        self.ema_decay = ema_decay
        self.apply_ema = self.ema_decay > 0.
        self.optimizer = opt
        self.state = opt.state
        self.param_groups = opt.param_groups

    def step(self, *args, **kwargs):
        retval = self.optimizer.step(*args, **kwargs)

        # stop here if we are not applying EMA
        if not self.apply_ema:
            return retval

        ema, params = {}, {}
        for group in self.optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                state = self.optimizer.state[p]

                # State initialization
                if 'ema' not in state:
                    state['ema'] = p.data.clone()

                if p.shape not in params:
                    params[p.shape] = {'idx': 0, 'data': []}
                    ema[p.shape] = []

                params[p.shape]['data'].append(p.data)
                ema[p.shape].append(state['ema'])

            for i in params:
                params[i]['data'] = torch.stack(params[i]['data'], dim=0)
                ema[i] = torch.stack(ema[i], dim=0)
                ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)

            for p in group['params']:
                if p.grad is None:
                    continue
                idx = params[p.shape]['idx']
                self.optimizer.state[p]['ema'] = ema[p.shape][idx, :]
                params[p.shape]['idx'] += 1

        return retval

    def load_state_dict(self, state_dict):
        super(EMA, self).load_state_dict(state_dict)
        # load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to
        # the underlying optimizer too.
        self.optimizer.state = self.state
        self.optimizer.param_groups = self.param_groups

    def swap_parameters_with_ema(self, store_params_in_ema):
        """ This function swaps parameters with their ema values. It records original parameters in the ema
        parameters, if store_params_in_ema is true."""

        # stop here if we are not applying EMA
        if not self.apply_ema:
            warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.')
            return

        for group in self.optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if not p.requires_grad:
                    continue
                ema = self.optimizer.state[p]['ema']
                if store_params_in_ema:
                    tmp = p.data.detach()
                    p.data = ema.detach()
                    self.optimizer.state[p]['ema'] = tmp
                else:
                    p.data = ema.detach()


# ------------------------
# Get Model
# ------------------------
# Get pretrained model
def get_model(args, ckpt_path):
    from models.ncsnpp_generator_adagn import NCSNpp
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    netG = NCSNpp(args).to(device)
    netG = torch.nn.DataParallel(netG, device_ids=[0,1,2,3])
    checkpoint = torch.load(ckpt_path, map_location=device)
    netG.load_state_dict(checkpoint)
    for p in netG.parameters():
        p.requires_grad = False
    return netG



# ------------------------
# KL divergence
# ------------------------
def knn_distance(point, sample, k):
    """Euclidean distance from `point` to it's `k`-Nearest
    Neighbour in `sample`

    This function works for points in arbitrary dimensional spaces.
    """
    # Compute all euclidean distances
    norms = np.linalg.norm(sample - point, axis=1)
    # Return the k-th nearest
    return np.sort(norms)[k]


def verify_sample_shapes(s1, s2, k):
    # Expects [N, D]
    assert len(s1.shape) == len(s2.shape) == 2
    # Check dimensionality of sample is identical
    assert s1.shape[1] == s2.shape[1]


def scipy_estimator(s1, s2, k=1):
    """KL-Divergence estimator using scipy's KDTree
    s1: (N_1,D) Sample drawn from distribution P
    s2: (N_2,D) Sample drawn from distribution Q
    k: Number of neighbours considered (default 1)
    return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    d = float(s1.shape[1])
    D = np.log(m / (n - 1))

    nu_d, nu_i = KDTree(s2).query(s1, k)
    rho_d, rhio_i = KDTree(s1).query(s1, k + 1)

    # KTree.query returns different shape in k==1 vs k > 1
    if k > 1:
        D += (d / n) * np.sum(np.log(nu_d[::, -1] / rho_d[::, -1]))
    else:
        D += (d / n) * np.sum(np.log(nu_d / rho_d[::, -1]))

    return D

import torch
import torch.nn.functional as F
from pytorch_msssim import ssim
from cleanfid import fid, features
from piq import LPIPS

lpips_metric = LPIPS(reduction='none', replace_pooling = True)

def calculate_ssim(img1, img2, data_range=255.0):
    """
    img1, img2: torch.Tensor of shape (N, C, H, W) or (C, H, W)
    return: SSIM (단일 값 또는 배치 단위 평균)
    """
    if img1.ndim == 3:
        img1 = img1.unsqueeze(0)
        img2 = img2.unsqueeze(0)

    return ssim(img1, img2, data_range=data_range, size_average=False)  # shape: (N,)

def calculate_batch_psnr(pred, target, max_val=255.0):
    """
    pred, target: torch.Tensor of shape (N, C, H, W)
    Returns: PSNR for each image in batch as a tensor of shape (N,)
    """
    if pred.shape != target.shape:
        raise ValueError("Shape mismatch: pred and target must be the same")

    mse = F.mse_loss(pred, target, reduction='none')  # shape: (N, C, H, W)
    mse = mse.view(mse.size(0), -1).mean(dim=1)       # shape: (N,)

    psnr = 10 * torch.log10(max_val**2 / (mse + 1e-8))  # shape: (N,)
    return psnr

def to_numpy_imgs(imgs):
    """
    imgs: torch.Tensor of shape (N, 3, H, W), range [-1, 1] or [0, 1]
    returns: np.ndarray of shape (N, H, W, 3), dtype=uint8, RGB, [0, 255]
    """
    if imgs.min() < 0:
        imgs = imgs.clamp(-1, 1).add(1).div(2)  # [-1, 1] -> [0, 1]
    imgs = imgs.clamp(0, 1).mul(255).byte()
    imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
    return imgs

def calculate_fid_score(fake_imgs, real_imgs):
    """
    fake_imgs, real_imgs: torch.Tensor of shape (N, 3, H, W), [-1,1] or [0,1]
    returns: scalar FID score
    """
    fake_np = to_numpy_imgs(fake_imgs)
    real_np = to_numpy_imgs(real_imgs)
    return fid.compute_fid(real_images=real_np, fake_images=fake_np, model_name='inception')


def calculate_fid(path1: str, path2: str, batch_size: int = 50, device: str = 'cuda'):

    assert os.path.isdir(path1)
    assert os.path.isdir(path2)

    if torch.cuda.is_available() and device == 'cuda':
        current_device = 'cuda'
    else:
        current_device = 'cpu'

    fid_value = fid.compute_fid(path1, path2, mode="clean",
                                    num_workers=0,  
                                    batch_size=batch_size,
                                    device=current_device)
    print(f"\nFID between '{path1}' and '{path2}' : {fid_value:.4f}")

def save_images_batch(batch_image_tensor, save_path, counter, normalize=True):
    for j in range(batch_image_tensor.size(0)): # 
        output_filename = f"{counter:04d}.png" # 
        output_filepath = os.path.join(save_path, output_filename)
        
        if  os.path.isfile(output_filepath):
            print("File already exists:", output_filepath)
            
        target_image_tensor = batch_image_tensor[j]                 
        temp_image = 0.5 * (target_image_tensor + 1) if normalize else target_image_tensor
        torchvision.utils.save_image(temp_image.cpu(), output_filepath, normalize=True)
        
        counter += 1 
    return counter


def calculate_batch_lpips_piq(pred, target, device, model=lpips_metric):
    """
    pred, target: (B, 3, H, W), float32, range [-1, 1]
    model: piq.LPIPS instance
    return: (B,) tensor – LPIPS per sample
    """
    model = model.to(device)
    lpips_vals = model(pred, target)  # shape: (B,)
    return lpips_vals