#%%
import numpy as np
from matplotlib import pyplot as plt
from numpy.random import default_rng

import networkx as nx
import pygsp
import network_lasso as nl
from joblib import delayed, Parallel
import pickle
from datetime import datetime
plt.close('all')

import policy
from bandit import MultiTaskContextualBandit
from experiment import bandit_multitask_experiment
import utils
import config as co

for exp_dict in  [co.exp_proto]:#, co.fig_b, co.fig_c, co.fig_d]:

    cluster_inds = utils.create_cluster_indices(n_points= exp_dict["n_users"], n_clusters=  exp_dict["n_clusters"], 
                                                weights= co.random.uniform(0,1, size=exp_dict["n_clusters"]), 
                                                imbalance= exp_dict["imbalance"])

    G = pygsp.graphs.StochasticBlockModel(N=exp_dict["n_users"] , k= exp_dict["n_clusters"] , p= exp_dict["p"] , 
                                        q= exp_dict["q"], z= cluster_inds, connected = True, seed= 0)
    Adj = G.W.toarray()

    G = nx.from_numpy_array(Adj)


    plt.figure()
    plt.imshow(Adj)

    # generate the signal over the clusters
    Theta_per_cluster = co.random.standard_normal((exp_dict["n_clusters"],exp_dict["dim"]))
    Theta_per_cluster /= np.linalg.norm(Theta_per_cluster, axis= 1, keepdims= True)
    Theta = Theta_per_cluster[cluster_inds]


    # if exp_dict["dim"] == 2:
    #     plt.figure()
    #     plt.scatter(*Theta.T)
    #     plt.xlim(np.min(Theta[:,0]), np.max(Theta[:,0]))
    #     plt.ylim(np.min(Theta[:,1]), np.max(Theta[:,1]))


    #%% generating context vectors
    X =  co.random.standard_normal(size= (co.n_arms_all, exp_dict["dim"]))
    X /= np.linalg.norm(X, axis= 1, keepdims= True) # normalize

    def context_sampler(t, dim, n_arms, rng):
        # return X[:50]#random_generator.choice(range(len(X)), size= 50, replace= False)
        # chosen_inds = random_generator.choice(range(len(X)), size= 50, replace= False)
        # return X[inds_to_choose[t]]
        X = rng.standard_normal(size= (co.n_arms, exp_dict["dim"]))
        return X/np.linalg.norm(X, axis= 1, keepdims= True)

    #%%
    l_net = 1.0

    agent_dict = {
                    # "NL_alpha1": policy.NetworkLassoAgent(adjacency= Adj, #/Adj.sum(axis=1,keepdims= True),
                    #                 l_net= 1.0,  K_norm= "F_bound", weighting= "degree"),
                    # "NL_alpha01": policy.NetworkLassoAgent(adjacency= Adj, #/Adj.sum(axis=1,keepdims= True),
                    #                 l_net= 0.1,  K_norm= "F_bound", weighting= "degree"),
                    # "GraphUCB": policy.GraphLinUCB(laplacian= utils.random_walk_laplacian(Adj)),
                    # "LinUcbITL": policy.LinUcbItl(),
                    # "LinUcbOracle_corrected": policy.LinUcbClusterOracle(cluster_inds = cluster_inds),
                    "CLUB": policy.CLUB(G= G),
                    # "SCLUB": policy.SCLUB(G= G, horizon = exp_dict["horizon"]),
                    # "GOBLin": policy.GOBLin(Adj),

                    # "OLS": policy.LinUcbItl(use_ucb= False, l_reg= co.eps),
                    # "OLS-Oracle": policy.LinUcbClusterOracle(cluster_inds = cluster_inds, 
                    #                                          l_reg= co.eps, use_ucb= False),
                    # "SA-LASSO": policy.SALasso_ITL(),
                    # "Trace-Norm": policy.TraceNormAgent(l_nuc= 0.01),
                    "LOCB": policy.LOCB()
                    
                }
    experiment_suffix= "_LOCB"


    #% bandit
    bandit = MultiTaskContextualBandit(Theta, n_arms= co.n_arms, context_sampler= context_sampler, 
                                        std= co.sigma, context_generator= 2)
    # users =co.random.integers(low= 0, high= bandit.n_users, size= exp_dict["horizon"]) # sequence of users


    ss =co.random.bit_generator._seed_seq
    parallelizer = Parallel(n_jobs= 1, verbose= 10)
    spawns = ss.spawn(co.repetitions)

    res = parallelizer(delayed(bandit_multitask_experiment)
                            (bandit, agent, exp_dict["horizon"], seed, users= None, progress_bar= False)
                            
                            for agent in agent_dict.values()
                            for seed in spawns
                                )

    res= np.array(res).reshape(len(agent_dict), co.repetitions, 2, exp_dict["horizon"])
    #%
    res_dict = {"agents_names": list(agent_dict.keys()), 
                "results": res}
    experiment_name = f"u{exp_dict["n_users"]}d{exp_dict["dim"]}"+\
                    f"h{exp_dict["horizon"]}c{exp_dict["n_clusters"]}"+\
                            f"i{exp_dict["imbalance"]}p{exp_dict["p"]}q{exp_dict["q"]}"
    with open(f"./results/{experiment_name + experiment_suffix}.pkl", 'wb') as file:
        pickle.dump(res_dict, file)
    # %
    # %%