import numpy as np
import scipy
import pandas as pd
import pickle
import time
import os
import argparse
from scipy.linalg import sqrtm
from utils import Experiments, Distribution

from generate_config import gammas, hs, d, T, n_samples, threshold, initial_value


def run_exp(data, algorithm):
    dist = Distribution(sigma=1)
    benchmark_mean = np.load("true_mean/true_mean|d=10|n_samples=10000.npy")    

    data['seed'].append(args.seed)
    data['dimension'].append(d)
    data['T'].append(T)
    data['n_samples'].append(n_samples)
    data['tolerance'].append(threshold)
    data['initial_value'].append(initial_value)
    data['alpha'].append(args.alpha)
    data['gamma'].append(gamma)
    data['h'].append(h)
    data['algorithm'].append(algorithm)

    start = time.time()
    config = {
        'distribution': dist,
        'dimension': d,
        'total ensemble': n_samples,
        'number of iterations': int(T / h),
        'initial condition': {'q': initial_value, 'p': 'zero'},
        'stats_fn': {'error': lambda samples: np.linalg.norm( np.mean(samples, axis=0) - benchmark_mean) },
        'threshold': threshold,
        'benchmark': False
    }
    exp = Experiments(config)
    if algorithm == 'ULD':
        exp.HFHR(h=h, alpha=0, gamma=gamma)
        e = exp.results[r'ULD($\gamma$={})'.format(gamma)]['error'][-1]
    elif algorithm == 'HFHR':
        exp.HFHR2(h=h, alpha=args.alpha, gamma=gamma)
        e = exp.results[r'HFHR($\alpha$={}, $\gamma$={})'.format(args.alpha, gamma)]['error'][-1]
    end = time.time()
    
    data['iteration_complexity'].append(exp.iteration_complexity)
    data['final_error'].append(e)
    data['computing_time'].append(end - start)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, help="seed")
    parser.add_argument("--alpha", type=float)
    args = parser.parse_args()

    np.random.seed(args.seed)

    data = {
        'seed': [],
        'dimension' : [],
        'T' : [],
        'n_samples': [],
        'tolerance': [],
        'initial_value': [],
        'alpha': [],
        'gamma': [],
        'h': [],
        'iteration_complexity': [],
        'final_error': [],
        'computing_time': [],
        'algorithm': []
    }

    for gamma in gammas:
        for h in hs:
            if args.alpha == 0:
                run_exp(data, 'ULD')
            run_exp(data, 'HFHR')

    log_dir = f'experiment-log/init_value={initial_value}-n_samples={n_samples}-tolerance={threshold}/'

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    pd.DataFrame(data).to_pickle(log_dir + f"seed={args.seed}|alpha={args.alpha}.pkl", protocol=4)