# -*- coding utf-8 -*-
# Offline.py

# Learn a baseline policy from collecting data

import numpy as np
import pickle

class offlineTrainer(object):
    def __init__(self, env, trajectory=None, beta=0.1, lamb=1):
        super().__init__()
        
        self.env = env
        
        # trajectory is a list of list, [trajectory_s,trajectory_a,trajectory_r]
        if trajectory == None:
            self.trajectory_s = []
            self.trajectory_a = []
            self.trajectory_r = []
        else:
            self.trajectory_s = trajectory[0]
            self.trajectory_a = trajectory[1]
            self.trajectory_r = trajectory[2]

        self.beta = beta # only the lcb beta
        self.lamb = lamb

        self.update_lambda()

        self.LCBQ = np.zeros([self.env.H + 1, self.env.S, self.env.A])


    def collect_trajectories(self, policy, N):
        for i in range(N):
            s = self.env.reset()
            temp_trajectory_s = []
            temp_trajectory_r = []
            temp_trajectory_a = []
            temp_trajectory_s.append(s)
            for h in range(self.env.H):
                a = np.random.choice(self.env.A, p=policy[h, s])
                temp_trajectory_a.append(a)
                
                s, r = self.env.step(a)
                temp_trajectory_s.append(s)
                temp_trajectory_r.append(r)
            
            self.trajectory_a.append(temp_trajectory_a)
            self.trajectory_r.append(temp_trajectory_r)
            self.trajectory_s.append(temp_trajectory_s)
        self.update_lambda()


    def LCB_policy(self,):
        self.LCBQ[self.env.H] = 0
        for h in range(self.env.H - 1, -1, -1):
            w = np.zeros(self.env.d)
            for i in range(len(self.trajectory_s)):
                s1 = self.trajectory_s[i][h]
                s2 = self.trajectory_s[i][h+1]
                a = self.trajectory_a[i][h]
                r = self.trajectory_r[i][h]
                w += self.env.phi[h, s1, a] * (r + np.max(self.LCBQ[h+1, s2]))
            w = np.dot(self.LambdaInverse[h], w)
            self.LCBQ[h] = np.maximum(np.dot(self.env.phi[h], w) - self.beta * np.sqrt(np.einsum('ijk,kl,ijl->ij',
                                        self.env.phi[h], self.LambdaInverse[h], self.env.phi[h])), 0)

        # temp policy
        temp_policy = np.zeros([self.env.H, self.env.S, self.env.A])
        LCB_act = np.argmax(self.LCBQ, axis=-1)
        temp = temp_policy.reshape([-1, self.env.A])
        temp[(list(range(temp.shape[0])), LCB_act[:self.env.H].reshape([-1,]))] = 1
        temp = temp.reshape([self.env.H, self.env.S, self.env.A])

        return temp

    def update_lambda(self,):
        self.Lambda = np.tile(self.lamb * np.identity(self.env.d), self.env.H).reshape([self.env.d, self.env.H, self.env.d]).transpose([1, 0, 2])
        
        for h in range(self.env.H):
            for i in range(len(self.trajectory_s)):
                s = self.trajectory_s[i][h]
                a = self.trajectory_a[i][h]
                self.Lambda[h] += np.outer(self.env.phi[h, s, a], self.env.phi[h, s, a])

        self.LambdaInverse = np.linalg.inv(self.Lambda)