#%%
import numpy as np
from matplotlib import pyplot as plt
from numpy.random import default_rng

import networkx as nx
import pygsp
import network_lasso as nl
from joblib import delayed, Parallel
import pickle
from datetime import datetime
plt.close('all')

import policy
from bandit import MultiTaskContextualBandit
from experiment import bandit_multitask_experiment
import utils
import config as co



cluster_inds = utils.create_cluster_indices(n_points= co.n_users, n_clusters=  co.n_clusters, 
                                            weights= co.random.uniform(0,1, size=co.n_clusters), 
                                            imbalance= co.imbalance)

G = pygsp.graphs.StochasticBlockModel(N=co.n_users , k= co.n_clusters , p= co.p , 
                                      q= co.q, z= cluster_inds, connected = True, seed= 0)
Adj = G.W.toarray()

G = nx.from_numpy_array(Adj)


plt.figure()
plt.imshow(Adj)

# generate the signal over the clusters
Theta_per_cluster = co.random.standard_normal((co.n_clusters,co.dim))
Theta_per_cluster /= np.linalg.norm(Theta_per_cluster, axis= 1, keepdims= True)
Theta = Theta_per_cluster[cluster_inds]


# if co.dim == 2:
#     plt.figure()
#     plt.scatter(*Theta.T)
#     plt.xlim(np.min(Theta[:,0]), np.max(Theta[:,0]))
#     plt.ylim(np.min(Theta[:,1]), np.max(Theta[:,1]))


#%% generating context vectors
X =  co.random.standard_normal(size= (co.n_arms_all, co.dim))
X /= np.linalg.norm(X, axis= 1, keepdims= True) # normalize

def context_sampler(t, dim, n_arms, rng):
    # return X[:50]#random_generator.choice(range(len(X)), size= 50, replace= False)
    # chosen_inds = random_generator.choice(range(len(X)), size= 50, replace= False)
    # return X[inds_to_choose[t]]
    X = rng.standard_normal(size= (co.n_arms, co.dim))
    return X/np.linalg.norm(X, axis= 1, keepdims= True)

#%%
l_net = 1.0

agent_dict = {
                "NL_alpha1": policy.NetworkLassoAgent(adjacency= Adj, #/Adj.sum(axis=1,keepdims= True),
                                l_net= 1.0,  K_norm= "F_bound", weighting= "degree"),
                "NL_alpha01": policy.NetworkLassoAgent(adjacency= Adj, #/Adj.sum(axis=1,keepdims= True),
                                l_net= 0.1,  K_norm= "F_bound", weighting= "degree"),
                "GraphUCB": policy.GraphLinUCB(laplacian= utils.random_walk_laplacian(Adj)),
                "LinUcbITL": policy.LinUcbItl(),
                "LinUcbOracle": policy.LinUcbClusterOracle(cluster_inds = cluster_inds, l_reg= co.eps),
                "CLUB": policy.CLUB(G= G),
                "SCLUB": policy.SCLUB(G= G, horizon = co.horizon),
                "GOBLin": policy.GOBLin(Adj),
                
            }


#% bandit
bandit = MultiTaskContextualBandit(Theta, n_arms= co.n_arms, context_sampler= context_sampler, 
                                    std= co.sigma, context_generator= 2)
# users =co.random.integers(low= 0, high= bandit.n_users, size= co.horizon) # sequence of users


ss =co.random.bit_generator._seed_seq
parallelizer = Parallel(n_jobs= -1, verbose= 10)
spawns = ss.spawn(co.repetitions)

res = parallelizer(delayed(bandit_multitask_experiment)
                           (bandit, agent, co.horizon, seed, users= None)
                           
                           for agent in agent_dict.values()
                           for seed in spawns
                            )

res= np.array(res).reshape(len(agent_dict), co.repetitions, 2, co.horizon)
#%
filename = co.experiment_name #datetime.now().strftime("%m-%d %H:%M:%S")
res_dict = {"agents_names": list(agent_dict.keys()), 
            "results": res}
with open(f"./results/{filename}.pkl", 'wb') as file:
    pickle.dump(res_dict, file)
# %
# %%
rewards, rewards_oracle, agent_names = utils.postprocess(res_dict)
regret_raw = rewards_oracle - rewards
regret_cumul = np.cumsum(regret_raw, axis= 2)
regret_cumul_dict = dict(zip(agent_names, regret_cumul))

#%
plt.figure()
for agent_name, regret in regret_cumul_dict.items():
    if not agent_name.endswith("noUCB"):
        reg_mean = np.mean(regret, axis= 0)
        reg_std = np.std(regret, axis= 0)
        plt.plot(reg_mean, label= agent_name)
        plt.fill_between(np.arange(len(reg_mean)), reg_mean + reg_std, reg_mean - reg_std,
                            alpha= 0.01)        # plt.plot(regret.T, alpha= 0.2, color= 'k')
# plt.show()
plt.legend()
plt.savefig("regret.pdf", format= "pdf")
# %%
