import numpy as np
import scipy as sp
import torch
import torch.nn as nn
import torch.optim as optim
import DataCreate.DataCreate as DC
import argparse
import copy
import time
from datetime import datetime
import sys
from sklearn.preprocessing import normalize
import random


# Logger
# Recording console output
class Logger(object):
    def __init__(self, stdout, dataset, portion, intensity):
        now = datetime.now()
        dt_string = now.strftime("%d-%m-%Y_%H-%M-%S")
        self.terminal = sys.stdout
        self.log = open("./R-Neural-UCB-logs/logfile_" + dt_string +
                        "_R-Neural-UCB-logs_{}_{}_{}.log".format(dataset, portion, intensity), "w")
        self.out = stdout
        print("date and time =", dt_string)

    def write(self, message):
        self.log.write(message)
        self.log.flush()
        self.terminal.write(message)

    def flush(self):
        pass


def set_seed(seed):
    """Sets seed"""
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


class Network(nn.Module):
    def __init__(self, dim, hidden_size=100):
        super(Network, self).__init__()
        self.hidden_size = hidden_size
        self.input_dim = dim

        self.fc1 = nn.Linear(dim, hidden_size)
        self.activate = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, 1)

    def init_param_paper(self):
        fc_1_wieghts_diag = np.random.normal(loc=0, scale=(4 / self.hidden_size),
                                             size=(self.hidden_size // 2, self.input_dim // 2))
        fc_1_weights = np.zeros((self.hidden_size, self.input_dim))
        fc_1_weights[0:self.hidden_size // 2, 0:self.input_dim // 2] = fc_1_wieghts_diag
        fc_1_weights[self.hidden_size // 2:, self.input_dim // 2:] = fc_1_wieghts_diag
        self.fc1.weight = nn.Parameter(torch.from_numpy(fc_1_weights).float())

        fc_2_weights_half = np.random.normal(loc=0, scale=(2 / self.hidden_size), size=(1, self.hidden_size // 2))
        fc_2_weights = np.concatenate((fc_2_weights_half, -1 * fc_2_weights_half), axis=1)
        self.fc2.weight = nn.Parameter(torch.from_numpy(fc_2_weights).float())

    def forward(self, x):
        return self.fc2(self.activate(self.fc1(x)))


class NeuralUCBDiag:
    def __init__(self, dim, lamdba=1, nu=1, hidden=100, sample_num=1000, base_lr=1e-2):
        self.func = Network(dim, hidden_size=hidden).cuda()
        self.input_dim = dim
        self.context_list = []
        self.reward = []
        self.lr_list = []
        self.lamdba = lamdba
        self.base_lr = base_lr
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)
        print("Total param num: ", self.total_param)
        self.U = lamdba * torch.ones((self.total_param,)).cuda()
        self.lambda_m = lamdba * torch.ones((self.total_param,)).cuda()
        self.raw_U = torch.zeros((self.total_param,)).cuda()
        # self.U = lamdba * torch.ones((self.total_param,))
        self.nu = nu
        self.sample_num = sample_num

    def init_context_list(self, init_context, init_rewards, A, long_vector_flag=False):
        for i in range(A):
            context, reward = init_context[i, i, :], init_rewards[i, i]
            if long_vector_flag:
                init_dim = int(self.input_dim / A)
                long_context = np.zeros((1, self.input_dim))
                long_context[0, i * init_dim:(i + 1) * init_dim] = context
                context = long_context
            self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
            self.reward.append(reward)
            self.lr_list.append(0.1)

    def init_kernel_matrix(self, A):
        print("Initializing kernel matrix...")
        for c_i in range(A):
            # Update kernel matrix with initial contexts
            tensor = self.context_list[c_i].cuda()
            tensor = tensor.squeeze(dim=0)
            mu = self.func(tensor)

            self.func.zero_grad()
            mu.backward(retain_graph=True)

            g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
            self.U += g * g

            optimizer = optim.SGD(self.func.parameters(), lr=self.base_lr, weight_decay=self.lamdba)
            length = len(self.reward)
            index = np.arange(length)
            np.random.shuffle(index)
            cnt = 0
            tot_loss = 0

            batch_loss = 0
            for idx in index:
                c = self.context_list[idx]
                r = self.reward[idx]
                optimizer.zero_grad()
                delta = self.func(c.cuda()) - r
                loss = delta * delta
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                tot_loss += loss.item()
                cnt += 1
                if cnt >= self.sample_num:
                    print("Avg loss: ", tot_loss / self.sample_num)
                    break

    # --------------------------------------------------------------------------------------------------------

    def select(self, contexts, t, arm_exploitation_coef=1, weighted_lr_alpha=0.5):
        # Keep original copy
        original_model = copy.deepcopy(self.func)
        #
        tensor = torch.from_numpy(contexts).float().cuda()
        g_list, weight_array = self.derive_arm_gradients_and_weights(contexts, alpha_coef=weighted_lr_alpha)
        #

        sampled = []
        for i in range(contexts.shape[0]):
            this_tensor = tensor[i, :]
            g = g_list[i]
            this_weight = weight_array[i]
            self.train(weight=this_weight, new_sample_num=100)
            fx = self.func(this_tensor)

            # CB square
            sigma2 = self.lamdba * self.nu * g * g / (self.lambda_m + (this_weight * self.raw_U))
            # sigma2 = self.lamdba * self.nu * g * g / self.U
            sigma = torch.sqrt(torch.sum(sigma2))

            sample_r = (arm_exploitation_coef * fx.item()) + sigma.item()

            if (t + 1) % 100 == 0:
                print("Arm: {}, weight: {}, f_hat: {}, UCB: {}".format(i, this_weight, fx.item(), sigma.item()))

            sampled.append(sample_r)

            #
            self.func = copy.deepcopy(original_model)
        arm = np.argmax(sampled)
        self.U += (g_list[arm] * g_list[arm])
        self.raw_U += (g_list[arm] * g_list[arm])

        return arm, g_list[arm].norm().item()

    def derive_arm_gradients_and_weights(self, contexts, alpha_coef=0.5):
        tensor = torch.from_numpy(contexts).float().cuda()
        mu = self.func(tensor)
        g_list = []
        weight_list = []
        for i, fx in enumerate(mu):
            self.func.zero_grad()
            fx.backward(retain_graph=True)

            g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
            g_list.append(g)
            #
            weight2 = self.lamdba * g * g / self.U
            # weight2 = self.lamdba * self.nu * g * g / self.U
            this_weight = torch.sqrt(torch.sum(weight2)).item()
            weight_list.append(this_weight)

        weight_array = np.array(weight_list)
        weight_array = np.minimum(np.ones(weight_array.shape),
                                  alpha_coef * (np.min(weight_array) / weight_array))

        return g_list, weight_array

    # --------------------------------------------------------------------------------------------------------

    def save_info(self, context, reward):
        self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
        self.reward.append(reward)

    def train(self, weight=None, new_sample_num=None):
        if weight is not None:
            optimizer = optim.SGD(self.func.parameters(), lr=self.base_lr, weight_decay=self.lamdba / weight)
        else:
            optimizer = optim.SGD(self.func.parameters(), lr=self.base_lr, weight_decay=self.lamdba)
        if new_sample_num is not None:
            sample_num = new_sample_num
        else:
            sample_num = self.sample_num
        #
        length = len(self.reward)
        index = np.arange(length)
        np.random.shuffle(index)
        cnt = 0
        tot_loss = 0

        while True:
            batch_loss = 0
            for idx in index:
                c = self.context_list[idx]
                r = self.reward[idx]

                #
                optimizer.zero_grad()
                delta = self.func(c.cuda()) - r
                loss = delta * delta
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                tot_loss += loss.item()
                cnt += 1
                if cnt >= sample_num:
                    return tot_loss / sample_num

            if batch_loss / length <= 1e-3:
                return batch_loss / length


# Output -> X:(A, A*dim)
def generate_vec(t, context, num_dim, num_arm):
    X = np.zeros((num_arm, num_arm * num_dim))
    for a in range(num_arm):
        X[a, a * num_dim:(a + 1) * num_dim] = context[t, a, :]

    return X


# Generate long vector for contexts
def generate_long_vec_category(category_list, context, init_dim, num_arm):
    row_num = len(category_list)
    this_X = np.zeros([row_num, init_dim * num_arm])
    for i, category in enumerate(category_list):
        arm_index = category  # Category index start from 0
        this_X[i, arm_index * init_dim:(arm_index + 1) * init_dim] = context

    return this_X


if __name__ == '__main__':
    """
    Parameters:
    Offset: 100 + 0.1 + extend + not double
    """

    parser = argparse.ArgumentParser(description='R-NeuralUCB')
    # nu value: 0.01 / 0.001 / 0.0001
    parser.add_argument('--nu', type=float, default=0.01, metavar='v', help='nu for control variance')
    parser.add_argument('--lamdba', type=float, default=0.001, metavar='l', help='lambda for regularization')
    parser.add_argument('--device_index', type=int, default=1)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--data_flag', type=int, default=12)
    # hidden size:
    parser.add_argument('--hidden', type=int, default=200, help='network hidden size')
    # Arm Exploitation:
    parser.add_argument('--arm_exploitation_coef', type=float, default=1)
    #
    parser.add_argument('--weighted_lr_alpha', type=float, default=1)
    parser.add_argument('--base_lr', type=float, default=1e-2)
    #
    parser.add_argument('--noise_portion', type=float, default=0.2)
    parser.add_argument('--noise_intensity', type=float, default=0.3)

    #
    args = parser.parse_args()
    #
    data_flag = args.data_flag
    if data_flag == 77:
        data_flag_multiclass = 'MNIST'
        A= 10
    elif data_flag == 12:
        data_flag_multiclass = 'MovieLens'
        A = 10
    elif data_flag == 13:
        data_flag_multiclass = 'Amazon'
        A = 10
    else:
        data_flag_multiclass = None

    #
    sys.stdout = Logger(sys.stdout, data_flag_multiclass, args.noise_portion, args.noise_intensity)
    print(args)
    set_seed(seed=args.seed)

    # =======================
    torch.cuda.set_device(args.device_index)

    # --------------------------------
    items_per_step = 10

    #
    multi_estimators = False
    sample_num = 1000

    # Dataset
    data_flag = args.data_flag
    N, N_valid = 1, 10
    T = 10000

    print("Data set: ", data_flag_multiclass)
    print(multi_estimators, sample_num, items_per_step)

    randomSeedsTest = np.array([15486101, 15486511, 15486883, 15487271,
                                15486139, 15486517, 15486893, 15487291,
                                15486157, 15486533, 15486907, 15487309,
                                15486173, 15486557, 15486917, 15487313,
                                15486181, 15486571, 15486929, 15487319,
                                15486193, 15486589, 15486931, 15487331,
                                15486209, 15486649, 15486953, 15487361,
                                15486221, 15486671, 15486967, 15487399,
                                15486227, 15486673, 15486997, 15487403,
                                15486241, 15486703, 15487001, 15487429,
                                15486257, 15486707, 15487007, 15487457,
                                15486259, 15486719, 15487019, 15487469])

    RunNumber = 0
    Main_Program_flag = 1

    # Get the train data. This is just one example assigned to each arm randomly when N = 1 (cold start)
    Basic_DataXY = DC.TrainDataCollect(data_flag, A, N_valid, N, T, randomSeedsTest[RunNumber], RunNumber,
                                       Main_Program_flag,
                                       noise_portion=args.noise_portion, noise_intensity=args.noise_intensity,
                                       items_per_step=items_per_step)

    # ==================================================================================================================
    if data_flag == 5 or data_flag == 7 or data_flag == 77:
        # ==============================================================================================
        assert A == Basic_DataXY['NoOfArms']

        # Real-world classification tasks -- MNIST
        Features_train = Basic_DataXY['TrainContexts']
        Features_test = Basic_DataXY['TestContexts']
        Labels_train_matrix = Basic_DataXY['TrainLabels']
        Labels_test_matrix = Basic_DataXY['TestLabels']
        Labels_train_matrix_clean = Basic_DataXY['TrainLabels_clean']
        Labels_test_matrix_clean = Basic_DataXY['TestLabels_clean']
        A = Basic_DataXY['NoOfArms']

        input_dim = Features_train.shape[1]

        # -----------------------------------------------------------------------
        X = np.zeros((T, A, A * input_dim))
        init_X = np.zeros((A * N, A, A * input_dim))

        # (T, A, d) -----------------------------------------------------------------
        for i in range(T):
            for j in range(A):
                normalized_vec = normalize(Features_test[i, :].reshape(1, -1))
                #
                X[i, j, j * input_dim: (j + 1) * input_dim] = normalized_vec

        # init contexts
        for i in range(A * N):
            for j in range(A):
                normalized_vec = normalize(Features_train[i, :].reshape(1, -1))
                #
                init_X[i, j, j * input_dim: (j + 1) * input_dim] = normalized_vec

        # Reward matrix
        rewards_m = np.copy(Labels_test_matrix)
        init_reward_m = np.copy(Labels_train_matrix)
        #
        rewards_m_clean = np.copy(Labels_test_matrix_clean)
        init_reward_m_clean = np.copy(Labels_train_matrix_clean)

    elif data_flag == 10 or data_flag == 12 or data_flag == 13:
        context_matrix = Basic_DataXY['context_matrix']
        init_context_matrix = Basic_DataXY['initContext']
        category_dict = Basic_DataXY['Category_Dict']
        init_category_dict = Basic_DataXY['init_Category_Dict']
        #
        rewards_data_m = np.copy(Basic_DataXY['reward_matrix'])
        init_reward_data_m = np.copy(Basic_DataXY['init_reward_matrix'])
        #
        rewards_data_m_clean = np.copy(Basic_DataXY['reward_matrix_clean'])
        init_reward_data_m_clean = np.copy(Basic_DataXY['init_reward_matrix_clean'])

        # items_per_step = context_matrix.shape[1]

        # Number of categories
        A = Basic_DataXY['NoOfArms']

        input_dim = context_matrix.shape[2]

        # -----------------------------------------------------------------------
        X = np.zeros((T, items_per_step, input_dim))
        init_X = np.zeros((A * N, items_per_step, input_dim))
        #
        rewards_m = np.zeros((T, items_per_step))
        init_reward_m = np.zeros((A * N, items_per_step))
        #
        rewards_m_clean = np.zeros((T, items_per_step))
        init_reward_m_clean = np.zeros((A * N, items_per_step))

        # (T, A, d) -----------------------------------------------------------------
        for i in range(T):
            for j in range(items_per_step):
                normalized_vec = normalize(context_matrix[i, j, :].reshape(1, -1))
                #
                X[i, j, :] = normalized_vec
                rewards_m[i, j] = rewards_data_m[i, j]
                rewards_m_clean[i, j] = rewards_data_m_clean[i, j]

        # init contexts
        for i in range(A * N):
            for j in range(items_per_step):
                normalized_vec = normalize(init_context_matrix[i, j, :].reshape(1, -1))
                #
                init_X[i, j, :] = normalized_vec
                init_reward_m[i, j] = init_reward_data_m[i, j]
                init_reward_m_clean[i, j] = init_reward_data_m_clean[i, j]

        #
        # rewards_m = np.copy(Basic_DataXY['reward_matrix'])
        # init_reward_m = np.copy(Basic_DataXY['init_reward_matrix'])
    else:
        # Other data sets ==============================================================================================
        input_dim = Basic_DataXY['userContext'].shape[1]

        X = np.zeros((T, A, input_dim))
        init_X = np.zeros((A * N, A, input_dim))

        user_matrix = Basic_DataXY['userContext']
        arm_matrix = Basic_DataXY['armContext']

        init_user_matrix = Basic_DataXY['initUserContext']

        #
        for i in range(T):
            for j in range(A):
                normalized_vec = normalize(np.multiply(user_matrix[i, :], arm_matrix[j, :]).reshape(1, -1))

                X[i, j, :] = normalized_vec

        #
        for i in range(A * N):
            for j in range(A):
                normalized_vec = normalize(np.multiply(init_user_matrix[i, :], arm_matrix[j, :]).reshape(1, -1))

                init_X[i, j, :] = normalized_vec

        #
        rewards_m = np.copy(Basic_DataXY['reward_matrix'])
        init_reward_m = np.copy(Basic_DataXY['init_reward_matrix'])
        #
        rewards_m_clean = np.copy(Basic_DataXY['reward_matrix_clean'])
        init_reward_m_clean = np.copy(Basic_DataXY['init_reward_matrix_clean'])

    # ==================================================================================================================
    # torch.set_num_threads(8)
    # torch.set_num_interop_threads(8)
    # ------------------------------------------
    algorithm_flag = 'R-Neural-UCB'

    print("nu value: ", args.nu)

    if data_flag == 5 or data_flag == 7 or data_flag == 77:
        matrix_dim = input_dim * A
    else:
        matrix_dim = input_dim

    #
    n_UCB = NeuralUCBDiag(matrix_dim, args.lamdba, args.nu, hidden=args.hidden, sample_num=sample_num,
                          base_lr=args.base_lr)
    n_UCB.init_context_list(init_context=init_X, init_rewards=init_reward_m, A=A, long_vector_flag=multi_estimators)

    #
    # n_UCB.init_kernel_matrix(A=A)

    start_time = time.time()
    regrets = []
    summ = 0
    s_count = 0
    for t in range(T):
        # get context and rewards
        if multi_estimators:
            if data_flag == 10 or data_flag == 12 or data_flag == 13:
                # print("Get new long embedded contexts categories...")
                context = np.empty([0, input_dim * A])
                rwd_list = []
                for a_i in range(items_per_step):
                    normalized_vec = X[t, a_i, :]
                    category_list = list(Basic_DataXY['Category_Dict'][tuple([t, a_i])])
                    sampled_category = [random.choice(category_list)]
                    # this_contexts = generate_long_vec_category(category_list, normalized_vec, input_dim, A)
                    this_contexts = generate_long_vec_category(sampled_category, normalized_vec, input_dim, A)

                    # For each category of this arm, add an context and a reward
                    context = np.concatenate([context, this_contexts], axis=0)
                    rwd_list += [float(rewards_m[t, a_i])] * this_contexts.shape[0]
                rwd = np.array(rwd_list).reshape(-1, )
                # print("This item pool length: ", context.shape[0])
            else:
                context, rwd = generate_vec(t=t, context=X, num_dim=input_dim, num_arm=A), rewards_m[t, :].reshape((A,))
        else:
            if data_flag == 10 or data_flag == 12 or data_flag == 13:
                context, rwd = X[t, :, :].reshape((items_per_step, input_dim)), \
                               rewards_m[t, :].reshape((items_per_step,))
                clean_rwd = rewards_m_clean[t, :].reshape((items_per_step,))
            else:
                context, rwd = X[t, :, :].reshape((A, matrix_dim)), rewards_m[t, :].reshape((A,))
                clean_rwd = rewards_m_clean[t, :].reshape((A,))

        # Select arm
        arm_select, nrm = n_UCB.select(context, t, arm_exploitation_coef=args.arm_exploitation_coef,
                                       weighted_lr_alpha=args.weighted_lr_alpha)
        r = rwd[arm_select]
        best_arm = np.argwhere(rwd == np.max(rwd)).flatten()
        if arm_select in best_arm:
            s_count += 1
        # reg = np.max(rwd) - r
        reg = np.max(clean_rwd) - clean_rwd[arm_select]
        summ += reg

        #
        n_UCB.save_info(context[arm_select], r)
        if t < 10000:
            loss = n_UCB.train()
        else:
            if t % 100 == 0:
                loss = n_UCB.train()
        regrets.append(summ)

        # print("Selected arm: {}, best arm: {}".format(arm_select, best_arm))
        if (t + 1) % 100 == 0:
            print('Time {}: regret_sum: {:.3f}, loss: {:.3e}, nrm: {:.3e}'
                  .format(t + 1, summ, loss, nrm))
            print("Algorithm: ", algorithm_flag, ", Step: ", t + 1, "/", T, ", Time elapsed: ",
                  time.time() - start_time)
            print("Selected arm: {}, best arm: {}".format(arm_select, best_arm))
            print("Accuracy of ", algorithm_flag + ": ", str(s_count / t))
