# -*- coding utf-8 -*-
# Baseline.py

# training baselines

import numpy as np
from Env import FiniteStateFiniteActionLinearMDP

class Trainer(object):
    # temporarily, trainer for finite-state finite-action linear MDP environments
    # based on Q-learning
    def __init__(self, env):
        super().__init__()
        
        self.env = env
        self.Q = np.zeros([env.H+1, env.S, env.A]) * env.H
        #self.Q[-1] = np.zeros([env.S, env.A])

        self.temprature_k = 1
        self.alpha = 0.02

    def policy(self, t, s):
        q = self.temprature_k * self.Q[t, s]
        p = np.exp(q)/sum(np.exp(q))
        a = np.random.choice(self.env.A, 1, p=p).item()

        return p, a

    def full_train(self, total_epoch=1000000):
        v = []
        for epoch in range(total_epoch):
            trajectory_s = []
            trajectory_r = []
            trajectory_a = []
            s = self.env.reset()
            trajectory_s.append(s)
            for h in range(self.env.H):
                _, a = self.policy(h, s)
                trajectory_a.append(a)
                s, r = self.env.step(a)
                trajectory_s.append(s)
                trajectory_r.append(r)
            trajectory_a.append(0)
            for h in range(self.env.H -1, -1, -1):
                self.Q[h, trajectory_s[h], trajectory_a] = (1 - self.alpha) * self.Q[h, trajectory_s[h], trajectory_a[h]] +\
                                            self.alpha * (trajectory_r[h] + self.Q[h, trajectory_s[h+1], trajectory_a[h+1]])
            temp = self.Q[0] # [S, A]
            temp_p = np.exp(temp) / np.sum(np.exp(temp), axis=-1).reshape([self.env.S, 1])
            v.append(np.mean(temp * temp_p).item())
        return v
            
    def save(self, dir='01'):
        np.save('policies/' + dir + '_kt' + str(self.temprature_k) + '.npy', self.Q)