import numpy as np
from sklearn.linear_model import Lasso
import argparse, os
from gen_data import gen_data
import cvxpy as cp
from sklearn.linear_model import LassoCV
# from pwe_estimate import lasso_prediction_result
import tqdm
import warnings
import multiprocessing, pickle

warnings.filterwarnings('ignore')


def sparse_regression(X, Y, chosen_arm, K):

    d = X.shape[1]
    beta_hat = np.zeros((d, K))

    for k in range(K):

        indices = np.where(chosen_arm == k)[0]
        if len(indices) == 0:
            continue

        X_k = X[indices]
        Y_k = Y[indices].reshape(-1)


        lasso = LassoCV(cv=5).fit(X_k, Y_k)
        beta_k = lasso.coef_

        beta_hat[:, k] = beta_k

    return beta_hat


def select_optimal_arms(estimated_rewards, threshold):
    max_estimated_reward = np.max(estimated_rewards)
    optimal_arms = np.where(estimated_rewards >= max_estimated_reward - threshold / 2)[0]
    return optimal_arms



def estimate_beta(X_history, Y_history, chosen_arm_history, estimator, explore_history=None, optimal_arms=None):
    if explore_history is not None:
        indices = [i for i, is_explore in enumerate(explore_history) if is_explore]
    elif optimal_arms is not None:
        indices = [i for i, arm in enumerate(chosen_arm_history) if arm in optimal_arms]
    else:
        indices = [i for i, arm in enumerate(chosen_arm_history) if True]

    X_filtered = np.array(X_history)[indices]
    Y_filtered = np.array(Y_history)[indices]
    chosen_arm_history_filtered = np.array(chosen_arm_history)[indices]

    beta_hat_estimate = estimator(X_filtered, Y_filtered, chosen_arm_history_filtered,
                                  max(chosen_arm_history_filtered) + 1)
    return beta_hat_estimate

def generate_explore_times(K, q, T):
    T_explore = [[] for _ in range(K)]
    for n in range(int(np.log2(T // (K * q)) + 1)):
        for i in range(K):
            for j in range(i, K * q, K):
                t = (2 ** n - 1) * K * q + j
                if t < T:
                    T_explore[i].append(t)
    return T_explore


def compute_regret(y_true, select_arm_list):
    max_y = np.max(y_true, axis=0)
    pred = y_true[select_arm_list, np.arange(len(select_arm_list))]
    return np.cumsum(max_y - pred)


class BastaniLassoBandit:
    def __init__(self, n_arms, n_features, data, T0, T, h, q):
        self.n_arms = n_arms
        self.n_features = n_features
        self.data = data
        self.T0 = T0
        self.T = T
        self.reward = 0
        self.history = [[] for _ in range(n_arms)]
        self.results = [[] for _ in range(n_arms)]
        self.arm_sample_idxs = []
        self.h = h
        self.q = int(q)

    def run_experiment(self):
        X_history, Y_history = [], []
        chosen_arm_history, explore_history = [], []
        beta_hat = np.zeros((self.n_features, self.n_arms))

        def collect_data(X_t, Y_t, arm, is_explore):

            X_history.append(X_t[arm])
            Y_history.append(Y_t[arm])
            chosen_arm_history.append(arm)
            self.arm_sample_idxs.append(arm)
            explore_history.append(is_explore)


        T_explore = generate_explore_times(self.n_arms, self.q, self.T)

        def is_explore_phase(t, T_explore):
            t = t + 1
            for i, times in enumerate(T_explore):
                if t in times:
                    return i
            return -1

        for t in range(0, self.T):
            X_t = np.array([entry[t, :-1] for entry in self.data])
            Y_t = [entry[t, -1] for entry in self.data]
            chosen_arm = is_explore_phase(t, T_explore)
            is_explore = (chosen_arm != -1)

            if is_explore:

                collect_data(X_t, Y_t, chosen_arm, is_explore)
                if is_explore_phase(t + 1, T_explore) == -1:
                    beta_hat = estimate_beta(X_history, Y_history, chosen_arm_history, sparse_regression,
                                             explore_history=explore_history)
            else:

                estimate_rewards = np.sum(X_t * beta_hat.T, axis=1)
                optimal_arms = select_optimal_arms(estimate_rewards, self.h)
                beta_hat_optimal = estimate_beta(X_history, Y_history, chosen_arm_history, sparse_regression)
                final_estimated_rewards = np.sum(X_t * beta_hat_optimal.T, axis=1)
                chosen_arm = optimal_arms[np.argmax(final_estimated_rewards[optimal_arms])]
                collect_data(X_t, Y_t, chosen_arm, is_explore)

        return 0



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", type=int, default=5, help="number of amrs")
    parser.add_argument("-T", type=int, default=500, help="number of rounds")
    parser.add_argument("-T0", type=int, default=125, help="number of T0 rounds")
    parser.add_argument("-p", type=int, default=200, help="dimension of arms")
    parser.add_argument("-rho", type=float, default=0.1, help="sparsity ratio")
    parser.add_argument("-case", type=int, default=2, help="sigma setting")
    parser.add_argument('-seed', type=int, default=0, help="random seed")
    parser.add_argument("-bandith", type=float, default=0.01, help="bandith")
    parser.add_argument("-banditq", type=float, default=3, help="banditq")

    args = parser.parse_args()
    data_path=f"/data/p_{args.p}_T_{args.T}_n_{args.n}_rho_{args.rho}_seed{args.seed}_case_{args.case}.npz"

    if not os.path.exists(data_path):
        gen_data(args.p,args.T,args.n,args.rho,args.seed,args.case)

    data = np.load(data_path)
    Data = data['Data']
    ytrue= data['Y_true']


    basbandit = BastaniLassoBandit(n_arms=args.n, n_features=args.p, data=Data, T0=args.T0, T=args.T,
                                       h=args.bandith, q=args.banditq)
    basbandit.run_experiment()
    select_arm_list = basbandit.arm_sample_idxs

    cumulative_regret = compute_regret(ytrue,select_arm_list)

    import matplotlib.pyplot as plt

    plt.plot(cumulative_regret)

    plt.savefig(f'bastani_p{args.p}_T{args.T}_n{args.n}_rho{args.rho}_seed{args.seed}_case{args.case}.png', bbox_inches='tight')
