# ---------------------------------------------------------------
# This file has been modified from following sources: 
# Source:
# 1. https://github.com/NVlabs/LSGM/blob/main/util/ema.py (NVIDIA License)
# 2. https://github.com/NVlabs/denoising-diffusion-gan/blob/main/train_ddgan.py (NVIDIA License)
# 3. https://github.com/nhartland/KL-divergence-estimators (MIT License)
# ---------------------------------------------------------------

import warnings
import numpy as np
import torch
from torch.optim import Optimizer
import torch.nn.functional as F
import math

# ------------------------
# Select Phi_star
# ------------------------
def select_phi(name):
    if name == 'linear':
        def phi(x):
            return x
            
    elif name == 'kl':
        def phi(x):
            return torch.exp(x) - 1
    
    elif name == 'chi':
        def phi(x):
            y = F.relu(x+2)-2
            return 0.25 * y**2 + y
        
    elif name == 'softplus':
        def phi(x):
            # return (2*F.softplus(x/5) - 2*F.softplus(0*x))*5
            return 2*F.softplus(x) - 2*F.softplus(0*x)
    else:
        raise NotImplementedError
    
    return phi

# ------------------------
# Sampler
# ------------------------
class Sampler:
    def __init__(self, args):
        self.num_timesteps = args.num_timesteps
        self.h = 1 / args.num_timesteps
        self.sigma = args.sigma
        self.alpha = args.alpha
        self.ts = [self.h * k for k in range(self.num_timesteps+1)]
        self.p = [1]*(args.num_timesteps+1) if args.time_sample == 'uniform' else [i+1 for i in range(args.num_timesteps+1)]
        self.p = np.array(self.p)
        self.t_prob = self.p[:-1]/sum(self.p[:-1])
        self.batch_size = args.batch_size
        self.device = args.device

    def retrieve_prob_t(self, t):
        idxs = (t / self.h +1e-8) // 1
        return torch.tensor(np.array([self.t_prob[int(idx)] for idx in idxs]), dtype=torch.float32).to(self.device)

    def sample_t(self, batch_size=None, not_sample_last_t=False):
        '''
        Sample t (size: batch_size) from self.ts
        where self.ts = [0, h, 2h, ..., 1] for h=1/num_timesteps
        if not_sample_last_t is True, we sample from [0, h, 2h, ..., (num_timesteps-1)h] (Not sample 1)
        '''
        if batch_size is None:
            batch_size = self.batch_size
        if not_sample_last_t:
            t = np.random.choice(self.ts[:-1], batch_size, p=self.t_prob)
        else:
            t = np.random.choice(self.ts, batch_size, p=self.p/sum(self.p))
        return torch.tensor(t, dtype=torch.float32, device=self.device)


    def __call__(self, x, y, t=None, return_ts=True):
        '''
        x : source data
        y : generated target data
        t : Time to sample If none then sample. 
        return_ts : return t

        y_t = (1-t)x + ty + sigma * sqrt(t(1-t)) * z
        where z ~ N(0,I)
        '''
        if t is None:
            t = self.sample_t()
        else:
            return_ts = False
        
        z = torch.randn_like(x)
        expended_t = t.view(t.size(0),*(1,)*(len(x.shape)-1))
        y_t = (1-expended_t) * x + expended_t * y + torch.sqrt(expended_t * (1-expended_t)) * self.sigma * z
        if not return_ts:
            return y_t 
        else:
            return y_t, t
        
    
    def sample_next(self, y_t, y, t):
        '''
        y_t : sample from time t
        y : generated sample
        t : current time
        sample next sample. Let s = t+h. Then,
                y_s = (1-s)/(1-t) y_t + h/(1-t) y + sigma * sqrt(h (1-s)/(1-t)) * z
                where z ~ N(0, I)
        '''
        assert sum(t==1) == 0
        t = t.view(t.size(0),*(1,)*(len(y.shape)-1))
        tph = t + self.h
        eta = torch.randn_like(y)
        return (1-tph) / (1-t) * y_t + self.h / (1-t) * y + torch.sqrt(self.h * (1-tph) / (1-t)) * self.sigma * eta
        

    def sample_pair(self, x, y, t=None, return_ts = True):
        '''
        Sample (t, y_t). Then, sample (t+h, y_{t+h}) | (t, y_t)
        '''
        if t is None:
            t = self.sample_t(not_sample_last_t=True)
        y_t = self.__call__(x, y, t)
        y_tph = self.sample_next(y_t, y, t)
        if return_ts:
            return y_t, y_tph, t, t + self.h
        else:
            return y_t, y_tph
    

    def sample_next_by_value(self, netD, y_t, t, noise=True):
        '''
        Sample from the value function
        y_{t+h} = - nabla(netD(y_t, t)) h + sigma * sqrt(h) * z
        '''
        try: y_t.requires_grad = True
        except: pass
        
        V_t = netD(y_t, t)
        dVdx = torch.autograd.grad(V_t.sum(), y_t)[0]
        y_tph = y_t - self.alpha * dVdx * self.h
        if noise:
            z = torch.randn_like(y_t)
            y_tph = y_tph + self.sigma * math.sqrt(self.h) * z
        tph = t + self.h
        return y_tph, tph
    

    def sample_by_value(self, netD, x, return_traj=False):
        '''
        x: source data
        '''
        if return_traj:
            traj = [x.detach()]

        for i, t in enumerate(self.ts[:-1]):
            t = torch.zeros(x.shape[0], device=x.device) + t
            
            if i < self.num_timesteps - 1:
                x, _ = self.sample_next_by_value(netD, x, t)
            else:
                x, _ = self.sample_next_by_value(netD, x, t, noise=False)
            
            if return_traj:
                traj.append(x.detach())

        if return_traj:
            return torch.stack(traj)
        else:
            return x


# ------------------------
# EMA
# ------------------------
class EMA(Optimizer):
    def __init__(self, opt, ema_decay):
        '''
        EMA Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py
        '''
        self.ema_decay = ema_decay
        self.apply_ema = self.ema_decay > 0.
        self.optimizer = opt
        self.state = opt.state
        self.param_groups = opt.param_groups

    def step(self, *args, **kwargs):
        retval = self.optimizer.step(*args, **kwargs)

        # stop here if we are not applying EMA
        if not self.apply_ema:
            return retval

        ema, params = {}, {}
        for group in self.optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                state = self.optimizer.state[p]

                # State initialization
                if 'ema' not in state:
                    state['ema'] = p.data.clone()

                if p.shape not in params:
                    params[p.shape] = {'idx': 0, 'data': []}
                    ema[p.shape] = []

                params[p.shape]['data'].append(p.data)
                ema[p.shape].append(state['ema'])

            for i in params:
                params[i]['data'] = torch.stack(params[i]['data'], dim=0)
                ema[i] = torch.stack(ema[i], dim=0)
                ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)

            for p in group['params']:
                if p.grad is None:
                    continue
                idx = params[p.shape]['idx']
                self.optimizer.state[p]['ema'] = ema[p.shape][idx, :]
                params[p.shape]['idx'] += 1

        return retval

    def load_state_dict(self, state_dict):
        super(EMA, self).load_state_dict(state_dict)
        # load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to
        # the underlying optimizer too.
        self.optimizer.state = self.state
        self.optimizer.param_groups = self.param_groups

    def swap_parameters_with_ema(self, store_params_in_ema):
        """ This function swaps parameters with their ema values. It records original parameters in the ema
        parameters, if store_params_in_ema is true."""

        # stop here if we are not applying EMA
        if not self.apply_ema:
            warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.')
            return

        for group in self.optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if not p.requires_grad:
                    continue
                ema = self.optimizer.state[p]['ema']
                if store_params_in_ema:
                    tmp = p.data.detach()
                    p.data = ema.detach()
                    self.optimizer.state[p]['ema'] = tmp
                else:
                    p.data = ema.detach()

