import numpy as np
import scipy
from scipy.linalg import sqrtm

INF = int('9999999')

class Experiments():
    
    def __init__(self, config):
        self.config = config
        
        self.dist = config['distribution']
        self.D = config['dimension']
        self.N = config['total ensemble']
        self.n_iter = config['number of iterations']
        self.IC = config['initial condition']
        self.stats_fn = self.config['stats_fn']
        self.threshold = config['threshold']
        self.stop = False
        self.iteration_complexity = INF
        self.benchmark = config['benchmark']
        
        self.results = {}
        self.methods_to_compare = []
    
    def Wiener(self):
        return np.random.randn(self.N, self.D)
    
    def record_stats(self, method):
        if method not in self.results:
            self.results[method] = {key: [] for key in self.stats_fn}
            
        for key, fn in self.stats_fn.items():
            val = fn(self.q)
            self.results[method][key].append(val)
            if val < self.threshold:
                self.stop = True
            
    
    def initialize(self):
        self.q = np.zeros((self.N, self.D)) + self.IC['q']
        if self.IC['p'] == 'zero':
            self.p = np.zeros((self.N, self.D))
        elif self.IC['p'] == 'gradient':
            self.p = -self.dist.dV(self.q)
        elif self.IC['p'] == 'random':
            self.p = np.random.randn(self.N, self.D)
    
    def OverdampedLangevin(self, h):
        self.initialize()
        self.record_stats('OLD')
        for it in range(1, self.n_iter + 1):    
            self.q += -h * self.dist.dV(self.q) + np.sqrt(2*h) * self.Wiener()
            self.record_stats('OLD')
        
        self.results['OLD']['samples'] = self.q.copy()
        self.methods_to_compare.append('OLD')
        
    def UnderdampedLangevin(self, h, gamma):
        sigma = np.sqrt(2 * gamma)
        name = 'UL (gamma={})'.format(gamma)
        self.initialize()
        self.record_stats(name)
        for it in range(1, self.n_iter + 1): 
            self.q += self.p * h
            self.p += -(self.dist.dV(self.q) + gamma * self.p) * h + np.sqrt(h) * sigma * self.Wiener()
            self.record_stats(name)
        
        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
        
    def HFHR(self, h, alpha=None, gamma=None):
    """
        When alpha=0, this integrator degenerates to KLMC (ULD) algorithm.
    """      
        dist, N, D = self.dist, self.N, self.D
        if alpha is None:
            alpha = 2*h
        
        self.initialize()
        if alpha == 0:
            name = r'ULD($\gamma$={})'.format(gamma)
        else:
            if gamma is None:
                name = r'HFHR($\alpha$={}, $\gamma$=3/t)'.format(alpha)
            else:
                name = r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)
        self.record_stats(name)
        for it in range(1, self.n_iter + 1):
            self.q = self.q + h*(self.p-alpha*dist.dV(self.q))+np.sqrt(h)*np.sqrt(2*alpha)*self.Wiener()
            self.p = np.exp(-gamma*h)*self.p-(1-np.exp(-gamma*h))/gamma*dist.dV(self.q)+np.sqrt(1-np.exp(-2*gamma*h))*self.Wiener()
            self.record_stats(name)
            if not self.benchmark and self.stop:
                self.iteration_complexity = it
                break
        
        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
    
    def phi_flow_exact(self, q, p, t, gamma):
        noise_q_tmp = np.random.randn(*q.shape)
        noise_p_tmp = np.random.randn(*p.shape)

        noise_q = self.M[0][0] * noise_q_tmp + self.M[0][1] * noise_p_tmp
        noise_p = self.M[1][0] * noise_q_tmp + self.M[1][1] * noise_p_tmp

        q = q + (1 - np.exp(-gamma*t)) / gamma * p + noise_q
        p = np.exp(-gamma*t) * p + noise_p
        return q, p

    def psi_flow_approx(self, q, p, t, gamma, alpha):
        q = q - alpha * self.dist.dV(q) * t + np.sqrt(2 * alpha * t) * np.random.randn(*q.shape)
        p = p - self.dist.dV(q) * t
        return q, p 
    
    def prepare_M(self, t, gamma):
        var_L = (2*t*gamma+4*np.exp(-gamma*t)-np.exp(-2*gamma*t)-3) / gamma**2
        var_K = 1 - np.exp(-2*gamma*t)
        E_LK = (1 - np.exp(-gamma*t))**2 / gamma
        self.M = sqrtm(np.array([[var_L, E_LK],[E_LK, var_K]]))
    
    def HFHR2(self, h, alpha=1, gamma=2):
        dist, N, D = self.dist, self.N, self.D
        
        self.prepare_M(h / 2, gamma)

        self.initialize()
        name = r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)
        self.record_stats(name)
        for it in range(1, self.n_iter + 1):
            self.q, self.p = self.phi_flow_exact(self.q, self.p, h / 2, gamma)
            self.q, self.p = self.psi_flow_approx(self.q, self.p, h, gamma, alpha)
            self.q, self.p = self.phi_flow_exact(self.q, self.p, h / 2, gamma)
            self.record_stats(name)
            if not self.benchmark and self.stop:
                self.iteration_complexity = it
                break
        
        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
      
    def get_nrows_ncols(self, n_plots):
        if n_plots < 4:
            n_rows, n_cols = 1, n_plots
        elif n_plots == 4:
            n_rows, n_cols = 2, 2
        elif n_plots <= 6:
            n_rows, n_cols = 2, 3
        elif n_plots <=8:
            n_rows, n_cols = 2, 4
        elif n_plots == 9:
            n_rows, n_cols = 3, 3
        else:
            print('Too many statistics to plot!')
            return None
        return (n_rows, n_cols)
    
    def plot_stats(self, scale='log'):
        n_rows, n_cols = self.get_nrows_ncols(len(self.stats_fn))
        print(n_rows, n_cols)
        fig, axes = plt.subplots(n_rows, n_cols, squeeze=False, figsize=(10 * n_cols, 6 * n_rows))
        
        lines = []
        for index, stat in enumerate(self.stats_fn.keys()):
            i, j = index // n_cols, index % n_cols
            ax = axes[i][j]
            for method in self.methods_to_compare:
                n_records = len(self.results[method][stat])
                x = np.linspace(0, self.n_iter, n_records)
                if scale == 'linear':
                    line = ax.plot(x, self.results[method][stat], label=method)
                elif scale == 'log':
                    line = ax.semilogy(x, np.abs(self.results[method][stat]), label=method)
                lines.extend(line)
            ax.set_xlabel('number of iterations')
            ax.set_ylabel(stat)
            ax.legend()
            

        plt.tight_layout()
    
class Distribution():
    
    def __init__(self, sigma):
        self.sigma = sigma
        
    def dV(self, q):
        term1 =  q / self.sigma**2
        q_max = np.max(q, axis=1, keepdims=True)
        e = np.exp(q - q_max)
        term2 = e / e.sum(axis=1, keepdims=True)
        return term1 + term2

