from __future__ import division
import numpy as np
from utils import dirichlet, q_value_iteration, to_ssp_kernel, extended_value_iteration, compute_pivot_horizon, timeit
from copy import deepcopy


class SSPAgent(object):
    """
    A superclass for SSP Agents.
    """

    def __init__(self, env):
        self.env = env
        self.state = None
        self.action = None

        self.t = 1
        self.k = 0
        self.n = np.zeros([self.env.nb_states, self.env.nb_actions], dtype=int)
        self.error = False

    def act(self, state):
        raise NotImplementedError

    def update(self, next_state, cost): 
        self.n[self.state, self.action] += 1
        self.t += 1

    def info(self):
        raise NotImplementedError

    def reset_episode(self):
        self.k += 1

    def reset(self):
        self.k = 0
        self.t = 1
        self.n = np.zeros([self.env.nb_states, self.env.nb_actions], dtype=int)


# LCB-ADVANTAGE-SSP
class HQLAgent(SSPAgent):
    def __init__(self, env, iota=1.0, H=15, refc=10000):
        super().__init__(env=env)
        self.iota = iota
        self.B = np.max(env.optimal_cost())
        self.cmin = np.min(env.costs[:-1])
        self.H = H
        assert refc == 2 ** int(np.log2(refc))
        self.refC = refc
        self.reset()

    def act(self, state):
        self.state = state
        self.action = np.argmin(self.Q[self.state])
        return self.action

    def reset(self):
        super().reset()
        SA = (self.env.nb_states, self.env.nb_actions)
        self.N = np.zeros(SA)
        self.M = np.zeros(SA)
        self.Q = np.zeros(SA)
        self.V = np.zeros(self.env.nb_states)
        self.refV = np.zeros(self.env.nb_states)
        self.refmu = np.zeros(SA)
        self.mu = np.zeros(SA)
        self.nu = np.zeros(SA) # type II
        self.refsigma = np.zeros(SA)
        self.sigma = np.zeros(SA)
        self.e = np.full(SA, self.H, dtype=int)
        self.pi = np.zeros(self.env.nb_states, dtype=int)
 
    @timeit.time_f('HQL.update')
    def update(self, next_state, cost):
        sa = (self.state, self.action)
        self.N[sa] += 1
        self.M[sa] += 1
        n, m = self.N[sa], self.M[sa]
        self.refmu[sa] += self.refV[next_state]
        self.mu[sa] += self.V[next_state] - self.refV[next_state]
        self.refsigma[sa] += self.refV[next_state] ** 2
        self.sigma[sa] += (self.V[next_state] - self.refV[next_state]) ** 2
        self.nu[sa] += self.V[next_state]
        if m == self.e[sa]:
            refvar = self.refsigma[sa] / n - (self.refmu[sa] / n)**2 + 1e-10
            var = self.sigma[sa] / m - (self.mu[sa] / m)**2 + 1e-10
            # type I
            b = self.iota * (np.sqrt(refvar / n) + np.sqrt(var / m) + self.B / n + self.B / m) 
            Q1 = cost + self.refmu[sa] / n + self.mu[sa] / m - b
            # type II
            b = self.iota * self.B / np.sqrt(m)
            Q2 = cost + self.nu[sa] / m - b
            self.Q[sa] = np.maximum(np.maximum(Q1, Q2), self.Q[sa])
            self.V[self.state] = np.min(self.Q[self.state])
            self.pi[self.state] = np.argmin(self.Q[self.state])
            self.nu[sa], self.mu[sa], self.sigma[sa], self.M[sa] = 0, 0, 0, 0
            self.e[sa] = int((1 + 1 / self.H) * self.e[sa])
        ns = np.sum(self.N[self.state])
        if 0 < ns <= self.refC and self.refC % ns == 0:
            self.refV[self.state] = self.V[self.state]
        super().update(next_state, cost)

    def info(self):
        return 'name = HQLAgent\n'


class SVIAgent(SSPAgent):
    def __init__(self, env, H, iota):
        super().__init__(env=env)
        self.B = np.max(env.optimal_cost())
        self.H = H
        self.iota = iota
        self.reset()

    def act(self, state):
        self.state = state
        self.action = self.pi[state]
        return self.action

    @timeit.time_f('SVI.update')
    def update(self, next_state, cost):
        sa = (self.state, self.action)
        self.Nt[sa][next_state] += 1
        self.Nsa[sa] += 1
        # entrywise update on Q
        if self.Nsa[sa] == self.E[sa]:
            Psa = self.Nt[sa] / self.Nsa[sa]
            n = self.Nsa[sa]
            v = np.maximum(np.dot(Psa, self.V ** 2) - np.dot(Psa, self.V) ** 2, 0)
            b = np.maximum(6 * np.sqrt(v * self.iota / n), 36 * self.B * self.iota / n)
            self.Q[sa] = np.maximum(self.env.costs[sa] + np.dot(Psa, self.V) - b, self.Q[sa])
            self.V[self.state] = np.min(self.Q[self.state])
            self.pi[self.state] = np.argmin(self.Q[self.state])
            self.E[sa] += 1 + self.E[sa] // self.H # \tile = 1 + 1/H*E_j
        super().update(next_state, cost)

    def info(self):
        return 'name = SVIAgent\n'

    def reset(self):
        super().reset()
        self.Nt = np.zeros((self.env.nb_states, self.env.nb_actions, self.env.nb_states)) # transition
        self.Nsa = np.zeros((self.env.nb_states, self.env.nb_actions)) # s, a
        self.Q = np.zeros((self.env.nb_states, self.env.nb_actions))
        self.V = np.zeros(self.env.nb_states)
        self.E = np.ones((self.env.nb_states, self.env.nb_actions), dtype=int)
        self.pi = np.zeros(self.env.nb_states, dtype=int)


class ULCAgent(SSPAgent):
    def __init__(self, env, H, iota):
        super().__init__(env)
        self.H = H
        self.iota = iota
        self.B = np.max(env.optimal_cost())
        self.tc = 8 * self.B # terminal cost
        self.hc = 9 * self.B # loop-free maximum cost
        self.reset()
        self.i = 0

    def act(self, state):
        self.state = state
        self.action = self.pi[self.h, self.state]
        return self.action

    def compute_pi(self):
        SA = (self.env.nb_states, self.env.nb_actions)
        uJ = np.zeros((self.H+1, self.env.nb_states))
        lJ = np.zeros((self.H+1, self.env.nb_states))
        uJ[self.H, :-1] = self.tc
        lJ[self.H, :-1] = self.tc
        uQ = np.zeros((self.H, *SA))
        lQ = np.zeros((self.H, *SA))
        pi = np.zeros((self.H, self.env.nb_states), dtype=int)
        Np = np.maximum(self.Nsa, 1)
        P = self.Nt / Np[:,:, None]
        for h in range(self.H-1, -1, -1):
            var = P.dot(lJ[h+1]**2) - (P.dot(lJ[h+1]))**2 + 1e-10
            b = self.iota * (np.sqrt(var/Np) + 1/Np) + 1/(16*self.H) * P.dot(uJ[h+1] - lJ[h+1])
            assert np.all(b >= 0) and np.all(uJ >= lJ)
            uQ[h] = self.env.costs + P.dot(uJ[h+1]) + b
            lQ[h] = self.env.costs + P.dot(lJ[h+1]) - b
            pi[h] = np.argmin(lQ[h], axis=-1)
            uJ[h] = np.minimum(uQ[h, range(self.env.nb_states), pi[h]], self.hc)
            lJ[h] = np.maximum(lQ[h, range(self.env.nb_states), pi[h]], 0)
        return uJ, lJ, pi
 
    @timeit.time_f('ULC.update')
    def update(self, next_state, cost):
        self.Nt[self.state, self.action, next_state] += 1
        self.Nsa[self.state, self.action] += 1
        if (self.h == self.H - 1) or (next_state == self.env.nb_states - 1):
            if self.h == self.H - 1: self.n_hit += 1
            uJ, lJ, self.pi = self.compute_pi()
            self.h = 0
            self.i += 1
        else:
            self.h += 1
        super().update(next_state, cost)

    def reset(self):
        super().reset()
        self.Nt = np.zeros((self.env.nb_states, self.env.nb_actions, self.env.nb_states)) # transition
        self.Nsa = np.zeros((self.env.nb_states, self.env.nb_actions))
        self.pi = np.random.randint(self.env.nb_actions, size=(self.H, self.env.nb_states))
        self.h = 0
        self.n_hit = 0

    def info(self):
        return 'name = ULCAgent\n'

    def reset_episode(self): # debug
        assert self.h == 0
        super().reset_episode()


class QLearningAgent(SSPAgent):
    """
    This agent implements the standard Q-learning algorithm for SSP with epsilon-greedy exploration.
    """

    def __init__(self, env, epsilon):
        super(QLearningAgent, self).__init__(env=env)
        self.q = np.zeros([self.env.nb_states, self.env.nb_actions])
        self.epsilon = epsilon

    def act(self, state):
        self.state = state
        tmp = np.random.rand()
        if tmp < self.epsilon:
            self.action = np.random.choice(self.env.actions)
        else:
            self.action = self.pi[self.state]
        return self.action
 
    @timeit.time_f('QL.update')
    def update(self, next_state, cost):
        self.n[self.state, self.action] += 1
        alpha = 1.0/(self.n[self.state, self.action])
        self.q[self.state, self.action] = (1 - alpha) * self.q[self.state, self.action] + alpha * (cost + np.min(self.q[next_state]))
        self.pi[self.state] = np.argmin(self.q[self.state])
        self.t += 1

    def info(self):
        return 'name = QLearningAgent\n' + 'epsilon = {0}\n'.format(self.epsilon)

    def reset(self):
        super().reset() 
        self.q = np.zeros([self.env.nb_states, self.env.nb_actions])
        self.pi = np.zeros(self.env.nb_states, dtype=int)


class SSPBernsteinAgent(SSPAgent):
    """
    Implements SSP Bernstein, i.e., Algorithm 2 of Paper: https://arxiv.org/pdf/2002.09869.pdf
    """

    def __init__(self, env, c):
        super(SSPBernsteinAgent, self).__init__(env=env)
        self.c = c

    def act(self, state):
        self.state = state
        self.action = self.policy[self.state]
        return self.action

    @timeit.time_f('BS.update')
    def update(self, next_state, cost):
        super(SSPBernsteinAgent, self).update(next_state, cost)
        self.n_prime[self.state, self.action, next_state] += 1
        self.n_episode[self.state, self.action] += 1
        if next_state != self.env.nb_states - 1 and self.n_episode[next_state, self.policy[next_state]] >= 0.5 * self.n[next_state, self.policy[next_state]]:
            self.n_episode = np.zeros([self.env.nb_states, self.env.nb_actions], dtype=int)
            self.policy, _ = extended_value_iteration(self.env.costs, self.c, self.n, self.n_prime)

    def info(self):
        return 'name = SSPBernsteinAgent\n' + 'c = {0}\n'.format(self.c)

    def reset(self):
        super(SSPBernsteinAgent, self).reset()
        self.n_episode = np.zeros([self.env.nb_states, self.env.nb_actions], dtype=int)
        self.n_prime = np.zeros([self.env.nb_states, self.env.nb_actions, self.env.nb_states], dtype=int)
        self.policy = np.random.choice(self.env.actions, size=self.env.nb_states)


class UCSSPAgent(SSPAgent):
    """
    Implements UC-SSP, i.e., Algorithm 1 of Paper: https://arxiv.org/pdf/1912.03517.pdf
    """

    def __init__(self, env, c):
        super(UCSSPAgent, self).__init__(env=env)
        self.c = c
        sp_costs = np.ones_like(self.env.costs)
        sp_costs[-1, :] = 0
        self.sp_costs = sp_costs

    def act(self, state):
        self.state = state
        self.action = self.policy[self.state]
        return self.action

    def update(self, next_state, cost):
        super().update(next_state, cost)
        self.remain_steps -= 1
        self.n_prime[self.state, self.action, next_state] += 1
        if self.remain_steps == 0:
            self.g += 1
            self.compute_policy(first=False)

    def info(self):
        return 'name = UCSSPAgent\n' + 'c = {0}\n'.format(self.c)

    def reset(self):
        super().reset()
        self.n_prime = np.zeros([self.env.nb_states, self.env.nb_actions, self.env.nb_states], dtype=int)
        self.compute_policy(first=True)
        self.g = 0

    def reset_episode(self):
        super().reset_episode()
        self.compute_policy(first=True)

    @timeit.time_f('UCSSP.update')
    def compute_policy(self, first):
        if first:
            costs = self.env.costs
            eps = np.min(costs[:-1]) / (2 * self.t)
            gamma = 1e-6 
        else:
            costs = self.sp_costs
            eps = 1 / (2 * self.t)
            gamma = 1e-6 
        eps, gamma = max(eps, 1e-6), max(gamma, 1e-6)
        self.policy, tilp = extended_value_iteration(costs, self.c, self.n, self.n_prime, eps=eps)
        self.remain_steps = compute_pivot_horizon(self.policy, tilp, gamma)


class EBAgent(SSPAgent):
    """
    Implements EB-SSP, i.e., Algorithm 1 of Paper: https://arxiv.org/pdf/2104.11186.pdf
    """

    def __init__(self, env, iota):
        super(EBAgent, self).__init__(env=env)
        self.B = np.max(env.optimal_cost())
        self.iota = iota

    def act(self, state):
        self.state = state
        self.action = self.policy[self.state]
        return self.action

    @timeit.time_f('EB.update')
    def update(self, next_state, cost):
        super().update(next_state, cost)
        self.Nsa[self.state, self.action] += 1
        self.Nt[self.state, self.action, next_state] += 1
        n = self.Nsa[self.state, self.action]
        if not (n & (n - 1)): # n is a power of 2
            self.j += 1
            eps = max(0.5 ** self.j, 1e-6)
            self.compute_policy(eps)
        super().update(next_state, cost)

    def info(self):
        return 'name = EBAgent\n'

    def reset(self):
        super().reset()
        self.j = 0
        self.Nt = np.zeros((self.env.nb_states, self.env.nb_actions, self.env.nb_states), dtype=int) # transition
        self.Nsa = np.zeros((self.env.nb_states, self.env.nb_actions), dtype=int) # s, a
        self.policy = np.zeros(self.env.nb_states, dtype=int)

    def compute_policy(self, eps):
        Np = np.maximum(self.Nsa, 1)
        Q, V = np.zeros((self.env.nb_states, self.env.nb_actions)), np.zeros(self.env.nb_states)
        while True:
            gv = np.zeros(self.env.nb_states)
            gv[-1] = 1.0
            P = (self.Nt + gv[None, None, :]) / (self.Nsa + 1)[:, :, None]
            v = np.maximum(np.dot(P, V ** 2) - np.dot(P, V) ** 2, 0)
            b = np.maximum(6*np.sqrt(v*self.iota/Np), 36*self.B*self.iota/Np) + 2*np.sqrt(2)*self.B*np.sqrt((self.env.nb_states+1)*self.iota)/Np
            nQ = np.maximum(self.env.costs + np.dot(P, V) - b, 0)
            nV = np.min(nQ, axis=-1)
            if np.abs(V - nV).max() <= eps:
                Q, V = nQ, nV
                break
            Q, V = nQ, nV
        self.policy = Q.argmin(axis=-1)
