import numpy as np
from scipy.spatial.distance import pdist, squareform
from itertools import combinations

import seaborn as sns
import matplotlib.pyplot as plt

def dxkxy_rbf(theta, pairwise_dists, h_rep):
    Kxy = np.exp( -pairwise_dists / h_rep )
    dxkxy = -np.matmul(Kxy, theta)
    sumkxy = np.sum(Kxy, axis=1)
    for i in range(theta.shape[1]):
        dxkxy[:, i] = dxkxy[:, i] + np.multiply(theta[:,i],sumkxy)
    dxkxy = 2 * dxkxy / h_rep
    return dxkxy

def dxkxy_rbf_p(theta, h_rep, p):
    theta_rep_1 = np.repeat(theta[:, np.newaxis, :], theta.shape[0], axis=1)
    theta_rep_2 = np.repeat(theta[np.newaxis, :, :], theta.shape[0], axis=0)
    theta_diff = theta_rep_1 - theta_rep_2
    theta_diff_p_norm = np.linalg.norm(theta_diff, ord=p, axis=2)
    Kxy = np.exp( - theta_diff_p_norm**2 / h_rep**2 / 2)
    dxkxy = np.zeros(theta.shape)
    for i in range(theta.shape[0]):
        theta_i_diff = theta_diff.T[:,:,i]
        theta_i_diff = theta_i_diff * np.power(theta_i_diff, p-2)
        kxi_p_norm = Kxy[:,i] / theta_diff_p_norm[:,i]**(p-2)
        kxi_p_norm[kxi_p_norm==np.inf] = 0
        dxkxy[i,:] = np.matmul(theta_i_diff, kxi_p_norm)
    dxkxy = dxkxy / (h_rep**2)
    return dxkxy

def dxkxy_rbf_inf(theta, h_rep):
    theta_rep_1 = np.repeat(theta[:, np.newaxis, :], theta.shape[0], axis=1)
    theta_rep_2 = np.repeat(theta[np.newaxis, :, :], theta.shape[0], axis=0)
    theta_diff = theta_rep_1 - theta_rep_2
    theta_diff_inf_norm = np.linalg.norm(theta_diff, ord=np.inf, axis=2)
    Kxy = np.exp( - theta_diff_inf_norm**2 / h_rep**2 / 2)
    dxkxy = np.zeros(theta.shape)
    d_ind = np.argmax(np.abs(theta_diff), axis=2) # dimension of greatest distance

    dxkxy = np.zeros(theta.shape)
    for i,j in combinations(range(theta.shape[0]), 2):
        ij_cont = theta_diff[i, j, d_ind[i, j]]
        ij_cont *= Kxy[i, j]
        dxkxy[i, d_ind[i, j]] += ij_cont
        dxkxy[j, d_ind[i, j]] += ij_cont
    dxkxy = dxkxy / (h_rep**2)
    return dxkxy

class HSVGD():

    def __init__(self):
        pass
    
    '''
        Calculate kernel matrix and its gradient: K, \nabla_x k
    ''' 
    def hsvgd_kernel(self, theta, **kwargs):
        sq_dist = pdist(theta)
        pairwise_dists = squareform(sq_dist)**2
        
        h_med = np.median(pairwise_dists) / np.log(theta.shape[0]+1)
        h_grad = h_med * 1
        Kxy_grad = np.exp( -pairwise_dists / h_grad )

        # Compute the kernel gradient for the repulsive term
        dxkxy = np.zeros(theta.shape)
        for c in kwargs['k_rep']['repulsive']:
            h_rep = h_grad * c['h_factor']
            # h_rep = h_grad * kwargs['h_rep_factor']
            if c['kernel'] == 'rbf':
                dxkxy_next = dxkxy_rbf(theta, pairwise_dists, h_rep)
            elif c['kernel'] == 'rbf_p':
                p = c['p']
                dxkxy_next = dxkxy_rbf_p(theta, h_rep, p)
            elif c['kernel'] == 'rbf_inf':
                dxkxy_next = dxkxy_rbf_inf(theta, h_rep)
            dxkxy += dxkxy_next * c['weight']

        return (Kxy_grad, dxkxy)
 
    def update(self, x0, lnprob, n_iter = 1000, stepsize = 1e-3, bandwidth = -1, alpha = 0.9, debug = False, store_history=False, annealing=False, **kwargs):
        # Check input
        if x0 is None or lnprob is None:
            raise ValueError('x0 or lnprob cannot be None!')
        
        theta = np.copy(x0)
        if store_history:
            self.theta_history = np.hstack([np.copy(x0), np.zeros((x0.shape[0],1))])
        
        # adagrad with momentum
        fudge_factor = 1e-6
        historical_grad = 0
        for iter in range(n_iter):
            if debug and (iter+1) % 1000 == 0:
                print ('iter ' + str(iter+1) )
            
            lnpgrad = lnprob(theta)
            # calculating the kernel matrix
            kxy, dxkxy = self.hsvgd_kernel(theta, **kwargs)  
            if annealing=='hyperbolic':
                p = 5
                gamma = np.tanh((1.3*(iter/n_iter))**p)
                grad_theta = (gamma*np.matmul(kxy, lnpgrad) + dxkxy) / x0.shape[0]  
            else:
                grad_theta = (np.matmul(kxy, lnpgrad) + dxkxy) / x0.shape[0]  
            
            # adagrad 
            if iter == 0:
                historical_grad = historical_grad + grad_theta ** 2
            else:
                historical_grad = alpha * historical_grad + (1 - alpha) * (grad_theta ** 2)
            adj_grad = np.divide(grad_theta, fudge_factor+np.sqrt(historical_grad))
            theta = theta + stepsize * adj_grad 

            if store_history:
                self.theta_history = np.vstack(
                    [
                        self.theta_history,
                        np.hstack([theta, np.zeros((x0.shape[0],1))+iter+1])
                    ]
                )
            
        return theta