import argparse
import datetime
import math
import os
import pickle
from sklearn.metrics import pairwise_distances
import torch
import numpy as np
from dvutils.data import Custom_Dataset, load_data
from dvutils.models_defined import model_dict
from tqdm import tqdm
from joblib import Parallel, delayed
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import torch
import gpytorch
import shutil

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.RBFKernel()
        

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train_gp(train_x, train_y, noise=0.1, length_scale=0.1):
    # initialize likelihood and model
    train_x = torch.vstack(train_x).cuda()
    train_y = torch.tensor(train_y, dtype=torch.float32).cuda()
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = ExactGPModel(train_x, train_y, likelihood)
    model = model.cuda()
    likelihood = likelihood.cuda()

    hypers = {
        'likelihood.noise_covar.noise': torch.tensor(noise),
        'covar_module.lengthscale': torch.tensor(length_scale),
    }

    model.initialize(**hypers)
            
    return model, likelihood

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Process which dataset to run')
    parser.add_argument('-num_dp', '--num_dp', help='Number of data points to do data valuation'
                        , type=int, default=100)
    parser.add_argument('-gpus', '--gpus', help='Specify the gpu to use to compute data Shapley'
                        , type=str, default='0123')
    parser.add_argument('-batch_size', '--batch_size', help='Batch size'
                        , type=int, default=200)
    parser.add_argument('-epochs', '--epochs', help='Number of epochs to train the model'
                        , type=int, default=10)
    parser.add_argument('-model', '--model', help='Type of model to use'
                        , type=str, default="MLP")
    parser.add_argument('-metric', '--metric', help='Metric for model performance'
                        , type=str, default="acc")
    parser.add_argument('-exp', '--exp', help='Name for the experiment'
                        , type=str, default="debug")
    parser.add_argument('-seed', '--seed', help='Seed for data valuation'
                        , type=int, default=123)
    parser.add_argument('-num_query', '--num_query', help='The total number of queries'
                        , type=int, default=1000)
    parser.add_argument('-init_query', '--init_query', help='The number of queries for initialization'
                        , type=int, default=100)
    parser.add_argument('-m', '--m', help='The top-m data values we want to identify'
                        , type=int, default=70)
    parser.add_argument('-beta', '--beta', help='Hyperparameter for the upper bound'
                        , type=float, default=1.0)
    parser.add_argument('-noise', '--noise', help='The variance of the noise'
                        , type=float, default=0.1)
    parser.add_argument('-cpuonly', '--cpuonly', help='Use only cpu'
                            , type=int, default=1)
    parser.add_argument('-useorder', '--useorder', help='Whether to use order to specify the value of beta'
                            , type=str, default="SE")
    parser.add_argument('-feature', '--feature', help='What kind of feature to use for Gaussian process'
                            , type=str, default="MLP")
    parser.add_argument('-normalizey', '--normalizey', help='Whether to normalize y'
                            , type=str, default="NO")
    parser.add_argument('-tau', '--tau', help='tau for data valuation'
                            , type=float, default=50)
    parser.add_argument('-length_scale', '--length_scale', help='length scale for Gaussian process'
                            , type=float, default=10)
    parser.add_argument('-lr', '--lr', help='Learning rate'
                            , type=float, default=0.001)
    parser.add_argument('-smallset', '--smallset', help='Whether to focus on small set or not the number indicate the fraction'
                            , type=float, default=1)
    parser.add_argument('-singlequery', '--singlequery', help='The number of queries being done when one point is selected'
                            , type=int, default=10)
    cmd_args = vars(parser.parse_args())
    
    if cmd_args['cpuonly']:
        gpus = [torch.device(f"cpu")]
    else:
        gpus = [torch.device(f"cuda:{tmp}") for tmp in cmd_args['gpus']]
    gpu_sequnce = gpus * (cmd_args['num_dp'] // len(gpus) + 1)

    # set seed
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.backends.cudnn.deterministic = True

    # arguments
    device = gpus[0] # ['cpu', 'cuda']
    metric = cmd_args['metric'] # ['neg_loss', 'acc']

    # parameter for NN training
    model_fn, Model_Train = model_dict[cmd_args['model']]
    lr = cmd_args['lr']
    optimizer_fn = torch.optim.Adam
    loss_fn = torch.nn.CrossEntropyLoss
    batch_size = cmd_args['batch_size']
    epochs = cmd_args['epochs']
    num_query = cmd_args['num_query']
    num_dp = cmd_args['num_dp']
    seed = cmd_args['seed']
    # load dataset
    trdata, valdata, trlabel, vallabel = load_data(dataset='mnist', numdp=cmd_args['num_dp'])

    tr_set = Custom_Dataset(torch.tensor(trdata, dtype=torch.float32), torch.tensor(trlabel, dtype=torch.int64), device=device, return_idx=False)
    val_set = Custom_Dataset(torch.tensor(valdata, dtype=torch.float32), torch.tensor(vallabel, dtype=torch.int64), device=device, return_idx=False)

    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size * 5, shuffle=False)

    # select data points for data valuation
    selected_idxs = np.random.choice(len(tr_set), cmd_args['num_dp'], replace=False)
    data_selected = torch.utils.data.Subset(tr_set, selected_idxs)
    selected_x, selected_y = torch.vstack([tmp[0] for tmp in data_selected]).numpy(), torch.vstack([tmp[1].reshape(1,1) for tmp in data_selected]).numpy().reshape(-1)
    val_x = torch.vstack([tmp[0] for tmp in val_set]).numpy()
    train_loader = torch.utils.data.DataLoader(data_selected, batch_size=batch_size, shuffle=False)

    model_train = Model_Train(model_fn, optimizer_fn, loss_fn, lr, batch_size, epochs, device)



    if cmd_args['feature'] == "MLP":
        model_fn_, Model_Train_ = model_dict['MLP']
        model_train_ = Model_Train_(model_fn_, optimizer_fn, loss_fn, lr, batch_size, epochs, device)
        model_train_.fit(train_loader, val_loader, verbose=False)
        loss, acc = model_train_.evaluate(val_loader)
        print("Loss: {:.4f}, Acc: {:.4f}".format(loss, acc))
        with torch.no_grad():
            selected_x = model_train_.model.get_feature(torch.tensor(selected_x, dtype=torch.float32).to(device)).detach().cpu().numpy()
            val_x = model_train_.model.get_feature(torch.tensor(val_x, dtype=torch.float32).to(device)).detach().cpu().numpy()


    ### Calculate the ground truth shapley value
    # Compute the pairwise distance between select_x and val_x
    pair_dis = pairwise_distances(selected_x, val_x)
    # Calculate how many val x is within tau distance to each selected x
    tf_mask = pair_dis <= cmd_args['tau']

    def compute_gap_upperbound_list(n, data, gp, gp_likelihood, beta, m):
        gp.eval()
        gp_likelihood.eval()
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            observed_pred = gp(data)
            y_pred = observed_pred.mean
            cov = observed_pred.covariance_matrix
        j_t = torch.argsort(-y_pred)[:m].tolist()
        non_j_t = list(set(list(range(n))) - set(j_t))
        j_t_matrix = cov[j_t,:][:,j_t]
        non_j_t_matrix = cov[non_j_t,:][:,non_j_t]
        jt_trace = torch.trace(j_t_matrix)
        non_jt_trace = torch.trace(non_j_t_matrix)
        jt_cov = jt_trace.repeat([len(non_j_t), 1]).transpose(0,1)
        non_jt_cov = non_jt_trace.repeat([len(j_t), 1])
        jt_non_jt_cov = cov[j_t,:][:,non_j_t]
        gap_std = torch.sqrt(jt_cov + non_jt_cov - 2 * jt_non_jt_cov)
        y_pred_jt = y_pred[j_t].repeat([len(non_j_t), 1]).transpose(0,1)
        y_pred_non_jt = y_pred[non_j_t].repeat([len(j_t), 1])
        B_ij = y_pred_non_jt - y_pred_jt + beta * gap_std
        best_idx = torch.argmax(B_ij)
        best_i = non_j_t[best_idx % len(non_j_t)]
        best_j = j_t[best_idx // len(non_j_t)]
        select_idx = best_i if cov[best_i,best_i] >= cov[best_j,best_j] else best_j
        print(select_idx)
        return select_idx, y_pred.cpu().numpy()
    
    def one_process_evaluation(selected_idx_, idx):
        u_s = np.sum(np.sum(tf_mask[selected_idx_], axis=0) > 1)             
        selected_idx_ = np.append(selected_idx_, idx)
        u_s_i = np.sum(np.sum(tf_mask[selected_idx_], axis=0) > 1)
        marginal_contribs = u_s_i - u_s
        return marginal_contribs
    
    def one_data_point_evaluation(idx):
        all_query = []
        other_points = set(range(cmd_args['num_dp'])) - set([idx])
        max_set_size = math.floor((cmd_args['num_dp'] - 1) * cmd_args['smallset'])
        all_size_cur = np.random.choice(max_set_size, cmd_args['singlequery'])
        all_selected_idx_ = [np.random.choice(list(other_points), size_cur, replace=False) for size_cur in all_size_cur]
        all_query = Parallel(n_jobs=cmd_args['singlequery'])(delayed(one_process_evaluation)(all_selected_idx_[m_], idx) for m_ in range(cmd_args['singlequery']))
        
        marginal_contribs = np.mean(all_query)
        return marginal_contribs
        
    
    selected_x = torch.tensor(selected_x, dtype=torch.float32).cuda()
    
    history = [[],[]]
    dv_result = []
    dv_result_query = []
    query_history = [[] for _ in range(num_dp)]
    idx_history = []
    time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
    local_folder = "./GPGAPE_GT/"
    target_folder = "./GPGAPE_GT/"
    # Check if the folder already exists
    if not os.path.exists(local_folder):
        # Create the folder
        os.makedirs(local_folder)
        print("Folder created.")
    else:
        print("Folder already exists.")
    save_name = local_folder + "data_shapley_mnist_numdp{}_nquery{}_noise{}_lengthscale{}_m{}_beta{}{}_{}_{}_{}_{}_{}.pkl".format(cmd_args['num_dp'], cmd_args['num_query'], cmd_args['noise'], cmd_args['length_scale'], cmd_args['m'], cmd_args['beta'], cmd_args['useorder'], cmd_args['metric'], cmd_args['model'], cmd_args['exp'], cmd_args['normalizey'], time_str)
    # initialization phase
    init_idxs = np.random.choice(num_dp, cmd_args['init_query'])
    for idx_ in init_idxs:
        y = one_data_point_evaluation(idx_)
        history[0].append(selected_x[idx_])
        history[1].append(y)
        query_history[idx_].append(y)
        idx_history.append(idx_)

    if cmd_args['normalizey'] != "NO":
        if cmd_args['normalizey'] == "MINMAX":
            scaler = MinMaxScaler()
        elif cmd_args['normalizey'] == "NORMAL":
            scaler = StandardScaler()
        y_ = scaler.fit_transform(np.array(history[1]).reshape(-1,1)).reshape(-1)
    else:
        y_ = np.array(history[1])
    gp_model, gp_likelihood = train_gp(history[0], y_, noise=cmd_args['noise'], length_scale=cmd_args['length_scale'])

    for query in tqdm(range(num_query)):
        if cmd_args['useorder'] == 'FS':
            beta_ = 1 + np.sqrt(np.log(query + 1))
        elif cmd_args['useorder'] == 'SE':
            beta_ = 1 + np.sqrt(np.log(query + 1) ** 3)
        else:
            beta_ = cmd_args['beta']
        query_idx, y_pred = compute_gap_upperbound_list(num_dp, selected_x, gp_model, gp_likelihood, beta_, cmd_args['m'])
        y = one_data_point_evaluation(query_idx)
        history[0].append(selected_x[query_idx])
        history[1].append(y)
        query_history[query_idx].append(y)
        idx_history.append(query_idx)
        if cmd_args['normalizey'] != "NO":
            if cmd_args['normalizey'] == "MINMAX":
                scaler = MinMaxScaler()
            elif cmd_args['normalizey'] == "NORMAL":
                scaler = StandardScaler()
            y_ = scaler.fit_transform(np.array(history[1]).reshape(-1,1)).reshape(-1)
        else:
            y_ = np.array(history[1])
        gp_model, gp_likelihood = train_gp(history[0], y_, noise=cmd_args['noise'], length_scale=cmd_args['length_scale'])
        dv_result.append(y_pred)
        dv_result_query.append([np.mean(tmp) for tmp in query_history])

    with open(save_name, 'wb') as handle:
        pickle.dump({"selected_idxs":selected_idxs, "dv_result":dv_result, "dv_result_query":dv_result_query, "args": cmd_args, "query_history":query_history}, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # Moving the file
    shutil.move(save_name, target_folder)
