# -*- coding utf-8 -*-
# safeLSVE.py

# algorithm implementation

import numpy as np

class safeTrainer(object):
    def __init__(self, env, baseline=(None, None), N = 100, alpha = 0.1):
        super().__init__()

        self.env = env 
        self.N = N
        self.alpha = alpha
        self.beta_ucb = 1 
        self.beta_lcb = 1
        self.b = baseline[0]
        self.baseline_actions = baseline[1]
        self.Lambda = np.tile(np.identity(self.env.d), self.env.H).reshape([self.env.d, self.env.H, self.env.d]).transpose([1, 0, 2])
        self.LambdaInverse = np.linalg.inv(self.Lambda)

        self.UCBQ = np.zeros([self.env.H + 1, self.env.S, self.env.A])
        self.Q = np.zeros([self.env.H + 1, self.env.S, self.env.A])

        self.trajectory_s = []
        self.trajectory_a = []
        self.trajectory_r = []

    def epoch_train(self, algo = 'StepMix'):
        if algo == 'StepMix':
            return self.epoch_train_withrho()
        elif algo == 'StepNoMix':
            return self.epoch_train_norho()
        elif algo == 'UCB':
            return self.UCB_train()
        elif algo == 'nonmarkov':
            return self.epoch_train_nonmarkov()
    
    def epoch_train_withrho(self,):
        for h in range(self.env.H - 1, -1, -1):
            w = np.zeros(self.env.d)
            for i in range(len(self.trajectory_s)):
                s1 = self.trajectory_s[i][h]
                s2 = self.trajectory_s[i][h+1]
                a = self.trajectory_a[i][h]
                r = self.trajectory_r[i][h]
                w += self.env.phi[h, s1, a] * (r + np.max(self.UCBQ[h+1, s2]))
            w = np.dot(self.LambdaInverse[h], w)
            self.UCBQ[h] = np.minimum(np.dot(self.env.phi[h], w) + self.beta_ucb * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                        self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), self.env.H)
        
        V = np.zeros([self.env.H + 1, self.env.S])
        last_V = self.b
        for h0 in range(self.env.H, -1, -1):
            if h0 < self.env.H:
                V[h0] = np.take_along_axis(self.Q[h0], np.argmax(self.UCBQ[h0], axis=-1).reshape([self.env.S, 1]), axis=-1).squeeze(-1)
            for h in range(h0 - 1, -1, -1):
                w = np.zeros(self.env.d)
                for i in range(len(self.trajectory_s)):
                    s1 = self.trajectory_s[i][h]
                    s2 = self.trajectory_s[i][h+1]
                    a = self.trajectory_a[i][h]
                    r = self.trajectory_r[i][h]
                    w += self.env.phi[h, s1, a] * (r + V[h + 1, s2])
                w = np.dot(self.LambdaInverse[h], w)
                self.Q[h] = np.maximum(np.dot(self.env.phi[h], w) - self.beta_lcb * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                            self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), 0)
                p = self.baseline_actions[h]
                V[h] = np.einsum('ij,ij->i', p, self.Q[h])
            if np.mean(V[0]) < (1 - self.alpha) * self.b:
                break
            else:
                last_V = np.mean(V[0])

        if np.mean(V[0]) < (1 - self.alpha) * self.b:
            rho = (last_V - (1 - self.alpha) * self.b) / (last_V - np.mean(V[0]))
            rho = rho.item()
        else:
            rho = 1
        
        s = self.env.reset()
        temp_trajectory_s = []
        temp_trajectory_r = []
        temp_trajectory_a = []
        temp_trajectory_s.append(s)
        for h in range(self.env.H):
            if h < h0:
                _, a = self.baseline_policy(h, s)
            if h == h0:
                _, a1 = self.baseline_policy(h, s)
                a2 = np.argmax(self.UCBQ[h, s]).item()
                if np.random.uniform() < rho:
                    a = a2
                else:
                    a = a1
            if h > h0:
                a = np.argmax(self.UCBQ[h, s]).item()
            temp_trajectory_a.append(a)
            self.Lambda[h] += np.outer(self.env.phi[h, s, a], self.env.phi[h, s, a])
            s, r = self.env.step(a)
            temp_trajectory_s.append(s)
            temp_trajectory_r.append(r)
            
        self.trajectory_a.append(temp_trajectory_a)
        self.trajectory_r.append(temp_trajectory_r)
        self.trajectory_s.append(temp_trajectory_s)

        self.LambdaInverse = np.linalg.inv(self.Lambda)

        return np.sum(temp_trajectory_r), h0, rho

    def epoch_train_norho(self,):
        for h in range(self.env.H - 1, -1, -1):
            w = np.zeros(self.env.d)
            for i in range(len(self.trajectory_s)):
                s1 = self.trajectory_s[i][h]
                s2 = self.trajectory_s[i][h+1]
                a = self.trajectory_a[i][h]
                r = self.trajectory_r[i][h]
                w += self.env.phi[h, s1, a] * (r + np.max(self.UCBQ[h+1, s2]))
            w = np.dot(self.LambdaInverse[h], w)
            self.UCBQ[h] = np.minimum(np.dot(self.env.phi[h], w) + self.beta_ucb * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                        self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), self.env.H)
        
        V = np.zeros([self.env.H + 1, self.env.S])
        
        for h0 in range(self.env.H, -1, -1):
            if h0 < self.env.H:
                V[h0] = np.take_along_axis(self.Q[h0], np.argmax(self.UCBQ[h0], axis=-1).reshape([self.env.S, 1]), axis=-1).squeeze(-1)
            for h in range(h0 - 1, -1, -1):
                w = np.zeros(self.env.d)
                for i in range(len(self.trajectory_s)):
                    s1 = self.trajectory_s[i][h]
                    s2 = self.trajectory_s[i][h+1]
                    a = self.trajectory_a[i][h]
                    r = self.trajectory_r[i][h]
                    w += self.env.phi[h, s1, a] * (r + V[h + 1, s2])
                w = np.dot(self.LambdaInverse[h], w)
                self.Q[h] = np.maximum(np.dot(self.env.phi[h], w) - self.beta_lcb * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                            self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), 0)
                p = self.baseline_actions[h]
                V[h] = np.einsum('ij,ij->i', p, self.Q[h])
            if np.mean(V[0]) < (1 - self.alpha) * self.b:
                break
            
        if np.mean(V[0]) < (1 - self.alpha) * self.b:
            h0 += 1
        
        s = self.env.reset()
        temp_trajectory_s = []
        temp_trajectory_r = []
        temp_trajectory_a = []
        temp_trajectory_s.append(s)
        for h in range(self.env.H):
            if h < h0:
                _, a = self.baseline_policy(h, s)
            else:
                a = np.argmax(self.UCBQ[h, s]).item()
            temp_trajectory_a.append(a)
            self.Lambda[h] += np.outer(self.env.phi[h, s, a], self.env.phi[h, s, a])
            s, r = self.env.step(a)
            temp_trajectory_s.append(s)
            temp_trajectory_r.append(r)
            
        self.trajectory_a.append(temp_trajectory_a)
        self.trajectory_r.append(temp_trajectory_r)
        self.trajectory_s.append(temp_trajectory_s)

        self.LambdaInverse = np.linalg.inv(self.Lambda)

        return np.sum(temp_trajectory_r), h0, 1
    
    def UCB_train(self,):
        for h in range(self.env.H - 1, -1, -1):
            w = np.zeros(self.env.d)
            for i in range(len(self.trajectory_s)):
                s1 = self.trajectory_s[i][h]
                s2 = self.trajectory_s[i][h+1]
                a = self.trajectory_a[i][h]
                r = self.trajectory_r[i][h]
                w += self.env.phi[h, s1, a] * (r + np.max(self.UCBQ[h+1, s2]))
            w = np.dot(self.LambdaInverse[h], w)
            self.UCBQ[h] = np.minimum(np.dot(self.env.phi[h], w) + self.beta_ucb * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                        self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), self.env.H)

        mean_v = []
        for i in range(1):
            s = self.env.reset()
            temp_trajectory_s = []
            temp_trajectory_r = []
            temp_trajectory_a = []
            temp_trajectory_s.append(s)
            for h in range(self.env.H):
                a = np.argmax(self.UCBQ[h, s]).item()
                temp_trajectory_a.append(a)
                if i == 0:
                    self.Lambda[h] += np.outer(self.env.phi[h, s, a], self.env.phi[h, s, a])
                s, r = self.env.step(a)
                temp_trajectory_s.append(s)
                temp_trajectory_r.append(r)
            
            if i == 0:
                self.trajectory_a.append(temp_trajectory_a)
                self.trajectory_r.append(temp_trajectory_r)
                self.trajectory_s.append(temp_trajectory_s)
            mean_v.append(np.sum(temp_trajectory_r))

        self.LambdaInverse = np.linalg.inv(self.Lambda)

        return np.mean(mean_v)

    def epoch_train_nonmarkov(self,):
        for h in range(self.env.H - 1, -1, -1):
            w = np.zeros(self.env.d)
            for i in range(len(self.trajectory_s)):
                s1 = self.trajectory_s[i][h]
                s2 = self.trajectory_s[i][h+1]
                a = self.trajectory_a[i][h]
                r = self.trajectory_r[i][h]
                w += self.env.phi[h, s1, a] * (r + np.max(self.UCBQ[h+1, s2]))
            w = np.dot(self.LambdaInverse[h], w)
            self.UCBQ[h] = np.minimum(np.dot(self.env.phi[h], w) + self.beta_ucb * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                        self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), self.env.H)
        
        V = np.zeros([self.env.H + 1, self.env.S])

        for h in range(self.env.H - 1, -1, -1):
            w = np.zeros(self.env.d)
            for i in range(len(self.trajectory_s)):
                s1 = self.trajectory_s[i][h]
                s2 = self.trajectory_s[i][h+1]
                a = self.trajectory_a[i][h]
                r = self.trajectory_r[i][h]
                w += self.env.phi[h, s1, a] * (r + V[h + 1, s2])
            w = np.dot(self.LambdaInverse[h], w)
            self.Q[h] = np.maximum(np.dot(self.env.phi[h], w) - self.beta_lcb * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                        self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), 0)
            V[h] = np.max(self.Q[h], axis=-1)
        
        if np.mean(V[0]) < (1 - self.alpha) * self.b:
            rho = self.alpha * self.b / (self.b - np.mean(V[0]))
            rho = rho.item()
        else:
            rho = 1
        flag = np.random.uniform() > rho
        
        s = self.env.reset()
        temp_trajectory_s = []
        temp_trajectory_r = []
        temp_trajectory_a = []
        temp_trajectory_s.append(s)
        for h in range(self.env.H):
            if flag:
                # full baseline
                _, a = self.baseline_policy(h, s)
            else:
                a = np.argmax(self.UCBQ[h, s]).item()
            temp_trajectory_a.append(a)
            self.Lambda[h] += np.outer(self.env.phi[h, s, a], self.env.phi[h, s, a])
            s, r = self.env.step(a)
            temp_trajectory_s.append(s)
            temp_trajectory_r.append(r)
            
        self.trajectory_a.append(temp_trajectory_a)
        self.trajectory_r.append(temp_trajectory_r)
        self.trajectory_s.append(temp_trajectory_s)

        self.LambdaInverse = np.linalg.inv(self.Lambda)

        return np.sum(temp_trajectory_r), 0, rho

    def baseline_policy(self, t, s):
        p = self.baseline_actions[t, s]
        a = np.random.choice(self.env.A, 1, p=p).item()

        return p, a