# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Contextual algorithm based on boostrapping neural networks."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from bandits.core.bandit_algorithm import BanditAlgorithm
from bandits.core.contextual_dataset import ContextualDataset
from bandits.algorithms.neural_bandit_model import NeuralBanditModel

import random


def freq_to_idx(freq, is_shuffle=True):
  out = []
  for i in range(freq.shape[0]):
    out += [i] * int(freq[i])
  if is_shuffle:
    random.shuffle(out)
  return np.array(out)

def idx_to_freq(idx, n_total):
  freq = np.zeros(n_total)
  for i in list(idx):
    freq[i] += 1
  return freq

def extract_resample(data, select_idx):
  new_contexts = data.contexts[select_idx, :]
  new_actions = np.array(data.actions)
  new_actions = list(new_actions[select_idx])
  new_rewards = data.rewards[select_idx, :]
  return new_contexts, new_actions, new_rewards


class BootstrappedBNNSampling(BanditAlgorithm):
  """Thompson Sampling algorithm based on training several neural networks."""

  def __init__(self, name, hparams, optimizer='RMS', centroid_learn=False, centroid_eval_num=100, gamma=0.5):
    """Creates a BootstrappedSGDSampling object based on a specific optimizer.
      hparams.q: Number of models that are independently trained.
      hparams.p: Prob of independently including each datapoint in each model.
    Args:
      name: Name given to the instance.
      hparams: Hyperparameters for each individual model.
      optimizer: Neural network optimization algorithm.
    """

    self.name = name
    self.hparams = hparams
    self.optimizer_n = optimizer

    self.training_freq = hparams.training_freq
    self.training_epochs = hparams.training_epochs
    self.t = 0

    self.q = hparams.q
    self.p = hparams.p

    self.gamma = gamma

    self.centroid_learn = centroid_learn
    self.resample_freq = hparams.training_freq
    # self.resample_freq = 0
    self.centroid_eval_num = centroid_eval_num
    self.centroid_weights = np.ones(self.q)/(1.*self.q)
    self.centroid_place_holder = ContextualDataset(hparams.context_dim,
                                                   hparams.num_actions,
                                                   hparams.buffer_s)

    self.datasets = [
        ContextualDataset(hparams.context_dim,
                          hparams.num_actions,
                          hparams.buffer_s)
        for _ in range(self.q)
    ]

    self.all_observed = ContextualDataset(hparams.context_dim, hparams.num_actions, hparams.buffer_s)

    self.partial_observed = ContextualDataset(hparams.context_dim, hparams.num_actions, hparams.buffer_s)

    self.bnn_boot = [
        NeuralBanditModel(optimizer, hparams, '{}-{}-bnn'.format(name, i))
        for i in range(self.q)
    ]

  def action(self, context):
    """Selects action for context based on Thompson Sampling using one BNN."""

    if self.t < self.hparams.num_actions * self.hparams.initial_pulls:
      return self.t % self.hparams.num_actions

    if self.centroid_learn:
      elements = [i for i in range(self.q)]
      eps = 0.
      tau = 1.
      probability = np.exp(np.log(self.centroid_weights + eps)/tau)
      probability = probability/np.sum(probability)
      model_index = np.random.choice(elements, 1, p=probability)[0]
    else:
      model_index = np.random.randint(self.q)

    with self.bnn_boot[model_index].graph.as_default():
      c = context.reshape((1, self.hparams.context_dim))
      output = self.bnn_boot[model_index].sess.run(
          self.bnn_boot[model_index].y_pred,
          feed_dict={self.bnn_boot[model_index].x: c})
      return np.argmax(output)

  def update(self, context, action, reward):
    """Updates the data buffer, and re-trains the BNN every self.freq_update."""

    self.t += 1
    # update all_observed
    self.all_observed.add(context, action, reward)
    self.partial_observed.add(context, action, reward)

    # resample
    if self.centroid_learn:
      if self.resample_freq and self.t % self.resample_freq == 0:
        self.centroid_weights = np.zeros(self.q)
        self.centroid_perturbs = {}
        n_total = len(self.partial_observed.actions)
        for i in range(self.q):
          self.centroid_perturbs[i] = []

        centroid_loss_on_each_data = []
        for i in range(self.q):
          cost = self.bnn_boot[i].evaluate(self.partial_observed)
          centroid_loss_on_each_data.append(cost)

        for _ in range(self.centroid_eval_num):
          # construct perturbed dataset
          centroid_cost = []
          select_idx = np.random.choice(n_total, n_total)
          for i in range(self.q):
            centroid_cost.append(np.sum(centroid_loss_on_each_data[i][select_idx]))

          model_indx = centroid_cost.index(min(centroid_cost))
          self.centroid_perturbs[model_indx] += list(select_idx)
          self.centroid_weights[model_indx] += 1.

        self.normalized_centroid_weights = self.centroid_weights * 1. / np.sum(self.centroid_weights)
        # converting to normalized format
        for i in range(self.q):
          tmp_perturb = idx_to_freq(self.centroid_perturbs[i], n_total)
          tmp_perturb = (1.*tmp_perturb)/self.centroid_weights[model_indx]
          tmp_perturb = np.ceil(tmp_perturb)
          tmp_perturb = freq_to_idx(tmp_perturb)
          if tmp_perturb.shape[0] and self.normalized_centroid_weights[i] > self.gamma / self.q:
            new_contexts, new_actions, new_rewards = extract_resample(self.partial_observed, tmp_perturb)
            self.datasets[i].add_batch_data(new_contexts, new_actions, new_rewards)

          else:
            select_idx = np.arange(n_total)
            new_contexts, new_actions, new_rewards = extract_resample(self.partial_observed, select_idx)
            self.datasets[i].add_batch_data(new_contexts, new_actions, new_rewards)
    else:

      if self.resample_freq and self.t % self.resample_freq == 0:
        # original
        n_total = len(self.partial_observed.actions)
        for i in range(self.q):
          select_idx = np.random.choice(n_total, n_total)
          new_contexts, new_actions, new_rewards = extract_resample(self.partial_observed, select_idx)
          self.datasets[i].add_batch_data(new_contexts, new_actions, new_rewards)
    
    if self.resample_freq and self.t % self.resample_freq == 0:
      self.partial_observed.remove_all_data()

    if self.t % self.training_freq == 0:
      for i in range(self.q):
        if self.hparams.reset_lr:
          self.bnn_boot[i].assign_lr()
        self.bnn_boot[i].train(self.datasets[i], self.training_epochs)
