import numpy as np
import os
import warnings
import random

from sklearn.cluster import k_means

warnings.filterwarnings("ignore")
import sklearn.datasets as datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from scipy import stats
from six.moves import cPickle
import sklearn.metrics.pairwise as Kernel

from scipy.io import loadmat
import pickle


def labels_to_reward_matrix(labels, A):
    length = len(labels)
    print(A, max(labels) + 1)
    # assert A == (max(labels) + 1)
    reward_m = np.zeros((length, A))
    for i in range(length):
        reward_m[i, int(labels[i])] = 1

    return reward_m


def labels_to_reward_matrix_partial_reward(labels, A, cluster_per_digit, partial_reward_val=0.5):
    length = len(labels)
    assert A == (max(labels) + 1)

    #
    ori_Labels, _ = np.divmod(labels, cluster_per_digit)

    #
    reward_m = np.zeros((length, A))
    for i in range(length):
        reward_m[i, int(labels[i])] = 1

    return reward_m


def get_arm_descriptor_from_contexts_samples(A, context_matrix, labels, sample_num=10):
    d = context_matrix.shape[1]
    descriptor_m = np.zeros([A, d])
    for aa in range(0, A):
        idx = np.where(labels == aa)[0]

        avg_descriptor = np.mean(context_matrix[idx[:sample_num], :], axis=0)
        descriptor_m[aa, :] = avg_descriptor
    return descriptor_m


def get_kernel_embedding_similarity_from_contexts_samples(A, context_matrix, labels, bandwidth, sample_num=350):
    d = context_matrix.shape[1]
    kernel_embedding_m = np.zeros([A, A])

    bw_emb = 1
    for a_i in range(0, A):
        idx_i = np.where(labels == a_i)[0][:sample_num]
        for a_j in range(0, A):
            idx_j = np.where(labels == a_j)[0][:sample_num]

            val = np.mean(Kernel.rbf_kernel(
                context_matrix[idx_i, :].reshape(sample_num, d),
                context_matrix[idx_j, :].reshape(sample_num, d),
                bw_emb
            ))
            kernel_embedding_m[a_i, a_j] = val

    # 20
    bw_simi = 10
    arm_similarity_m = np.zeros([A, A])
    for a_i in range(0, A):
        for a_j in range(0, A):
            sim = kernel_embedding_m[a_i, a_i] + kernel_embedding_m[a_j, a_j] \
                  - 2 * kernel_embedding_m[a_i, a_j]
            val = np.exp(-sim * bw_simi)
            arm_similarity_m[a_i, a_j] = val

    print("Similarity BW: ", bw_emb, bw_simi)
    print("Similarity matrix: ", arm_similarity_m)

    return arm_similarity_m


# ============================================================ Generating training data
def TrainDataCollect(data_flag, A, N_valid, N, T, RandomSeedNumber, RunNumber, Main_Program_flag,
                     noise_portion, noise_intensity,
                     items_per_step=10,
                     READ_PREVIOUS=True, MULTI_RUNS=False, bandwidth=None):
    """
    :param data_flag: Synthetic / Multi-class
    :param data_flag_multiclass: Which multi-class dataset to use
    :param A: Number of arms, only valid for synthetic data
    :param d: Dimension for synthetic data
    :param N_valid: Number of data points per arm in validation set
    :param N: Algorithm starts with one random example assigned to each arm, N = 1
    :param T: Run the UCB for T times.
    :param RandomSeedNumber: Seed for the ramdom number
    :param Main_Program_flag: Equal to 0, do not change
    :return:
    """

    # ------------------------------------------------- Multi-class data ===============================================
    if data_flag == 77:
        # ----------------------
        cluster_per_digit = 1
        print("Cluster per digit: ", cluster_per_digit)

        Features = np.load('./Dataset/Mnist_context_matrix_{}.npy'.format(noise_portion))
        ori_Labels = np.load('./Dataset/Mnist_label_matrix_{}.npy'.format(noise_portion))
        clean_labels = np.load('./Dataset/Mnist_label_matrix_{}_clean.npy'.format(noise_portion))

        # ========================= Label Augmentation
        Labels = np.copy(ori_Labels)
        Clean_Labels = np.copy(clean_labels)
        dim = Features.shape[1]

        # Sample some labels to obtain arm similarity
        print("Overall: ", np.unique(Labels, return_counts=True))

        # A * N --- Training data set
        idx = np.arange(A * N)
        # Permutation of the train dataset
        Features_train = Features[idx, :]
        Labels_train = Labels[idx]
        Labels_trian_clean = Clean_Labels[idx]

        # Reshuffle the data set
        Features_test = Features[A * N:, :]
        Labels_test = Labels[A * N:]
        Labels_test_clean = Clean_Labels[A * N:]

        # Label matrices
        Labels_train_matrix = labels_to_reward_matrix(Labels_train, A)
        Labels_test_matrix = labels_to_reward_matrix(Labels_test, A)

        # Label matrices
        Labels_train_matrix_clean = labels_to_reward_matrix(Labels_trian_clean, A)
        Labels_test_matrix_clean = labels_to_reward_matrix(Labels_test_clean, A)

        # ----------------------------s
        DataXY = dict()
        for aa in range(0, A):
            XTrain = Features_train[aa * N: (aa + 1) * N, :]
            LabelsTrain = Labels_train[aa * N: (aa + 1) * N]
            YTrain = np.zeros([N])
            YTrain[LabelsTrain == aa] = 1
            print(LabelsTrain, aa, YTrain)
            Total_Features = np.copy(XTrain)
            Arm_rewards = np.copy(YTrain)

            # Save training data for NEW KTLEst UCB ------------------------------------
            train_datasetNEWKTLEstUCB = 'Train_Datasets_NEWKTLEstUCB' + str(int(aa))
            DataXY[train_datasetNEWKTLEstUCB] = np.copy(Total_Features)

            train_labelsNEWKTLEstUCB = 'Train_Labels_NEWKTLEstUCB' + str(int(aa))
            DataXY[train_labelsNEWKTLEstUCB] = np.copy(Arm_rewards)

            # Save training data for KTL UCB
            train_datasetKTLUCB = 'Train_Datasets_KTLUCB' + str(int(aa))
            DataXY[train_datasetKTLUCB] = np.copy(Total_Features)

            train_labelsKTLUCB = 'Train_Labels_KTLUCB' + str(int(aa))
            DataXY[train_labelsKTLUCB] = np.copy(Arm_rewards)

            # Save training data for KTLEst UCB
            train_datasetKTLEstUCB = 'Train_Datasets_KTLEstUCB' + str(int(aa))
            DataXY[train_datasetKTLEstUCB] = np.copy(Total_Features)

            train_labelsKTLEstUCB = 'Train_Labels_KTLEstUCB' + str(int(aa))
            DataXY[train_labelsKTLEstUCB] = np.copy(Arm_rewards)

            # Save training data for Lin UCB
            train_datasetLinUCB = 'Train_Datasets_LinUCB' + str(int(aa))
            DataXY[train_datasetLinUCB] = np.copy(Total_Features)

            train_labelsLinUCB = 'Train_Labels_LinUCB' + str(int(aa))
            DataXY[train_labelsLinUCB] = np.copy(Arm_rewards)

            # Save training data for Pool UCB
            train_datasetPoolUCB = 'Train_Datasets_PoolUCB' + str(int(aa))
            DataXY[train_datasetPoolUCB] = np.copy(Total_Features)

            train_labelsPoolUCB = 'Train_Labels_PoolUCB' + str(int(aa))
            DataXY[train_labelsPoolUCB] = np.copy(Arm_rewards)

        DataXY['Testfeatures'] = np.copy(Features_test)
        DataXY['theta'] = 0
        DataXY['armTest'] = np.copy(Labels_test)
        DataXY['d'] = Features_test.shape[1]
        print("Shape: ", Features_train.shape)

        #
        DataXY['TrainContexts'] = Features_train
        DataXY['TestContexts'] = Features_test
        DataXY['TrainLabels'] = Labels_train_matrix
        DataXY['TestLabels'] = Labels_test_matrix
        #
        DataXY['TrainLabels_clean'] = Labels_train_matrix_clean
        DataXY['TestLabels_clean'] = Labels_test_matrix_clean

        # Number of arms ----------------
        DataXY['NoOfArms'] = A

    # ==================================================================================================================
    # ==================================================================================================================
    # Real-world MovieLens data
    elif data_flag == 12:
        context_matrix = np.load(
            './Matrices_Category/MovieLens_context_matrix_{}_{}.npy'.format(noise_portion, noise_intensity))
        reward_matrix = np.load(
            './Matrices_Category/MovieLens_reward_matrix_{}_{}.npy'.format(noise_portion, noise_intensity))
        reward_matrix_clean = np.load(
            './Matrices_Category/MovieLens_reward_matrix_{}_{}_clean.npy'.format(noise_portion, noise_intensity))

        with open('./Matrices_Category/MovieLens_category_dict.pickle', 'rb') as pk_file:
            category_dict = pickle.load(pk_file)
        similarity_m = np.load('./Matrices_Category/MovieLens_category_simi_matrix.npy')
        # items_per_step = context_matrix.shape[1]

        DataXY = dict()
        # For each arm, generate training samples by rotating the ellipse
        for aa in range(0, A):
            combined_embedding = context_matrix[aa, aa % 10, :].reshape(1, -1)
            combined_embedding = normalize(combined_embedding, axis=1)

            # X_train -- N x 2 / Y_train --- N x 1
            XTrain = combined_embedding
            YTrain = np.copy(reward_matrix[aa, aa % 10]).reshape(1, 1)

            # one sample +  one label for N = 1
            Total_Features = np.copy(XTrain)  # --- N x 2
            Arm_rewards = np.copy(YTrain).reshape(-1, )  # --- N x 1

            # Save training data for NEW-KTLEst UCB
            train_datasetNEWKTLEstUCB = 'Train_Datasets_NEWKTLEstUCB' + str(int(aa))
            DataXY[train_datasetNEWKTLEstUCB] = np.copy(Total_Features)

            train_labelsNEWKTLEstUCB = 'Train_Labels_NEWKTLEstUCB' + str(int(aa))
            DataXY[train_labelsNEWKTLEstUCB] = np.copy(Arm_rewards)

            # Save training data for KTL UCB
            train_datasetKTLUCB = 'Train_Datasets_KTLUCB' + str(int(aa))
            DataXY[train_datasetKTLUCB] = np.copy(Total_Features)

            train_labelsKTLUCB = 'Train_Labels_KTLUCB' + str(int(aa))
            DataXY[train_labelsKTLUCB] = np.copy(Arm_rewards)

            # Save training data for KTLEst UCB
            train_datasetKTLEstUCB = 'Train_Datasets_KTLEstUCB' + str(int(aa))
            DataXY[train_datasetKTLEstUCB] = np.copy(Total_Features)

            train_labelsKTLEstUCB = 'Train_Labels_KTLEstUCB' + str(int(aa))
            DataXY[train_labelsKTLEstUCB] = np.copy(Arm_rewards)

            # Save training data for Lin UCB
            train_datasetLinUCB = 'Train_Datasets_LinUCB' + str(int(aa))
            DataXY[train_datasetLinUCB] = np.copy(Total_Features)

            train_labelsLinUCB = 'Train_Labels_LinUCB' + str(int(aa))
            DataXY[train_labelsLinUCB] = np.copy(Arm_rewards)

            # Save training data for Pool UCB
            train_datasetPoolUCB = 'Train_Datasets_PoolUCB' + str(int(aa))
            DataXY[train_datasetPoolUCB] = np.copy(Total_Features)

            train_labelsPoolUCB = 'Train_Labels_PoolUCB' + str(int(aa))
            DataXY[train_labelsPoolUCB] = np.copy(Arm_rewards)

        #
        init_category_dict = {}
        shifted_category_dict = {}
        for i in range(A):
            for j in range(items_per_step):
                init_category_dict[tuple([i, j])] = category_dict.pop(tuple([i, j]))
        for i in range(A, context_matrix.shape[0]):
            for j in range(items_per_step):
                shifted_category_dict[tuple([i - A, j])] = category_dict.pop(tuple([i, j]))
        category_dict = shifted_category_dict

        # User context and arm context ---
        DataXY['context_matrix'] = context_matrix[A * N:, :items_per_step, :]
        DataXY['reward_matrix'] = reward_matrix[A * N:, :items_per_step]
        DataXY['reward_matrix_clean'] = reward_matrix_clean[A * N:, :items_per_step]

        DataXY['initContext'] = context_matrix[:A * N, :items_per_step, :]
        DataXY['init_reward_matrix'] = reward_matrix[:A * N, :items_per_step]
        DataXY['init_reward_matrix_clean'] = reward_matrix_clean[:A * N, :items_per_step]

        DataXY['theta'] = 0
        DataXY['init_Category_Dict'] = init_category_dict
        DataXY['Category_Dict'] = category_dict
        DataXY['genre_similarity'] = similarity_m
        DataXY['d'] = int(context_matrix.shape[2])
        DataXY['item_pool_size'] = items_per_step
        print("d: ", DataXY['d'])
        DataXY['NoOfArms'] = A

    # ==================================================================================================================
    # ==================================================================================================================
    # Real-world Amazon Recommendation dataset
    elif data_flag == 13:
        context_matrix = np.load(
            './Matrices_Category/Amazon/Amazon_context_matrix_{}_{}.npy'.format(noise_portion, noise_intensity))
        reward_matrix = np.load(
            './Matrices_Category/Amazon/Amazon_reward_matrix_{}_{}.npy'.format(noise_portion, noise_intensity))
        reward_matrix_clean = np.load(
            './Matrices_Category/Amazon/Amazon_reward_matrix_{}_{}_clean.npy'.format(noise_portion, noise_intensity))
        with open('./Matrices_Category/Amazon/Amazon_category_dict.pickle', 'rb') as pk_file:
            category_dict = pickle.load(pk_file)

        DataXY = dict()
        # For each arm, generate training samples by rotating the ellipse
        for aa in range(0, A):
            combined_embedding = context_matrix[aa, aa % 10, :].reshape(1, -1)
            combined_embedding = normalize(combined_embedding, axis=1)

            # X_train -- N x 2 / Y_train --- N x 1
            XTrain = combined_embedding
            YTrain = np.copy(reward_matrix[aa, aa % 10]).reshape(1, 1)

            # one sample +  one label for N = 1
            Total_Features = np.copy(XTrain)  # --- N x 2
            Arm_rewards = np.copy(YTrain).reshape(-1, )  # --- N x 1

            # Save training data for NEW-KTLEst UCB
            train_datasetNEWKTLEstUCB = 'Train_Datasets_NEWKTLEstUCB' + str(int(aa))
            DataXY[train_datasetNEWKTLEstUCB] = np.copy(Total_Features)

            train_labelsNEWKTLEstUCB = 'Train_Labels_NEWKTLEstUCB' + str(int(aa))
            DataXY[train_labelsNEWKTLEstUCB] = np.copy(Arm_rewards)

            # Save training data for KTL UCB
            train_datasetKTLUCB = 'Train_Datasets_KTLUCB' + str(int(aa))
            DataXY[train_datasetKTLUCB] = np.copy(Total_Features)

            train_labelsKTLUCB = 'Train_Labels_KTLUCB' + str(int(aa))
            DataXY[train_labelsKTLUCB] = np.copy(Arm_rewards)

            # Save training data for KTLEst UCB
            train_datasetKTLEstUCB = 'Train_Datasets_KTLEstUCB' + str(int(aa))
            DataXY[train_datasetKTLEstUCB] = np.copy(Total_Features)

            train_labelsKTLEstUCB = 'Train_Labels_KTLEstUCB' + str(int(aa))
            DataXY[train_labelsKTLEstUCB] = np.copy(Arm_rewards)

            # Save training data for Lin UCB
            train_datasetLinUCB = 'Train_Datasets_LinUCB' + str(int(aa))
            DataXY[train_datasetLinUCB] = np.copy(Total_Features)

            train_labelsLinUCB = 'Train_Labels_LinUCB' + str(int(aa))
            DataXY[train_labelsLinUCB] = np.copy(Arm_rewards)

            # Save training data for Pool UCB
            train_datasetPoolUCB = 'Train_Datasets_PoolUCB' + str(int(aa))
            DataXY[train_datasetPoolUCB] = np.copy(Total_Features)

            train_labelsPoolUCB = 'Train_Labels_PoolUCB' + str(int(aa))
            DataXY[train_labelsPoolUCB] = np.copy(Arm_rewards)

        #
        init_category_dict = {}
        shifted_category_dict = {}
        for i in range(A):
            for j in range(items_per_step):
                init_category_dict[tuple([i, j])] = category_dict.pop(tuple([i, j]))
        for i in range(A, context_matrix.shape[0]):
            for j in range(items_per_step):
                shifted_category_dict[tuple([i - A, j])] = category_dict.pop(tuple([i, j]))
        category_dict = shifted_category_dict

        # User context and arm context ---
        DataXY['context_matrix'] = context_matrix[A * N:, :items_per_step, :]
        DataXY['reward_matrix'] = reward_matrix[A * N:, :items_per_step]
        DataXY['reward_matrix_clean'] = reward_matrix_clean[A * N:, :items_per_step]

        DataXY['initContext'] = context_matrix[:A * N, :items_per_step, :]
        DataXY['init_reward_matrix'] = reward_matrix[:A * N, :items_per_step]
        DataXY['init_reward_matrix_clean'] = reward_matrix_clean[:A * N, :items_per_step]

        DataXY['theta'] = 0
        DataXY['init_Category_Dict'] = init_category_dict
        DataXY['Category_Dict'] = category_dict
        DataXY['d'] = int(context_matrix.shape[2])
        DataXY['item_pool_size'] = items_per_step
        print("d: ", DataXY['d'])
        DataXY['NoOfArms'] = A

    DataXY['data_flag'] = data_flag
    return DataXY


# ===========================================================

def AllDataCollect(DataXY, algorithm_flag):
    # Get total samples and samples in each dataset
    A = DataXY['NoOfArms']
    total_samples = 0
    # -- Modified from np.zeros([A, 1]) to np.zeros(A)
    samples_per_task = np.zeros(A)
    # print("Sample: ", samples_per_task)
    for i in range(0, A):
        if algorithm_flag == 'KTL-UCB-TaskSim':
            train_dataset = 'Train_Datasets_KTLUCB' + str(i)
        elif algorithm_flag == 'KTL-UCB-TaskSimEst':
            train_dataset = 'Train_Datasets_KTLEstUCB' + str(i)
        elif algorithm_flag == 'NEW-KTL-UCB-TaskSimEst':
            train_dataset = 'Train_Datasets_NEWKTLEstUCB' + str(i)
        elif algorithm_flag == 'Lin-UCB-Ind':
            train_dataset = 'Train_Datasets_LinUCB' + str(i)
        elif algorithm_flag == 'Lin-UCB-Pool':
            train_dataset = 'Train_Datasets_PoolUCB' + str(i)
        # print DataXY.keys()
        X = np.copy(DataXY[train_dataset])  # -- N x 2
        total_samples = total_samples + X.shape[0]  # samples_per_task * num_arms
        samples_per_task[i] = X.shape[0]  # --- CHANGING with the algorithm

    # Collect all labels and all features
    y = np.zeros(total_samples)
    X_total = np.zeros([total_samples, X.shape[1]])
    rr = 0
    for i in range(0, A):
        if algorithm_flag == 'KTL-UCB-TaskSim':
            train_labels = 'Train_Labels_KTLUCB' + str(i)
            train_dataset = 'Train_Datasets_KTLUCB' + str(i)
        elif algorithm_flag == 'KTL-UCB-TaskSimEst':
            train_labels = 'Train_Labels_KTLEstUCB' + str(i)
            train_dataset = 'Train_Datasets_KTLEstUCB' + str(i)
        elif algorithm_flag == 'NEW-KTL-UCB-TaskSimEst':
            train_labels = 'Train_Labels_NEWKTLEstUCB' + str(i)
            train_dataset = 'Train_Datasets_NEWKTLEstUCB' + str(i)
        elif algorithm_flag == 'Lin-UCB-Ind':
            train_labels = 'Train_Labels_LinUCB' + str(i)
            train_dataset = 'Train_Datasets_LinUCB' + str(i)
        elif algorithm_flag == 'Lin-UCB-Pool':
            train_labels = 'Train_Labels_PoolUCB' + str(i)
            train_dataset = 'Train_Datasets_PoolUCB' + str(i)

        # Arm rewards ---
        labels = np.copy(DataXY[train_labels])
        y[rr:rr + labels.shape[0]] = np.copy(DataXY[train_labels])
        X_total[rr:rr + labels.shape[0], :] = np.copy(DataXY[train_dataset])
        rr = rr + labels.shape[0]

    # total_samples -> samples_per_task * num_arms --- N * A
    # samples_per_task -> samples_per_arm,
    # y -> labels, (N*A, 1)
    # X_total -> total_dataset ---- (N*A, 2)

    return total_samples, samples_per_task, y, X_total


def AddData(DataXY, arm_tt, algorithm_flag, X_test, reward_test, tt):
    if algorithm_flag == 'KTL-UCB-TaskSim':
        train_labels = 'Train_Labels_KTLUCB' + str(arm_tt)
        train_dataset = 'Train_Datasets_KTLUCB' + str(arm_tt)
        test_label = 'Test_Labels_KTLUCB'
        last_roundXTest = 'Test_Datasets_KTLUCB'
    elif algorithm_flag == 'KTL-UCB-TaskSimEst':
        train_labels = 'Train_Labels_KTLEstUCB' + str(arm_tt)
        train_dataset = 'Train_Datasets_KTLEstUCB' + str(arm_tt)
        test_label = 'Test_Labels_KTLEstUCB'
        last_roundXTest = 'Test_Datasets_KTLEstUCB'
    elif algorithm_flag == 'NEW-KTL-UCB-TaskSimEst':
        train_labels = 'Train_Labels_NEWKTLEstUCB' + str(arm_tt)
        train_dataset = 'Train_Datasets_NEWKTLEstUCB' + str(arm_tt)
        test_label = 'Test_Labels_NEWKTLEstUCB'
        last_roundXTest = 'Test_Datasets_NEWKTLEstUCB'
    elif algorithm_flag == 'Lin-UCB-Ind':
        train_labels = 'Train_Labels_LinUCB' + str(arm_tt)
        train_dataset = 'Train_Datasets_LinUCB' + str(arm_tt)
        test_label = 'Test_Labels_LinUCB'
        last_roundXTest = 'Test_Datasets_LinUCB'
    elif algorithm_flag == 'Lin-UCB-Pool':
        train_labels = 'Train_Labels_PoolUCB' + str(arm_tt)
        train_dataset = 'Train_Datasets_PoolUCB' + str(arm_tt)
        test_label = 'Test_Labels_PoolUCB'
        last_roundXTest = 'Test_Datasets_PoolUCB'

    Total_Features = np.copy(DataXY[train_dataset])
    Arm_rewards = np.copy(DataXY[train_labels])

    Total_Features = np.append(Total_Features, X_test, axis=0)
    reward_test = np.ones([1]) * reward_test
    Arm_rewards = np.append(Arm_rewards, reward_test, axis=0)

    DataXY[train_dataset] = np.copy(Total_Features)
    DataXY[train_labels] = np.copy(Arm_rewards)
    DataXY[last_roundXTest] = np.copy(X_test)

    if tt == 0:
        armSelectedTT = np.ones([1]) * arm_tt  # np.empty([0])
    else:
        armSelectedTT = np.copy(DataXY[test_label])
        armSelectedTT = np.append(armSelectedTT, np.ones([1]) * arm_tt, axis=0)
    armSelectedTT = armSelectedTT.astype(int)

    DataXY[test_label] = np.copy(armSelectedTT)

    return DataXY
