import numpy as np 

G = 10
L = 1
M = 1
MAX_SPEED = 8

# action:
#     0: Nomove
#     1: AU
#     2: AD
#     3: AL
#     4: AR
ACTIONMAP = np.array([[0, 0],
                      [-1, 0],
                      [1, 0],
                      [0, -1],
                      [0, 1]])

POLICYMAP = {
    0: "$\\uparrow$",
    1: "$\\downarrow$",
    2: "$\\leftarrow$",
    3: "$\\rightarrow$",
}

STATEMAP = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4],
                     [1, 0], [1, 1], [1, 2], [1, 3], [1, 4],
                     [2, 0], [2, 4],
                     [3, 0], [3, 1], [3, 3], [3, 4],
                     [4, 0], [4, 1], [4, 2], [4, 3], [4, 4],])

STATE_I2L = {k+1: v for k, v in enumerate(STATEMAP)}
STATE_L2I = {str(v): k+1 for k, v in enumerate(STATEMAP)}

DISALLOWED_STATE = np.array([[2, 1], [2, 2], [2, 3], [3, 2],
                             [-1, 0], [-1, 1], [-1, 2], [-1, 3], [-1, 4],
                             [5, 0], [5, 1], [5, 2], [5, 3], [5, 4],
                             [0, -1], [1, -1], [2, -1], [3, -1], [4, -1],
                             [0, 5], [1, 5], [2, 5], [3, 5], [4, 5],])


class catVSmonsters:
    def __init__(self, reward='one'):
        self.r2 = 1.
        
        self.p_remain = 0.7
        self.p_to_3 = 0.12
        self.p_to_4 = 0.12
        self.p_to_0 = 0.06

    def get_vs(self, s, a, table):

        if s == 21:
            return 10.
        elif s == 2:
            return self.r2
        
        state = STATE_I2L[s]   # index to location

        pa = self.get_pa(a)  # prob of each action index
        
        new_v = 0
        for i, action in enumerate(ACTIONMAP):
            ns, r = self.get_pnsr(state, action)
            new_v += pa[i] * (r + self.gamma * table[ns])
        #     if s == 16:
        #         print(ns, table[ns], r, r + self.gamma * table[ns])
        # #     print(pa[i], r, table[ns], ns, pa[i] * (r + self.gamma * table[ns]))
        # if s == 16:
        #     print("\n")
        # exit()
        return new_v


    def get_pa(self, a):
        pa = np.zeros(5)
        if a in [1, 2, 3, 4]:
            pa[a] += self.p_remain
            pa[3] += self.p_to_3 
            pa[4] += self.p_to_4
            pa[0] += self.p_to_0 
        else:
            raise ValueError("Input 'a' must be in {1, 2, 3, 4}")
        return pa
    
    def get_pnsr(self, s, a):
        condition = False
        for ds in DISALLOWED_STATE:
            condition |= (s+a == ds).all()
        ns = s if condition else s+a

        if (ns == np.array([0, 3])).all() or (ns == np.array([4, 1])).all():
            r = -8.
        elif self.q == 3 and (ns == np.array([0, 1])).all():
            r = self.r2
        else:
            r = -0.05

        return STATE_L2I[str(ns)], r

    def draw_policy(self, table):
        policy = {}
        for i in range(1, 22):
            vas = []
            for a in [1, 2, 3, 4]:
                va = self.get_vs(i, a, table)
                vas.append(va)
            action = POLICYMAP[np.array(vas).argmax()]
            policy[i] = action

        print("\\begin{center}")
        print("\\begin{tabular}{|c|c|c|c|c|}")
        print("\hline")
        print(f"{policy[1]} & {policy[2]} & {policy[3]} & {policy[4]} & {policy[5]} \\\\")
        print("\hline")
        print(f"{policy[6]} & {policy[7]} & {policy[8]} & {policy[9]} & {policy[10]} \\\\")
        print("\hline")
        print(f"{policy[11]} & N/A & N/A & N/A & {policy[12]} \\\\")
        print("\hline")
        print(f"{policy[13]} & {policy[14]} & N/A & {policy[15]} & {policy[16]} \\\\")
        print("\hline")
        print(f"{policy[17]} & {policy[18]} & {policy[19]} & {policy[20]} & {policy[21]} \\\\")
        print("\\hline")
        print("\\end{tabular}")
        print("\\end{center}")








