import numpy as np
from sklearn.linear_model import Lasso
import argparse,os
from gen_data import gen_data
from pwe_estimate import lasso_prediction_result

def compute_regret(y_true,select_arm_list):
    max_y = np.max(y_true, axis=0)
    pred = y_true[select_arm_list, np.arange(len(select_arm_list))]
    return np.cumsum(max_y-pred)
class LassoBandit:
    def __init__(self, n_arms, n_features, data, T0, T):
        self.n_arms = n_arms
        self.n_features = n_features
        self.data = data
        self.T0 = T0
        self.T = T
        self.reward = 0
        self.history = [[] for _ in range(n_arms)]
        self.results = [[] for _ in range(n_arms)]
        self.arm_sample_idxs = []


    def explore(self):
        for idx in range(self.T0):
            self.history[idx % self.n_arms].append(self.data[idx % self.n_arms][idx])
            self.arm_sample_idxs.append(idx % self.n_arms)
        self.history = [np.array(l) for l in self.history]

    def commit(self):
        for arm_id in range(self.n_arms):
            test_data = self.data[arm_id][self.T0:]
            result = lasso_prediction_result(self.history[arm_id], test_data[:,:-1])
            self.results[arm_id]=result.reshape(-1)
        selected=np.argmax(np.array(self.results), axis=0).tolist()
        self.arm_sample_idxs=self.arm_sample_idxs+selected

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", type=int, default=5, help="number of amrs")
    parser.add_argument("-T", type=int, default=500, help="number of rounds")
    parser.add_argument("-T0", type=int, default=125, help="number of T0 rounds")
    parser.add_argument("-p", type=int, default=200, help="dimension of arms")
    parser.add_argument("-rho", type=float, default=0.1, help="sparsity ratio")
    parser.add_argument("-case", type=int, default=1, help="sigma setting")
    parser.add_argument('-seed', type=int, default=0, help="random seed")

    args = parser.parse_args()
    data_path=f"/data/p_{args.p}_T_{args.T}_n_{args.n}_rho_{args.rho}_seed{args.seed}_case_{args.case}.npz"

    if not os.path.exists(data_path):
        gen_data(args.p,args.T,args.n,args.rho,args.seed,args.case)

    data = np.load(data_path)
    Data = data['Data']
    ytrue= data['Y_true']


    bandit = LassoBandit(n_arms=args.n, n_features=args.p, data=Data,T0=args.T0, T=args.T)

    bandit.explore()
    bandit.commit()
    select_arm_list=bandit.arm_sample_idxs

    cumulative_regret = compute_regret(ytrue,select_arm_list)

    import matplotlib.pyplot as plt

    plt.plot(cumulative_regret)

    plt.savefig(f'lasso_p{args.p}_T{args.T}_n{args.n}_rho{args.rho}_seed{args.seed}_case{args.case}.png', bbox_inches='tight')
