
#%%
import numpy as np
from matplotlib import pyplot as plt
from numpy.random import default_rng

import networkx as nx
import pygsp
import network_lasso as nl
from joblib import delayed, Parallel
import pickle
from datetime import datetime
plt.close('all')

import policy
from bandit import MultiTaskContextualBandit
from experiment import bandit_multitask_experiment
import utils

# %
dtype= 'float32'

n_arms_all= 500
n_arms = 50

sigma = 0.01
random = default_rng(0)
eps = 1e-9
repetitions = 10



# Fig (a)

exp_proto = { 
        "horizon": 2000,
        "n_users": 50,
        "dim": 2,
        "n_clusters": 4,
        "imbalance": 1.0,
        "p": 0.8,
        "q": 0.2,
        }

fig_a = { 
        "horizon": 3000,
        "n_users": 100,
        "dim": 20,
        "n_clusters": 8,
        "imbalance": 1.0,
        "p": 0.4,
        "q": 0.1,
        }

fig_b = { 
        "horizon": 1000,
        "n_users": 100,
        "dim": 10,
        "n_clusters": 8,
        "imbalance": 1.0,
        "p": 0.5,
        "q": 0.1,
        }

fig_c = { 
        "horizon": 3000,
        "n_users": 50,
        "dim": 80,
        "n_clusters": 5,
        "imbalance": 1.0,
        "p": 0.8,
        "q": 0.2,
        }

fig_d = { 
        "horizon": 5000,
        "n_users": 200,
        "dim": 20,
        "n_clusters": 10,
        "imbalance": 1.0,
        "p": 0.5,
        "q": 0.05,
        }

# experiment_name = f"u{n_users}d{dim}h{horizon}c{n_clusters}i{imbalance}p{p}q{q}"
# %%
