import numpy as np
from sklearn.linear_model import Lasso
import argparse,os
from lasso_bandit import LassoBandit
from rdl_etc import RDLBandit
from HOPE import HOPE
from gen_data import generate_diag_matrix, generate_data


def generate_sparse_vectors(rhos, p, seed=0, norm_value=1):
    np.random.seed(seed)
    vectors = []
    for rho in rhos:
        num_non_zero = int(rho * p)
        vector = np.zeros(p)
        non_zero_indices = np.random.choice(p, num_non_zero, replace=False)
        vector[non_zero_indices] = np.random.normal(loc=0.0, scale=1.0, size=num_non_zero)
        current_norm = np.linalg.norm(vector, ord=2)
        if current_norm > 0:
            vector = vector * (norm_value / current_norm)
        vectors.append(vector)
    return np.array(vectors)



def compute_regret(y_true,select_arm_list):
    max_y = np.max(y_true, axis=0)
    pred = y_true[select_arm_list, np.arange(len(select_arm_list))]
    return np.cumsum(max_y-pred)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", type=int, default=5, help="number of amrs")
    parser.add_argument("-T", type=int, default=500, help="number of rounds")
    parser.add_argument("-T0", type=int, default=125, help="number of T0 rounds")
    parser.add_argument("-p", type=int, default=200, help="dimension of arms")
    parser.add_argument("-num_sparse", type=int, default=3, help="number of sparse arms")
    parser.add_argument("-scase", type=int, default=1, help="sparse case")
    parser.add_argument("-nscase", type=int, default=2, help="nonsparse case")
    parser.add_argument('-seed', type=int, default=42, help="random seed")

    args = parser.parse_args()
    num_sparse = args.num_sparse
    num_nonsparse = args.n-num_sparse
    rhos = np.full(args.n, 0.1)


    rhos[num_sparse:] = 0.9
    Sigma2 = generate_diag_matrix(args.p, args.T, args.n, seed=args.seed, case=args.nscase)[0]
    Sigma1 = generate_diag_matrix(args.p, args.T, args.n, seed=args.seed, case=args.scase)[0]

    betas = generate_sparse_vectors(rhos, args.p, args.seed, 1)
    Gen_data1, y_true1 = generate_data(betas[:num_sparse], Sigma1, args.p, args.T, num_sparse, args.seed, 0.01)
    Gen_data2, y_true2 = generate_data(betas[num_sparse:], Sigma2, args.p, args.T, num_nonsparse, args.seed, 0.01)

    Data = Gen_data1+Gen_data2
    ytrue= np.array(y_true1+y_true2)

    data_fold = f"/data/mixed_n{args.n}_nonsparsecase{args.nscase}_sparsecase{args.scase}/"
    if not os.path.exists(data_fold):
        os.makedirs(data_fold)

    np.savez(data_fold+f"nsparse_{num_sparse}_seed{args.seed}.npz", betas=betas, Data=Data, Y_true=ytrue, Sigmas=[Sigma1,Sigma2])


    lassobandit = LassoBandit(n_arms=args.n, n_features=args.p, data=Data,T0=args.T0, T=args.T)
    lassobandit.explore()
    lassobandit.commit()
    lasso_select_arm_list=lassobandit.arm_sample_idxs
    lasso_cumulative_regret = compute_regret(ytrue,lasso_select_arm_list)

    rdlbandit = RDLBandit(n_arms=args.n, n_features=args.p, data=Data,T0=args.T0, T=args.T)
    rdlbandit.explore()
    rdlbandit.commit()
    rdl_select_arm_list=rdlbandit.arm_sample_idxs
    rdl_cumulative_regret = compute_regret(ytrue,rdl_select_arm_list)

    hopeetc = HOPE(n_arms=args.n, n_features=args.p, data=Data,T0=args.T0, T=args.T)
    hopeetc.explore()
    hopeetc.commit()
    hope_select_arm_list=hopeetc.arm_sample_idxs
    hope_cumulative_regret = compute_regret(ytrue,hope_select_arm_list)



    #print("result are:",lasso_cumulative_regret[-1],rdl_cumulative_regret[-1],hope_cumulative_regret[-1])

