import argparse
import datetime
import math
import os
import pickle
import torch
import numpy as np
from dvutils.data import Custom_Dataset, load_data
from dvutils.models_defined import model_dict
from tqdm import tqdm
from joblib import Parallel, delayed
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import torch
import gpytorch
import shutil
from sklearn.decomposition import PCA

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.covar_module = gpytorch.kernels.RBFKernel()
        

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train_gp(train_x, train_y, noise=0.1, length_scale=0.1):
    # initialize likelihood and model
    train_x = torch.vstack(train_x).cuda()
    train_y = torch.tensor(train_y, dtype=torch.float32).cuda()
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = ExactGPModel(train_x, train_y, likelihood)
    model = model.cuda()
    likelihood = likelihood.cuda()

    hypers = {
        'likelihood.noise_covar.noise': torch.tensor(noise),
        'covar_module.lengthscale': torch.tensor(length_scale),
    }

    model.initialize(**hypers)
            
    return model, likelihood


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Process which dataset to run')
    parser.add_argument('-num_dp', '--num_dp', help='Number of data points to do data valuation'
                        , type=int, default=1000)
    parser.add_argument('-num_eval', '--num_eval', help='Number of data points in validation set'
                        , type=int, default=1000)
    parser.add_argument('-gpus', '--gpus', help='Specify the gpu to use to compute data Shapley'
                        , type=str, default='0123')
    parser.add_argument('-batch_size', '--batch_size', help='Batch size'
                        , type=int, default=200)
    parser.add_argument('-epochs', '--epochs', help='Number of epochs to train the model'
                        , type=int, default=10)
    parser.add_argument('-dataset', '--dataset', help='Which dataset to use'
                        , type=str, default="mnist")
    parser.add_argument('-model', '--model', help='Type of model to use'
                        , type=str, default="MLP")
    parser.add_argument('-metric', '--metric', help='Metric for model performance'
                        , type=str, default="acc")
    parser.add_argument('-exp', '--exp', help='Name for the experiment'
                        , type=str, default="new_exp")
    parser.add_argument('-seed', '--seed', help='Seed for data valuation'
                        , type=int, default=123)
    parser.add_argument('-length_scale', '--length_scale', help='Lengthscale for RBF kernel'
                        , type=float, default=5)
    parser.add_argument('-num_query', '--num_query', help='The total number of queries'
                        , type=int, default=1000)
    parser.add_argument('-init_query', '--init_query', help='The number of queries for initialization'
                        , type=int, default=100)
    parser.add_argument('-m', '--m', help='The top-m data values we want to identify'
                        , type=int, default=20)
    parser.add_argument('-beta', '--beta', help='Hyperparameter for the upper bound'
                        , type=float, default=1.0)
    parser.add_argument('-noise', '--noise', help='The variance of the noise'
                        , type=float, default=0.1)
    parser.add_argument('-cpuonly', '--cpuonly', help='Use only cpu'
                            , type=int, default=1)
    parser.add_argument('-useorder', '--useorder', help='Whether to use order to specify the value of beta'
                            , type=str, default="SE")
    parser.add_argument('-feature', '--feature', help='What kind of feature to use for Gaussian process'
                            , type=str, default="MLP")
    parser.add_argument('-normalizey', '--normalizey', help='Whether to normalize y'
                            , type=str, default="NO")
    parser.add_argument('-lr', '--lr', help='Learning rate'
                            , type=float, default=0.001)
    parser.add_argument('-smallset', '--smallset', help='Whether to focus on small set or not the number indicate the fraction'
                            , type=float, default=1)
    parser.add_argument('-singlequery', '--singlequery', help='The number of queries being done when one point is selected'
                            , type=int, default=10)
    parser.add_argument('-best_k', '--best_k', help='The number of data points selected in each iter'
                            , type=int, default=1)
    parser.add_argument('-pca', '--pca', help='Whether to use PCA to reduce dimension'
                            , type=int, default=0)
    parser.add_argument('-noisenum', '--noisenum', help='The number of noisy data points'
                        , type=int, default=0)
    cmd_args = vars(parser.parse_args())
    if cmd_args['cpuonly']:
        gpus = [torch.device(f"cpu")]
    else:
        gpus = [torch.device(f"cuda:{tmp}") for tmp in cmd_args['gpus']]
    gpu_sequnce = list(range(len(cmd_args['gpus']))) * (cmd_args['num_dp'] // len(gpus) + 1)
    
    # set seed
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.backends.cudnn.deterministic = True

    # arguments
    device = gpus[0] # ['cpu', 'cuda']
    metric = cmd_args['metric'] # ['neg_loss', 'acc']

    # parameter for NN training
    model_fn, Model_Train = model_dict[cmd_args['model']]
    lr = cmd_args['lr']
    optimizer_fn = torch.optim.Adam
    if cmd_args['model'] == "MLP_R":
        loss_fn = torch.nn.MSELoss
    else:
        loss_fn = torch.nn.CrossEntropyLoss
    batch_size = cmd_args['batch_size']
    epochs = cmd_args['epochs']
    num_query = cmd_args['num_query']
    num_dp = cmd_args['num_dp']
    seed = cmd_args['seed']
    # load dataset
    flatten=False if 'CNN' in cmd_args['model'] else True
    trdata, valdata, trlabel, vallabel = load_data(dataset=cmd_args['dataset'], numdp=cmd_args['num_dp'], flatten=flatten)
    valdata, vallabel = valdata[:cmd_args['num_eval']], vallabel[:cmd_args['num_eval']]

    tr_set = Custom_Dataset(torch.tensor(trdata, dtype=torch.float32), torch.tensor(trlabel, dtype=torch.int64), device=device, return_idx=False)
    val_set = Custom_Dataset(torch.tensor(valdata, dtype=torch.float32), torch.tensor(vallabel, dtype=torch.int64), device=device, return_idx=False)
    
    # select 1k data points for data valuation
    selected_idxs = np.random.choice(len(tr_set), cmd_args['num_dp'], replace=False)
    selected_data, selected_label = trdata[selected_idxs], trlabel[selected_idxs]
    if cmd_args['noisenum'] > 0:
        index_noise = np.random.choice(len(selected_data), cmd_args['noisenum'], replace=False)
        selected_data[index_noise] = selected_data[index_noise] + np.random.normal(0,2,selected_data[index_noise].shape)

    selected_feature = selected_data
    if cmd_args['pca']:
        from sklearn.decomposition import PCA
        pca = PCA(n_components = 32, random_state=0)
        selected_data = pca.fit_transform(selected_data)
        valdata = pca.transform(valdata)
        selected_feature = selected_data

    if cmd_args['model'] == "Logistic":
        selected_label = np.array([str(tmp) for tmp in selected_label])
        vallabel = np.array([str(tmp) for tmp in vallabel])
    elif "MLP" in cmd_args['model'] or "CNN" in cmd_args['model']:
        if cmd_args['model'] == "MLP_R":
            selected_data, selected_label = torch.tensor(selected_data, dtype=torch.float32), torch.tensor(selected_label, dtype=torch.float32)
            valdata, vallabel = torch.tensor(valdata, dtype=torch.float32), torch.tensor(vallabel, dtype=torch.float32)
        else:
            selected_data, selected_label = torch.tensor(selected_data, dtype=torch.float32), torch.tensor(selected_label, dtype=torch.long)
            valdata, vallabel = torch.tensor(valdata, dtype=torch.float32), torch.tensor(vallabel, dtype=torch.long)
        if not cmd_args['cpuonly']:
            selected_data_gpu, selected_label_gpu = [selected_data.to(device_) for device_ in gpus], [selected_label.to(device_) for device_ in gpus]
            valdata_gpu, vallabel_gpu = [valdata.to(device_) for device_ in gpus], [vallabel.to(device_) for device_ in gpus]
        model_train_gpu = [Model_Train(model_fn, optimizer_fn, loss_fn, lr, batch_size, epochs, gpus[(i % len(gpus))], val_data=valdata_gpu[(i % len(gpus))] , val_label=vallabel_gpu[(i % len(gpus))], train_data=selected_data_gpu[(i % len(gpus))], train_label=selected_label_gpu[(i % len(gpus))]) for i in range(cmd_args['singlequery'] * cmd_args['best_k'])]
        
        model_train = Model_Train(model_fn, optimizer_fn, loss_fn, lr, batch_size, epochs, device)
        model_train.fit_e(selected_data, selected_label)
        loss, acc = model_train.evaluate_e(valdata, vallabel)
        print("Loss: {:.4f}, Acc: {:.4f}".format(loss, acc))
        selected_feature = model_train.model.get_feature(torch.tensor(selected_data, dtype=torch.float32).to(device)).detach().cpu().numpy()

    model_train = Model_Train(model_fn, optimizer_fn, loss_fn, lr, batch_size, epochs, device)
    model_train.fit_e(selected_data, selected_label)
    loss, acc = model_train.evaluate_e(valdata, vallabel)
    print("Loss: {:.4f}, Acc: {:.4f}".format(loss, acc))
    selected_feature = torch.tensor(selected_feature, dtype=torch.float32).cuda()
    np.random.seed(cmd_args['seed'])

    def compute_gap_upperbound_list(n, data, gp, gp_likelihood, beta, m, best_k):
        gp.eval()
        gp_likelihood.eval()
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            observed_pred = gp(data)
            y_pred = observed_pred.mean
            cov = observed_pred.covariance_matrix
        j_t = torch.argsort(-y_pred)[:m].tolist()
        non_j_t = list(set(list(range(n))) - set(j_t))
        j_t_matrix = cov[j_t,:][:,j_t]
        non_j_t_matrix = cov[non_j_t,:][:,non_j_t]
        jt_trace = torch.trace(j_t_matrix)
        non_jt_trace = torch.trace(non_j_t_matrix)
        jt_cov = jt_trace.repeat([len(non_j_t), 1]).transpose(0,1)
        non_jt_cov = non_jt_trace.repeat([len(j_t), 1])
        jt_non_jt_cov = cov[j_t,:][:,non_j_t]
        gap_std = torch.sqrt(jt_cov + non_jt_cov - 2 * jt_non_jt_cov)
        y_pred_jt = y_pred[j_t].repeat([len(non_j_t), 1]).transpose(0,1)
        y_pred_non_jt = y_pred[non_j_t].repeat([len(j_t), 1])
        B_ij = y_pred_non_jt - y_pred_jt + beta * gap_std
        best_k_idx = torch.argsort(-B_ij.flatten())[:best_k]
        select_idx = []
        for idx_ in best_k_idx:
            best_i = non_j_t[idx_ % len(non_j_t)]
            best_j = j_t[idx_ // len(non_j_t)]
            if best_i in select_idx or best_j in select_idx:
                select_idx_ = best_i if best_j in select_idx else best_j
            else:
                select_idx_ = best_i if (cov[best_i,best_i] >= cov[best_j,best_j]) else best_j
            select_idx.append(select_idx_)

        print(select_idx)
        return select_idx, y_pred.cpu().numpy()

    def one_process_evaluation(selected_idx_, idx, gpu_idx):
        if "MLP" in cmd_args['model'] or "CNN" in cmd_args['model']:
            selected_data_ = selected_data_gpu[(gpu_idx % len(gpus))]
            selected_label_ = selected_label_gpu[(gpu_idx % len(gpus))]
            valdata_ = valdata_gpu[(gpu_idx % len(gpus))]
            vallabel_ = vallabel_gpu[(gpu_idx % len(gpus))]
            model_train_ = model_train_gpu[gpu_idx]
        else:
            selected_data_ = selected_data
            selected_label_ = selected_label
            valdata_ = valdata
            vallabel_ = vallabel
            model_train_ = model_train
            
        model_train_.fit_e(selected_data_[selected_idx_], selected_label_[selected_idx_])
        loss_old, acc_old = model_train_.evaluate_e(valdata_, vallabel_)
        selected_idx_ = np.append(selected_idx_, idx)
        model_train_.fit_e(selected_data_[selected_idx_], selected_label_[selected_idx_])
        loss_new, acc_new = model_train_.evaluate_e(valdata_, vallabel_)
        marginal_contrib_ = acc_new - acc_old if cmd_args['metric'] == "acc" else loss_old - loss_new
        return marginal_contrib_
    
    def one_data_point_evaluation(idx):
        all_query = []
        other_points = set(range(cmd_args['num_dp'])) - set([idx])
        max_set_size = math.floor((cmd_args['num_dp'] - 1) * cmd_args['smallset'])
        all_size_cur = np.random.choice(max_set_size, cmd_args['singlequery'])
        all_selected_idx_ = [np.random.choice(list(other_points), size_cur, replace=False) for size_cur in all_size_cur]
        all_query = Parallel(n_jobs=cmd_args['singlequery'])(delayed(one_process_evaluation)(all_selected_idx_[m_], idx) for m_ in range(cmd_args['singlequery']))
        
        marginal_contribs = np.mean(all_query)
        return marginal_contribs

    def one_evaluation(idxs):
        mc_idx = []
        mc_selected = []
        for idx_ in idxs:
            other_points = set(range(cmd_args['num_dp'])) - set([idx_])
            max_set_size = math.floor((cmd_args['num_dp'] - 1) * cmd_args['smallset'])
            all_size_cur = np.random.choice(max_set_size, cmd_args['singlequery'])
            all_selected_idx_ = [np.random.choice(list(other_points), size_cur, replace=False) for size_cur in all_size_cur]
            mc_idx += [idx_] * cmd_args['singlequery']
            mc_selected += all_selected_idx_
        n_jobs = min(len(mc_idx),10)
        all_query = Parallel(n_jobs=n_jobs)(delayed(one_process_evaluation)(mc_selected[m_], mc_idx[m_], m_) for m_ in range(len(mc_idx)))
        
        # compute average for each idx
        all_query = np.array(all_query).reshape(-1, cmd_args['singlequery'])
        all_query = np.mean(all_query, axis=1)
        if cmd_args['noisenum'] > 0:
            all_query = -all_query
        return all_query
    
    history = [[],[]]
    dv_result = []
    dv_result_query = []
    query_history = [[] for _ in range(num_dp)]
    idx_history = []
    time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
    if cmd_args['noisenum'] > 0:
        local_folder = "./GPGAPE-TORCH-NOISE/"
        target_folder = "./GPGAPE-TORCH-NOISE/"
    else:
        local_folder = "./GPGAPE-TORCH/"
        target_folder = "./GPGAPE-TORCH/"
    # Check if the folder already exists
    if not os.path.exists(local_folder):
        # Create the folder
        os.makedirs(local_folder)
        print("Folder created.")
    else:
        print("Folder already exists.")
    save_name = local_folder + "data_shapley_{}_seed{}_numdp{}_nquery{}_noise{}_lengthscale{}_m{}_beta{}{}_{}_{}_lr{}_{}_{}_{}.pkl".format(cmd_args['dataset'], cmd_args['seed'], cmd_args['num_dp'], cmd_args['num_query'], cmd_args['noise'], cmd_args['length_scale'], cmd_args['m'], cmd_args['beta'], cmd_args['useorder'], cmd_args['metric'], cmd_args['model'], cmd_args['lr'], cmd_args['exp'], cmd_args['normalizey'], time_str)

    # initialization phase
    init_idxs = np.random.choice(num_dp, cmd_args['init_query'])
    for idx_ in tqdm(init_idxs, desc='Initialization phase'):
        y = one_evaluation([idx_])
        history[0] += [selected_feature[idx_]]
        history[1] += y.tolist()
        query_history[idx_] += y.tolist()
        
        idx_history.append(idx_)
    
    if cmd_args['normalizey'] != "NO":
        if cmd_args['normalizey'] == "MINMAX":
            scaler = MinMaxScaler()
        elif cmd_args['normalizey'] == "NORMAL":
            scaler = StandardScaler()
        y_ = scaler.fit_transform(np.array(history[1]).reshape(-1,1)).reshape(-1)
    else:
        y_ = np.array(history[1])
    gp_model, gp_likelihood = train_gp(history[0], y_, noise=cmd_args['noise'], length_scale=cmd_args['length_scale'])

    for query in tqdm(range(num_query)):
        if cmd_args['useorder'] == 'FS':
            beta_ = 1 + np.sqrt(np.log(query + 1))
        elif cmd_args['useorder'] == 'SE':
            beta_ = 1 + np.sqrt(np.log(query + 1) ** 3)
        else:
            beta_ = cmd_args['beta']
        query_idx, y_pred = compute_gap_upperbound_list(num_dp, selected_feature, gp_model, gp_likelihood, beta_, cmd_args['m'], cmd_args['best_k'])
        y = one_evaluation(query_idx)
        history[0] += [selected_feature[tmp_id] for tmp_id in query_idx]
        history[1] += y.tolist()
        for tmp_idx, tmp_y in zip(query_idx, y.tolist()):
            query_history[tmp_idx] += [tmp_y]
        idx_history += query_idx
        if cmd_args['normalizey'] != "NO":
            if cmd_args['normalizey'] == "MINMAX":
                scaler = MinMaxScaler()
            elif cmd_args['normalizey'] == "NORMAL":
                scaler = StandardScaler()
            y_ = scaler.fit_transform(np.array(history[1]).reshape(-1,1)).reshape(-1)
        else:
            y_ = np.array(history[1])
        gp_model, gp_likelihood = train_gp(history[0], y_, noise=cmd_args['noise'], length_scale=cmd_args['length_scale'])
        dv_result.append(y_pred)
        dv_result_query.append([np.mean(tmp) for tmp in query_history])

    with open(save_name, 'wb') as handle:
        pickle.dump({"selected_idxs":selected_idxs, "dv_result":dv_result, "dv_result_query":dv_result_query, "args": cmd_args, "query_history":query_history}, handle, protocol=pickle.HIGHEST_PROTOCOL)
    shutil.move(save_name, target_folder)