import numpy as np
from tqdm import tqdm

random = np.random.default_rng()

def bandit_multitask_experiment(bandit, player, horizon, seed, users= None, progress_bar = False):
    # TODO: also separate regret by user
    rng = np.random.default_rng(seed)
    if bandit.context_generator is None:
        bandit.create_context_generator(horizon, rng)
    
    bandit.create_noise_table(horizon, rng)
    arm = [] #arm,
    expected_reward = [] # expectation of realized reward given the selected arm
    oracle_expected_reward = []
    Theta_error = []

    player.initialize(bandit)
    # if hasattr(player, 'clusters'):
    #     print(player.clusters)
    if users is None:
        users = rng.integers(low= 0, high= bandit.n_users, size= horizon)

    horizon_user_range = zip(range(1, horizon+1), users)
    if progress_bar:
        horizon_user_range = tqdm(horizon_user_range)

    for t, u in horizon_user_range:

        cxt_mat = next(bandit.context_generator) # generated a context
        
        rewards_true = cxt_mat @ bandit.Theta[u] # Rewards observed by the oracle
        
        # select arm
        arm_pull = player.play_arm(cxt_mat, u, t)
        
        x_pull = cxt_mat[arm_pull].T
        # print(np.linalg.norm(x_pull))
        y_pull = rewards_true[arm_pull] + bandit.noise_table[t-1]
        
        # Update estimated theta
        player.update(x_pull, y_pull, u, t)

        # Store history for plotting
        arm.append(arm_pull)
        expected_reward.append(rewards_true[arm_pull])
        oracle_expected_reward.append(np.max(rewards_true))
        # Theta_error.append(np.linalg.norm(np.linalg.lstsq(player.A_aug, player.b_aug)[0] - (bandit.Theta.T).flatten()))
    bandit.context_generator = None
    return np.array(expected_reward), np.array(oracle_expected_reward)#, Theta_error


def bandit_multitask_experiment_debug(bandit, player, horizon, seed, users= None):
    # TODO: also separate regret by user
    rng = np.random.default_rng(seed)
    if bandit.context_generator is None:
        bandit.create_context_generator(horizon, rng)
    
    bandit.create_noise_table(horizon, rng)
    # print(bandit.noise_table)
    arm = [] #arm,
    expected_reward = [] # expectation of realized reward given the selected arm
    oracle_expected_reward = []
    Theta_error = []

    player.initialize(bandit)
    if users is None:
        users = rng.integers(low= 0, high= bandit.n_users, size= horizon)

    for t, u in zip(range(1, horizon+1), users):

        cxt_mat = next(bandit.context_generator) # generated a context
        
        rewards_true = cxt_mat @ bandit.Theta[u] # Rewards observed by the oracle
        
        # select arm
        arm_pull = player.play_arm(cxt_mat, u, t)
        print(arm_pull)
        
        x_pull = cxt_mat[arm_pull].T
        # print(np.linalg.norm(x_pull))
        y_pull = rewards_true[arm_pull] + bandit.noise_table[t-1]
        
        # Update estimated theta
        player.update(x_pull, y_pull, u, t)

        # Store history for plotting
        arm.append(arm_pull)
        print(rewards_true[arm_pull] - bandit.Theta[u] @ x_pull)
        expected_reward.append(bandit.Theta[u] @ x_pull)
        oracle_expected_reward.append(np.max(rewards_true))
        # Theta_error.append(np.linalg.norm(np.linalg.lstsq(player.A_aug, player.b_aug)[0] - (bandit.Theta.T).flatten()))
    bandit.context_generator = None
    return np.array(expected_reward), np.array(oracle_expected_reward)