# -*- coding utf-8 -*-
# Env.py

# Randomly create LinearMDP environments

import numpy as np

from utils import RandomSimplexVector

class FiniteStateFiniteActionLinearMDP(object):
    def __init__(self, H=3, S=50, A=50, d=5):
        super().__init__()
        self.H = H
        self.S = S
        self.A = A
        self.d = d

        # theta
        # shape of theta should be (H, S, d)
        self.theta = np.random.uniform(size=[H, d])

        # mu
        # shape of mu should be (H, S, d)
        self.mu = RandomSimplexVector(d = S, size = [H, d]).transpose((0,2,1))

        # phibest
        # shape of phi should be (H, S, A, d)
        self.phi = RandomSimplexVector(d = d, size = [H, S, A])

    def reset(self,):
        self.t = 0
        self.s = np.random.randint(self.S)
        return self.s

    def step(self, action):
        r = np.dot(self.phi[self.t, self.s, action], self.theta[self.t])
        p = np.dot(self.mu[self.t], self.phi[self.t, self.s, action])
        s = np.random.choice(self.S, 1, p=p)
        self.s = s.item()
        self.t += 1
        return self.s, r

    def save_env(self, dir='01'):
        np.save('envs/' + dir + '_theta.npy', self.theta)
        np.save('envs/' + dir + '_mu.npy', self.mu)
        np.save('envs/' + dir + '_phi.npy', self.phi)

    def load_env(self, dir='01'):
        self.theta = np.load('envs/' + dir + '_theta.npy')
        self.mu = np.load('envs/' + dir + '_mu.npy')
        self.phi = np.load('envs/' + dir + '_phi.npy')

    def baseline_gen(self, temprature_k = 1):
        Q = np.zeros([self.H, self.S, self.A])
        V = np.zeros([self.H + 1, self.S])
        for h in range(self.H - 1, -1, -1):
            for s in range(self.S):
                for a in range(self.A):
                    p = np.dot(self.mu[h], self.phi[h, s, a])
                    EV = np.dot(p, V[h+1])
                    Q[h, s, a] = np.dot(self.theta[h], self.phi[h, s, a]) + EV
                p = np.exp(temprature_k * Q[h, s]) / np.sum(np.exp(temprature_k * Q[h, s]))
                V[h, s] = np.dot(p, Q[h, s])
        return np.mean(V[0]), np.exp(temprature_k * Q) / np.sum(np.exp(temprature_k * Q), axis = -1).reshape([self.H, self.S, 1])

    def best_gen(self,):
        Q = np.zeros([self.H, self.S, self.A])
        V = np.zeros([self.H + 1, self.S])
        action = np.zeros([self.H, self.S, self.A])
        for h in range(self.H - 1, -1, -1):
            for s in range(self.S):
                for a in range(self.A):
                    p = np.dot(self.mu[h], self.phi[h, s, a])
                    EV = np.dot(p, V[h+1])
                    Q[h, s, a] = np.dot(self.theta[h], self.phi[h, s, a]) + EV
                action[h, s, np.argmax(Q[h, s])] = 1
                V[h, s] = np.max(Q[h, s])
        return np.mean(V[0]), action

    def value_gen(self, actions):
        Q = np.zeros([self.H, self.S, self.A])
        V = np.zeros([self.H + 1, self.S])
        for h in range(self.H - 1, -1, -1):
            for s in range(self.S):
                for a in range(self.A):
                    p = np.dot(self.mu[h], self.phi[h, s, a])
                    EV = np.dot(p, V[h+1])
                    Q[h, s, a] = np.dot(self.theta[h], self.phi[h, s, a]) + EV
                p = actions[h, s]
                V[h, s] = np.dot(p, Q[h, s])
        return np.mean(V[0])