import numpy as np
from sklearn.linear_model import Lasso
import argparse, os
from gen_data import gen_data
import cvxpy as cp
from sklearn.linear_model import LassoCV
# from pwe_estimate import lasso_prediction_result
import tqdm, pickle
import warnings

warnings.filterwarnings('ignore')


def compute_regret(y_true, select_arm_list):
    max_y = np.max(y_true, axis=0)
    pred = y_true[select_arm_list, np.arange(len(select_arm_list))]
    return np.cumsum(max_y - pred)


class LinUCBBandit:
    def __init__(self, n_arms, n_features, data, T, alpha=1.0):
        self.n_arms = n_arms
        self.n_features = n_features
        self.data = data
        self.T = T
        self.alpha = alpha

        self.reward = 0
        self.history = [[] for _ in range(n_arms)]
        self.results = [[] for _ in range(n_arms)]
        self.arm_sample_idxs = []

        self.A = [np.identity(n_features) for _ in range(n_arms)]
        self.b = [np.zeros(n_features) for _ in range(n_arms)]
        self.theta = [np.zeros(n_features) for _ in range(n_arms)]

        self.total_reward = 0
        self.cumulative_rewards = []
        self.arm_counts = np.zeros(n_arms)
        self.regret = []

    def run_experiment(self):
        for t in range(0, self.T):
            data_t = [entry[t, :] for entry in self.data]
            X_t = [entry[t, :-1] for entry in self.data]
            Y_t = [entry[t, -1] for entry in self.data]
            ucb_scores = np.zeros(self.n_arms)

            for a in range(self.n_arms):
                x = X_t[a]
                A_inv = np.linalg.inv(self.A[a])
                self.theta[a] = A_inv.dot(self.b[a])

                estimate = self.theta[a].dot(x)
                confidence = self.alpha * np.sqrt(x.T.dot(A_inv).dot(x))
                ucb_scores[a] = estimate + confidence

            chosen_arm = np.argmax(ucb_scores)
            self.arm_sample_idxs.append(chosen_arm)

            reward = Y_t[chosen_arm]
            context = X_t[chosen_arm]

            self.A[chosen_arm] += np.outer(context, context)
            self.b[chosen_arm] += reward * context

        return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", type=int, default=5, help="number of amrs")
    parser.add_argument("-T", type=int, default=500, help="number of rounds")
    parser.add_argument("-T0", type=int, default=125, help="number of T0 rounds")
    parser.add_argument("-p", type=int, default=200, help="dimension of arms")
    parser.add_argument("-rho", type=float, default=0.1, help="sparsity ratio")
    parser.add_argument("-case", type=int, default=2, help="sigma setting")
    parser.add_argument('-seed', type=int, default=0, help="random seed")

    args = parser.parse_args()
    data_path = "Your data path"

    if not os.path.exists(data_path):
        gen_data(args.p, args.T, args.n, args.rho, args.seed, args.case)

    data = np.load(data_path)
    Data = data['Data']
    ytrue = data['Y_true']
    results = []

    for i in range(10):
        print("Running experiment %d/%d" % (i+1,10))
        np.random.seed(i)
        for q in range(len(Data)):
            gend = Data[q]
            epsilon = np.random.normal(loc=0, scale=np.sqrt(0.01), size=args.T)
            Y = gend[:, -1] + epsilon
            gend[:, -1] = Y
            Data[q] = gend

        bandit = LinUCBBandit(n_arms=args.n, n_features=args.p, data=Data, T=args.T)

        bandit.run_experiment()
        select_arm_list = bandit.arm_sample_idxs

        cumulative_regret = compute_regret(ytrue, select_arm_list)
        results.append(cumulative_regret)

    with open("your save pth", 'wb') as f:
        pickle.dump(results, f)
