import numpy as np
import random
import math
import os
import argparse
import importlib
from algorithms.SGD_TS import *
from algorithms.AutoTuning import *
data_generator = importlib.import_module('algorithms.data_generator')
import concurrent.futures

parser = argparse.ArgumentParser(description='simulations')
parser.add_argument('-data', '--data', type=str, default = 'simulations', help = 'can be movielens')
args = parser.parse_args()

T = 10000
rep = 10
J = [0, 0.01, 0.1, 1, 10] # tuning set for exploration para \alpha
etas = [0.01, 0.1, 1, 10] # tuning set for step size
algo = 'sgdts'
datatype = args.data

if datatype == 'simulations':
    sigma, d, K = 0.1, 10, 100
elif datatype == 'movielens':
    sigma, d, K = 0.5, 20, 1000
paras = {
    'eta0': etas,
    'alpha1': J,
    'alpha2': J,

}
print('tuning set of explore1, explore2, step size are {}'.format(paras))

if not os.path.exists('results/'):
    os.mkdir('results/')
if not os.path.exists('results/' + datatype + '/'):
    os.mkdir('results/' + datatype + '/')
if not os.path.exists('results/' + datatype + '/' + algo + '/'):
    os.mkdir('results/' + datatype + '/' + algo + '/')
path = 'results/' + datatype + '/' + algo + '/'

if datatype == 'movielens':
    # check real data files exist:
    if not os.path.isfile('data/{}_users_matrix_d{}'.format(datatype, d)) or not os.path.isfile('data/{}_movies_matrix_d{}'.format(datatype, d)):
        print("{holder} data does not exist, will run preprocessing for {holder} data now. If you are running experiments for netflix data, then preprocessing might take a long time".format(holder=datatype))
        from data.preprocess_data import *
        process = eval("process_{}_data".format(datatype))
        process(d)
        print("real data processing done")   
    users = np.loadtxt("data/{}_users_matrix_d{}".format(datatype, d))
    fv = np.loadtxt("data/{}_movies_matrix_d{}".format(datatype, d))
    np.random.seed(0)
    thetas = np.zeros((rep, d))
    for i in range(rep):
        thetas[i,:] = np.mean(users[np.random.choice(len(users), 100, replace = False), :], axis=0)

ub = 1/math.sqrt(d)
lb = -1/math.sqrt(d)
reg_syndicated = np.zeros(T)
reg_tl = np.zeros(T)
reg_op = np.zeros(T)
reg_tl_combined = np.zeros(T)
reg_corral = np.zeros(T)
reg_corral_combined = np.zeros(T)
parallel = False

methods = {
    'auto': '_syndicated',
    'op': '_op',
    'tl_combined': '_tl_combined',
    'corral': '_corral',
    'corral_combined': '_corral_combined',
}

if parallel:
# parallel version:
    def func(i):
        np.random.seed(i+1)
        reg_tl = np.zeros(T)
        reg_op = np.zeros(T)
        reg_syndicated = np.zeros(T)
        reg_tl_combined = np.zeros(T)
        reg_corral = np.zeros(T)
        reg_corral_combined = np.zeros(T)
        if datatype == 'simulations':
            theta = np.random.uniform(lb, ub, d)
            fv = np.random.uniform(-1, 1, (T, K, d))
            context_logistic = data_generator.context_logistic
            bandit = context_logistic(K, -1, 1, T, d, sigma, true_theta = theta, fv=fv)
        elif datatype == 'movielens':
            fv = np.loadtxt("data/{}_movies_matrix_d{}".format(datatype, d))
            context_logistic = data_generator.movie_logistic
            theta = thetas[i, :]
            bandit = context_logistic(K, T, d, sigma, true_theta = theta, fv=fv)
        bandit.build_bandit()
        print(i, ": ", end = " ")
        algo_class = SGD_TS(bandit, T)

        fcts = {
            k: getattr(algo_class, algo+methods[k])
            for k,v in methods.items()
        }

        reg_tl += fcts['auto']( {'eta0': paras['eta0']} ) # tl
        reg_tl_combined += fcts['tl_combined'](paras) # tl_combined
        reg_syndicated += fcts['auto'](paras) # syndicated
        reg_op += fcts['op']( {'eta0': paras['eta0']} ) # op
        reg_corral += fcts['corral'](etas)
        reg_corral_combined += fcts['corral_combined'](etas, J)

        print("op {}, tl {}, syndicated {}, combined {}, corral {}, corral_combined {}".format(
            reg_op[-1], reg_tl[-1], reg_syndicated[-1], reg_tl_combined[-1], reg_corral[-1],
            reg_corral_combined[-1]))
        return reg_tl, reg_op, reg_syndicated, reg_tl_combined, reg_corral, reg_corral_combined

    with concurrent.futures.ProcessPoolExecutor() as executor:
        secs = [nn for nn in range(rep)]
        results = executor.map(func, secs)
        results = list(results)
    res1 = [i[0] for i in list(results)]
    res2 = [i[1] for i in list(results)]
    res3 = [i[2] for i in list(results)]
    res4 = [i[3] for i in list(results)]
    res5 = [i[4] for i in list(results)]
    res6 = [i[5] for i in list(results)]
    print('{0}: reg_tl: {1}'.format(algo, (sum(res1) / rep)[-5:]))
    print('{0}: reg_op: {1}'.format(algo, (sum(res2) / rep)[-5:]))
    print('{0}: reg_syndicated: {1}'.format(algo, (sum(res3) / rep)[-5:]))
    print('{0}: reg_tl_combined: {1}'.format(algo, (sum(res4) / rep)[-5:]))
    print('{0}: reg_corral: {1}'.format(algo, (sum(res5) / rep)[-5:]))
    print('{0}: reg_corral_combined: {1}'.format(algo, (sum(res6) / rep)[-5:]))

    result = {
        'tl': sum(res1) / rep,
        'op': sum(res2) / rep,
        'syndicated': sum(res3) / rep,
        'tl_combined': sum(res4) / rep,
        'corral': sum(res5) / rep,
        'corral_combined': sum(res6) / rep
    }
    for k, v in result.items():
        np.savetxt(path + k, v)

else:
    # unparallel version
     for i in range(rep):
         np.random.seed(i + 1)
         reg_tl = np.zeros(T)
         reg_op = np.zeros(T)
         reg_syndicated = np.zeros(T)
         reg_tl_combined = np.zeros(T)
         reg_corral = np.zeros(T)
         reg_corral_combined = np.zeros(T)
         if datatype == 'simulations':
             theta = np.random.uniform(lb, ub, d)
             fv = np.random.uniform(-1, 1, (T, K, d))
             context_logistic = data_generator.context_logistic
             bandit = context_logistic(K, -1, 1, T, d, sigma, true_theta=theta, fv=fv)
         elif datatype == 'movielens':
             fv = np.loadtxt("data/{}_movies_matrix_d{}".format(datatype, d))
             context_logistic = data_generator.movie_logistic
             theta = thetas[i, :]
             bandit = context_logistic(K, T, d, sigma, true_theta=theta, fv=fv)
         bandit.build_bandit()
         print(i, ": ", end=" ")
         algo_class = SGD_TS(bandit, T)

         fcts = {
             k: getattr(algo_class, algo + methods[k])
             for k, v in methods.items()
         }

         reg_tl += fcts['auto']({'eta0': paras['eta0']})  # tl
         reg_tl_combined += fcts['tl_combined'](paras)  # tl_combined
         reg_syndicated += fcts['auto'](paras)  # syndicated
         reg_op += fcts['op']({'eta0': paras['eta0']})  # op
         reg_corral += fcts['corral'](etas)
         reg_corral_combined += fcts['corral_combined'](etas, J)

         print("op {}, tl {}, syndicated {}, combined {}, corral {}, corral_combined {}".format(
             reg_op[-1], reg_tl[-1], reg_syndicated[-1], reg_tl_combined[-1], reg_corral[-1],
             reg_corral_combined[-1]))
         result = {
             'tl': reg_tl / (i + 1),
             'op': reg_op / (i + 1),
             'syndicated': reg_syndicated / (i + 1),
             'tl_combined': reg_tl_combined / (i + 1),
             'corral': reg_corral / (i + 1),
             'corral_combined': reg_corral_combined / (i + 1)
         }
         for k, v in result.items():
             np.savetxt(path + k, v)
