import numpy as np
import scipy
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import re
import pickle
import time
import os
from scipy.linalg import sqrtm

sns.set(style="darkgrid")


class Experiments():
    
    def __init__(self, config):
        self.config = config
        
        self.dist = config['distribution']
        self.D = config['dimension']
        self.N = config['total ensemble']
        self.n_iter = config['number of iterations']
        self.IC = config['initial condition']
        self.stats_fn = self.config['stats_fn']
        
        self.results = {}
        self.methods_to_compare = []
    
    def Wiener(self):
        return np.random.randn(self.N, self.D)
    
    def record_stats(self, method):
        if method not in self.results:
            self.results[method] = {key: [] for key in self.stats_fn}
            
        for key, fn in self.stats_fn.items():
            self.results[method][key].append(fn(self.q))
    
    def initialize(self):
        self.q = np.zeros((self.N, self.D)) + self.IC['q']
        if self.IC['p'] == 'zero':
            self.p = np.zeros((self.N, self.D))
        elif self.IC['p'] == 'gradient':
            self.p = -self.dist.dV(self.q)
        elif self.IC['p'] == 'random':
            self.p = np.random.randn(self.N, self.D)
    
    def OverdampedLangevin(self, h):
        self.initialize()
        self.record_stats('OLD')
        for it in range(1, self.n_iter + 1):    
            self.q += -h * self.dist.dV(self.q) + np.sqrt(2*h) * self.Wiener()
            self.record_stats('OLD')
        
        self.results['OLD']['samples'] = self.q.copy()
        self.methods_to_compare.append('OLD')
        
    def UnderdampedLangevin(self, h, gamma):
        sigma = np.sqrt(2 * gamma)
        name = 'UL (gamma={})'.format(gamma)
        self.initialize()
        self.record_stats(name)
        for it in range(1, self.n_iter + 1): 
            self.q += self.p * h
            self.p += -(self.dist.dV(self.q) + gamma * self.p) * h + np.sqrt(h) * sigma * self.Wiener()
            self.record_stats(name)
        
        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
        
    def HFHR(self, h, alpha=None, gamma=None):          
        dist, N, D = self.dist, self.N, self.D
        if alpha is None:
            alpha = 2*h
        
        self.initialize()
        if alpha == 0:
            name = r'ULD($\gamma$={})'.format(gamma)
        else:
            if gamma is None:
                name = r'HFHR($\alpha$={}, $\gamma$=3/t)'.format(alpha)
            else:
                name = r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)
        self.record_stats(name)
        for it in range(1, self.n_iter + 1):
            self.q = self.q + h*(self.p-alpha*dist.dV(self.q))+np.sqrt(h)*np.sqrt(2*alpha)*self.Wiener()
            self.p = np.exp(-gamma*h)*self.p-(1-np.exp(-gamma*h))/gamma*dist.dV(self.q)+np.sqrt(1-np.exp(-2*gamma*h))*self.Wiener()
            self.record_stats(name)

        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
    
    def phi_flow_exact(self, q, p, t, gamma):
        noise_q_tmp = np.random.randn(*q.shape)
        noise_p_tmp = np.random.randn(*p.shape)

        noise_q = self.M[0][0] * noise_q_tmp + self.M[0][1] * noise_p_tmp
        noise_p = self.M[1][0] * noise_q_tmp + self.M[1][1] * noise_p_tmp

        q = q + (1 - np.exp(-gamma*t)) / gamma * p + noise_q
        p = np.exp(-gamma*t) * p + noise_p
        return q, p

    def psi_flow_approx(self, q, p, t, gamma, alpha):
        q = q - alpha * self.dist.dV(q) * t + np.sqrt(2 * alpha * t) * np.random.randn(*q.shape)
        p = p - self.dist.dV(q) * t
        return q, p 
    
    def prepare_M(self, t, gamma):
        var_L = (2*t*gamma+4*np.exp(-gamma*t)-np.exp(-2*gamma*t)-3) / gamma**2
        var_K = 1 - np.exp(-2*gamma*t)
        E_LK = (1 - np.exp(-gamma*t))**2 / gamma
        self.M = sqrtm(np.array([[var_L, E_LK],[E_LK, var_K]]))
    
    def HFHR2(self, h, alpha=1, gamma=2):
        dist, N, D = self.dist, self.N, self.D
        
        self.prepare_M(h / 2, gamma)

        self.initialize()
        name = r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)
        self.record_stats(name)
        for it in range(1, self.n_iter + 1):
            self.q, self.p = self.phi_flow_exact(self.q, self.p, h / 2, gamma)
            self.q, self.p = self.psi_flow_approx(self.q, self.p, h, gamma, alpha)
            self.q, self.p = self.phi_flow_exact(self.q, self.p, h / 2, gamma)
            self.record_stats(name)
        
        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
      
    def get_nrows_ncols(self, n_plots):
        if n_plots < 4:
            n_rows, n_cols = 1, n_plots
        elif n_plots == 4:
            n_rows, n_cols = 2, 2
        elif n_plots <= 6:
            n_rows, n_cols = 2, 3
        elif n_plots <=8:
            n_rows, n_cols = 2, 4
        elif n_plots == 9:
            n_rows, n_cols = 3, 3
        else:
            print('Too many statistics to plot!')
            return None
        return (n_rows, n_cols)
    
    def plot_stats(self, scale='log'):
        n_rows, n_cols = self.get_nrows_ncols(len(self.stats_fn))
        print(n_rows, n_cols)
        fig, axes = plt.subplots(n_rows, n_cols, squeeze=False, figsize=(10 * n_cols, 6 * n_rows))
        
        lines = []
        for index, stat in enumerate(self.stats_fn.keys()):
            i, j = index // n_cols, index % n_cols
            ax = axes[i][j]
            for method in self.methods_to_compare:
                n_records = len(self.results[method][stat])
                x = np.linspace(0, self.n_iter, n_records)
                if scale == 'linear':
                    line = ax.plot(x, self.results[method][stat], label=method)
                elif scale == 'log':
                    line = ax.semilogy(x, np.abs(self.results[method][stat]), label=method)
                lines.extend(line)
            ax.set_xlabel('number of iterations')
            ax.set_ylabel(stat)
            ax.legend()
            
#         fig.legend(lines, self.methods_to_compare, bbox_to_anchor=(1.2, 0.6))
        plt.tight_layout()
    
class Distribution():
    
    def __init__(self, sigma):
        self.sigma = sigma
        
    def dV(self, q):
        term1 =  q / self.sigma**2
        q_max = np.max(q, axis=1, keepdims=True)
        e = np.exp(q - q_max)
        term2 = e / e.sum(axis=1, keepdims=True)
        return term1 + term2

def get_error(path):
    with open(path, 'r') as f:
        lines = f.readlines()

    last_line = lines[-2]
    error = float(re.findall('\d+\.\d+', last_line)[0])
    return error

if __name__ == '__main__':
    np.random.seed(0)

    gamma = 2
    h = 0.1
    T = 10
    alpha = 1
    ds = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]

    load_previous_results = True
    log_dir = 'experiment-log/'
    log_name_template = 'verify_dependence_on_d_fixed_n_samples(n_samples=1000, d={}).txt'
    img_dir = 'img/'
    fig_name = 'verify_dependence_on_d'

    dist = Distribution(sigma=1)
    config = {
        'distribution': dist,
        'dimension': 1,
        'total ensemble': int(1e3),
        'number of iterations': int(T / h),
        'initial condition': {'q': 1, 'p': 'zero'},
        'stats_fn': {},
    }

    if not load_previous_results:
        for d in ds:
            start = time.time()
            benchmark = Experiments({
                'distribution': dist,
                'dimension': d,
                'total ensemble': int(1e3),
                'number of iterations': 2000,
                'initial condition': {'q': 1, 'p': 'zero'},
                'stats_fn': {}
            })
            benchmark.HFHR(h=5e-3, alpha=0, gamma=gamma)
            end = time.time()
            benchmark_mean = np.mean(benchmark.results[r'ULD($\gamma$={})'.format(gamma)]['samples'], axis=0)
            print(f'dimension={d} benchmark finished, {end - start:.2f}s elapsed.')
            
            start = time.time()
            config['dimension'] = d
            config['stats_fn'] = {'error': lambda samples: np.linalg.norm( np.mean(samples, axis=0) - benchmark_mean) }
            exp = Experiments(config)
            exp.HFHR2(h=h, alpha=alpha, gamma=gamma)
            end = time.time()
            print(f'dimension={d} exp finished, {end - start:.2f}s elapsed.\n')

            if not os.path.exists(log_dir):
                os.makedirs(log_dir)
            logger = open(log_dir + log_name_template.format(d), 'w+')
            logger.write('--- Experiment Configuration ---\n\n')
            logger.write('n_samples: {}\n'.format(config['total ensemble']))
            logger.write(f'gamma: {gamma}\n')
            logger.write(f'alpha: {alpha}\n')
            logger.write(f'T: {T}\n')
            logger.write(f'h: {h}\n')
            logger.write('\n' + '-' * 80 + '\n\n')

            header = '{:>10s}\t{:>10s}\n'.format('iteration', 'l2-error')
            logger.write(header)

            template = '{:>10d}\t{:>10.3f}\n'
            for i, e in enumerate(exp.results[r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)]['error']):
                logger.write(template.format(i, e))
            logger.close()
        
            
    fig, ax = plt.subplots(figsize=(12, 8))

    error = [get_error(log_dir + log_name_template.format(d)) for d in ds]
    ax.scatter(ds, error, label='HFHR Algorithm')
    x = np.array(ds)
    ax.plot(x, 0.043 * np.sqrt(x), ls=':', label=r'$O(\sqrt{d})$')

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Dimension', fontsize=48)
    ax.set_ylabel(r'$||\mathbb{E}_{\mu_\infty}x - \mathbb{E}_\mu x||_2$', fontsize=48)
    ax.legend(fontsize=36)
    ax.tick_params(axis='both', which='major', labelsize=18)

    fig.savefig(img_dir + fig_name + '.png', bbox_inches='tight')
    fig.savefig(img_dir + fig_name + '.pdf', bbox_inches='tight')