import numpy as np
import torch
from utils import *
class RBMLE():
    def __init__(self, S: int, A: int, K: int, R: np.ndarray, gamma:float, func:Network, alpha:float, eta:float, solving_approach = 'exact') -> None:
        self.S = S
        self.A = A
        self.action_space = np.eye(A)
        self.state_space = np.eye(S)
        self.K = K
        self.R = R
        self.gamma = gamma
        self.func = func
        self.solving_approach = solving_approach # exact or approximated
        if alpha == "log":
            self.alpha = lambda x:eta*np.log(x)
        else:
            self.alpha = lambda x:eta*x**(float(alpha))
        self.phi_list = []
        self.n_list = []
        # self.theta = np.array([0 ,1])
        self.theta = np.random.randn(self.K)
        self.theta = self.theta / np.sum(self.theta)
        self.V = np.ones(self.S) / (1 - self.gamma) 
        self.Q = np.ones((self.S, self.A)) / (1 - self.gamma) 
        self.history_table = np.array([[[0]*self.S]*self.A]*self.S) # count history for each state, action and next state
        self.sigma = np.identity(K)
        self.mle = np.random.randn(self.K)

        # compute MDP
        self.transitions = []
        for i in range(self.S):
            s = torch.from_numpy(np.array([self.state_space[i]])).float().cuda()
            self.transitions.append([])
            for j in range(self.A):
                a = torch.from_numpy(np.array([self.action_space[j]])).float().cuda()
                with torch.no_grad():
                    phi = self.func(s,a)
                self.transitions[i].append(np.dot(np.array([self.theta]),phi)[0].tolist())

    def select(self, state: int) -> int:
        action = np.argmax(self.Q[state])
        return action
    
    def train(self, state: int, action: int, reward: float, next_state: int) -> None:
        t = len(self.phi_list)
        # compute phi 
        state_vec = torch.from_numpy(np.array([self.state_space[state]])).float().cuda()
        action_vec = torch.from_numpy(np.array([self.action_space[action]])).float().cuda()
        phi = self.func(state_vec, action_vec).T[next_state]
        self.sigma = self.sigma + np.matmul(phi.reshape(-1,1), phi.reshape(1,-1))
        # phi_v_list.append(np.matmul(phi, self.V))
        # V_list.append(self.V[next_state])
        self.phi_list.append(phi)
        # history_table[state][action][next_state] += 1
            
        # compute N(s_i,a_i,s_{i+1}) / N(s_i,a_i)
        # n_list.append(history_table[state][action][next_state] / np.sum(history_table[state][action]))

        # optimize theta
        if self.solving_approach == 'exact':
            temp = solve_rbmle_exact(self.S, self.A, self.R, self.func, self.phi_list, next_state, self.alpha(t), self.gamma)
            if temp is not None:
                self.theta = temp

        elif self.solving_approach == 'approximated':
            # compute bias vector
            max_action = np.argmax(self.Q[next_state])
            state_vec = torch.from_numpy(np.array([self.state_space[next_state]])).float().cuda()
            action_vec = torch.from_numpy(np.array([self.action_space[max_action]])).float().cuda()
            phi = self.func(state_vec, action_vec)
            phi_v_bias = np.matmul(phi, self.V)
            temp = solve_rbmle_approximated(self.S, self.A, self.func, np.array(phi_list), self.alpha(t) * phi_v_bias)
            if temp is not None:
                self.theta = temp
        else:
            raise ValueError(self.solving_approach)
            
        # compute MDP
        self.transitions = []
        for i in range(self.S):
            s = torch.from_numpy(np.array([self.state_space[i]])).float().cuda()
            self.transitions.append([])
            for j in range(self.A):
                a = torch.from_numpy(np.array([self.action_space[j]])).float().cuda()
                with torch.no_grad():
                    phi = self.func(s,a)
                self.transitions[i].append(np.dot(np.array([self.theta]),phi)[0].tolist())
        self.update_value_function(max_iterations=10**6, delta=10**-4)
        return self.theta
    
    def update_value_function(self, max_iterations, delta):
        self.V = np.ones(self.S) / (1 - self.gamma) 
        R, P = self.R, self.transitions
        for _ in range(max_iterations):
            previous_value_fn = self.V.copy()
            self.Q = np.einsum('ijk,ijk -> ij', P, R + self.gamma * self.V)
            self.V = np.max(self.Q, axis=1)
            if np.max(np.abs(self.V - previous_value_fn)) < delta:
                break

    def compute_mle(self):
        t = len(self.phi_list)
        if t == 0:
            mle = np.random.rand(self.K)
            mle = mle / np.sum(mle)
        else:
            mle = solve_mle(self.S, self.A, self.func, np.array(self.phi_list))
        transitions = []
        for i in range(self.S):
            s = torch.from_numpy(np.array([self.state_space[i]])).float().cuda()
            transitions.append([])
            for j in range(self.A):
                a = torch.from_numpy(np.array([self.action_space[j]])).float().cuda()
                with torch.no_grad():
                    phi = self.func(s,a)
                transitions[i].append(np.dot(np.array([mle]),phi)[0].tolist())
        return mle, transitions
