"""
This file defines several different different decision making algorithms
"""
import abc

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class BaseDecisionMaker(metaclass=abc.ABCMeta):
    name = 'BaseDM'

    def __init__(self, number_of_arms, settings, *args, **kwargs):
        self.settings = settings
        self.number_of_arms = number_of_arms
        self.out_dim = number_of_arms
        self.learning_rate = self.settings.dm_lr
        self.exploration = settings.exploration
        self.exploration_gamma = settings.exploration_gamma

        # storage
        self.values = np.zeros(number_of_arms)
        self.critic = np.zeros(number_of_arms)
        self.policy = np.zeros(number_of_arms)

        self.time = 0

    def step(self):
        self.time += 1
        self.reduce_exploration()
        return self._step()

    @abc.abstractmethod
    def _step(self):
        pass

    @abc.abstractmethod
    def update(self, action, reward, next_action):
        pass

    @property
    def probabilities(self):
        return [0.]*self.out_dim

    def reduce_exploration(self):
        self.exploration = self.exploration * (self.exploration_gamma / (self.exploration_gamma + self.time))


class QLearning(BaseDecisionMaker):
    name = 'QLearning'

    def __init__(self, number_of_arms, settings, next_qlearner, *args, **kwargs):
        BaseDecisionMaker.__init__(self, number_of_arms, settings)
        self.next_state = next_qlearner

    def _step(self):
        if np.random.random() < self.settings.exploration:
            action = np.random.randint(self.out_dim)
        else:
            action = np.argmax(self.values)
        return action

    def update(self, action, reward, next_action):
        if isinstance(reward, torch.Tensor):
            reward = reward.item()
        if self.next_state is None:
            error = reward - self.values[action]
        else:
            error = np.max(self.next_state.values) + reward - self.values[action]
        self.values[action] = self.values[action] + self.learning_rate * error
        return error


class SARSA(QLearning):
    name = 'SARSA'

    def update(self, action, reward, next_action):
        if isinstance(reward, torch.Tensor):
            reward = reward.item()
        if self.next_state is None:
            error = reward - self.values[action]
        else:
            assert next_action is not None
            error = self.next_state.values[next_action] + reward - self.values[action]
        self.values[action] = self.values[action] + self.learning_rate * error
        return error


class REINFORCE(BaseDecisionMaker):
    name = 'REINFORCE'

    @property
    def probabilities(self):
        return np.exp(self.policy) / sum(np.exp(self.policy))

    def _step(self):
        action = np.argmax(np.random.multinomial(1, pvals=self.probabilities))
        return action

    def update(self, action, reward, next_action):
        gradient = - np.log(self.probabilities[action]) * reward
        self.policy[action] = self.policy[action] + self.learning_rate * gradient
        return gradient


class EXP3(BaseDecisionMaker):
    name = 'EXP3'

    def __init__(self, number_of_arms, settings, *args, **kwargs):
        BaseDecisionMaker.__init__(self, number_of_arms, settings)
        self._number_of_arms = number_of_arms
        self.reset()

    def _step(self):
        self.values[0, :] = (1 - self.settings.exp3_gamma) * (self.weights) / (
            np.sum(self.weights)) + self.settings.exp3_gamma / self._number_of_arms
        p = self.values[0, :] / self.values[0, :].sum()
        try:
            action = np.random.choice(self._number_of_arms, p=p)
        except ValueError:
            print('Error in weights:')
            print(self.values)
            action = 0
        unvisited = np.where(self.values[1, :] == 0.)
        return unvisited[0][0] if unvisited[0].size > 0 else action

    def update(self, action, reward, next_action):
        xhat = np.zeros((1, self._number_of_arms))
        xhat[0, action] = reward / self.values[0, action]
        self.weights = self.weights * np.exp(self.settings.exp3_gamma * xhat / self._number_of_arms)
        self.values[1, action] += 1.
        return 0.

    @property
    def probabilities(self):
        return self.values[0, :]

    def reset(self):
        self.values = np.zeros((2, self._number_of_arms))
        self.weights = np.ones((1, self._number_of_arms))
        self.time = 0
        return