import numpy as np
import scipy
import re
import os
import pickle
import time
import argparse
from scipy.linalg import sqrtm


class Experiments():
    
    def __init__(self, config):
        self.config = config
        
        self.dist = config['distribution']
        self.D = config['dimension']
        self.N = config['total ensemble']
        self.n_iter = config['number of iterations']
        self.IC = config['initial condition']
        self.stats_fn = self.config['stats_fn']
        
        self.results = {}
        self.methods_to_compare = []
    
    def Wiener(self):
        return np.random.randn(self.N, self.D)
    
    def record_stats(self, method):
        if method not in self.results:
            self.results[method] = {key: [] for key in self.stats_fn}
            
        for key, fn in self.stats_fn.items():
            self.results[method][key].append(fn(self.q))
    
    def initialize(self):
        self.q = np.zeros((self.N, self.D)) + self.IC
        self.p = np.zeros((self.N, self.D))
    
    def OverdampedLangevin(self, h):
        self.initialize()
        self.record_stats('OLD')
        for it in range(1, self.n_iter + 1):    
            self.q += -h * self.dist.dV(self.q) + np.sqrt(2*h) * self.Wiener()
            self.record_stats('OLD')
        
        self.results['OLD']['samples'] = self.q.copy()
        self.methods_to_compare.append('OLD')
        
    def UnderdampedLangevin(self, h, gamma):
        sigma = np.sqrt(2 * gamma)
        name = 'UL (gamma={})'.format(gamma)
        self.initialize()
        self.record_stats(name)
        for it in range(1, self.n_iter + 1): 
            self.q += self.p * h
            self.p += -(self.dist.dV(self.q) + gamma * self.p) * h + np.sqrt(h) * sigma * self.Wiener()
            self.record_stats(name)
        
        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
        
    def HFHR(self, h, alpha=None, gamma=None):          
        dist, N, D = self.dist, self.N, self.D
        if alpha is None:
            alpha = 2*h
        
        self.initialize()
        if alpha == 0:
            name = r'ULD($\gamma$={})'.format(gamma)
        else:
            if gamma is None:
                name = r'HFHR($\alpha$={}, $\gamma$=3/t)'.format(alpha)
            else:
                name = r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)
        self.record_stats(name)
        for it in range(1, self.n_iter + 1):
            self.q = self.q + h*(self.p-alpha*dist.dV(self.q))+np.sqrt(h)*np.sqrt(2*alpha)*self.Wiener()
            self.p = np.exp(-gamma*h)*self.p-(1-np.exp(-gamma*h))/gamma*dist.dV(self.q)+np.sqrt(1-np.exp(-2*gamma*h))*self.Wiener()
            self.record_stats(name)

            if it * 100 % self.n_iter == 0:
                print('finished {:.1f}%'.format(it / self.n_iter * 100))

        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
    
    def phi_flow_exact(self, q, p, t, gamma):
        noise_q_tmp = np.random.randn(*q.shape)
        noise_p_tmp = np.random.randn(*p.shape)

        noise_q = self.M[0][0] * noise_q_tmp + self.M[0][1] * noise_p_tmp
        noise_p = self.M[1][0] * noise_q_tmp + self.M[1][1] * noise_p_tmp

        q = q + (1 - np.exp(-gamma*t)) / gamma * p + noise_q
        p = np.exp(-gamma*t) * p + noise_p
        return q, p

    def psi_flow_approx(self, q, p, t, gamma, alpha):
        q = q - alpha * self.dist.dV(q) * t + np.sqrt(2 * alpha * t) * np.random.randn(*q.shape)
        p = p - self.dist.dV(q) * t
        return q, p 
    
    def prepare_M(self, t, gamma):
        var_L = (2*t*gamma+4*np.exp(-gamma*t)-np.exp(-2*gamma*t)-3) / gamma**2
        var_K = 1 - np.exp(-2*gamma*t)
        E_LK = (1 - np.exp(-gamma*t))**2 / gamma
        self.M = sqrtm(np.array([[var_L, E_LK],[E_LK, var_K]]))
    
    def HFHR2(self, h, alpha=1, gamma=2):
        dist, N, D = self.dist, self.N, self.D
        
        self.prepare_M(h / 2, gamma)

        self.initialize()
        name = r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)
        self.record_stats(name)
        for it in range(1, self.n_iter + 1):
            self.q, self.p = self.phi_flow_exact(self.q, self.p, h / 2, gamma)
            self.q, self.p = self.psi_flow_approx(self.q, self.p, h, gamma, alpha)
            self.q, self.p = self.phi_flow_exact(self.q, self.p, h / 2, gamma)
            self.record_stats(name)
        
        self.results[name]['samples'] = self.q.copy()
        self.methods_to_compare.append(name)
      
    def get_nrows_ncols(self, n_plots):
        if n_plots < 4:
            n_rows, n_cols = 1, n_plots
        elif n_plots == 4:
            n_rows, n_cols = 2, 2
        elif n_plots <= 6:
            n_rows, n_cols = 2, 3
        elif n_plots <=8:
            n_rows, n_cols = 2, 4
        elif n_plots == 9:
            n_rows, n_cols = 3, 3
        else:
            print('Too many statistics to plot!')
            return None
        return (n_rows, n_cols)
    
class Distribution():
    
    def __init__(self, sigma):
        self.sigma = sigma
        
    def dV(self, q):
        term1 =  q / self.sigma**2
        q_max = np.max(q, axis=1, keepdims=True)
        e = np.exp(q - q_max)
        term2 = e / e.sum(axis=1, keepdims=True)
        return term1 + term2

def get_error(f_path):
    with open(f_path, 'r') as f:
        lines = f.readlines()
    
    hs = []
    es = []
    for line in lines:
        tokens = re.findall('\d+\.\d+', line)
        if len(tokens) != 3:
            continue

        h = float(tokens[0])
        e = float(tokens[-1])

        hs.append(h)
        es.append(e)
    
    return hs, es

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int)
    args = parser.parse_args()
    seed = args.seed

    np.random.seed(seed)

    d = 2
    T = 48
    n_samples = int(1e6)
    gamma = 2
    alpha = 1
    init = 1
    hs = [2, 1] + [1.0 / 2**x for x in range(1, 8)]

    dist = Distribution(sigma=1)

    for h in hs:        
        start = time.time()
        config = {
            'distribution': dist,
            'dimension': d,
            'total ensemble': n_samples,
            'number of iterations': int(T / h),
            'initial condition': init,
            # 'stats_fn': {}
            'stats_fn': {'mean': lambda samples: np.mean(samples, axis=0) },
        }
        exp = Experiments(config)
        exp.HFHR2(h=h, alpha=alpha, gamma=gamma)
        end = time.time()
        sample_mean = exp.results[r'HFHR($\alpha$={}, $\gamma$={})'.format(alpha, gamma)]['mean']
        np.save(f'results/sample_mean|h={h}|d={d}|n_samples={n_samples}|part={seed}.npy', sample_mean)