import random
import numpy as np
import matplotlib.pyplot as plt
from pylab import xticks, yticks

from core.MDP import MDP
from utils.random_board import generate_random_board, generate_one_goal_board
from utils.utils import phi_exp_inv_factory, phi_exp_factory, phi_exp_combined_factory

class Grid_world():
    
    def __init__(self,
                 board,
                 gamma=.9,
                 win_reward=1,
                 punish_reward=-1):
        # print("Building Grid world!")
        
        self.board = board.astype(np.uint8)
        self.H, self.W = board.shape
        self.gamma = gamma
        self.win_reward = win_reward
        self.punish_reward = punish_reward
        
        # Use (i, j) to present a state first.
        self.state_list = [(i, j) for i in range(self.H) for j in range(self.W)]
        self.pos2idx = {(i, j): i * self.W + j for i in range(self.H) for j in range(self.W)}
        self.idx2pos = dict([val,key] for key,val in self.pos2idx.items())
        # Use a string to represent an action.
        self.action_list = ["up", "down", "left", "right", "stay"]
        self.action2idx = {"up": 0, "down": 1, "left": 2, "right": 3, "stay": 4}
        self.idx2action = dict([val,key] for key,val in self.action2idx.items())
        
        # Build the MDP.
        self.P, self.rewards = self.load_board()
        self.mdp = MDP(self.P, self.gamma, self.rewards)
        
        
    def load_board(self):
        
        def move(s, a):

            if a == "up":
                return s if s[0] == 0 else (s[0]-1, s[1])
            elif a == "down":
                return s if s[0] == (self.H - 1) else (s[0]+1, s[1])
            elif a == "left":
                return s if s[1] == 0 else (s[0], s[1]-1)
            elif a == "right":
                return s if s[1] == (self.W - 1) else (s[0], s[1]+1)
            else:
                return s
            
        def find_pos(board, element):

            result = []
            i_list, j_list = np.where(board == element)
            for idx in range(i_list.size):
                result.append((i_list[idx], j_list[idx]))
            return result
            
        
        board = self.board
        H, W = self.H, self.W
        pos2idx = self.pos2idx
        action2idx = self.action2idx
        
        target_state_list = find_pos(board, 2)
        obstacle_state_list = find_pos(board, 1)
        
        self.target_state_list = target_state_list
        self.obstacle_state_list = obstacle_state_list
        
        P = np.zeros((5, H*W, H*W))
        for a in self.action_list:
            for s in self.state_list:
                next_s = move(s, a)
                P[action2idx[a], pos2idx[s], pos2idx[next_s]] = 1
        for target_state in target_state_list:
            P[:, pos2idx[target_state], :] = 0
            P[:, pos2idx[target_state], pos2idx[target_state]] = 1              
        
        rewards = np.zeros((5, H*W, H*W))
        for target_state in target_state_list:

            rewards[:, :, pos2idx[target_state]] = self.win_reward * (P[:, :, pos2idx[target_state]] == 1).astype(np.float32)
            rewards[:, pos2idx[target_state], pos2idx[target_state]] = 0       
        
        for obstacle_state in obstacle_state_list:

            rewards[:, :, pos2idx[obstacle_state]] = self.punish_reward * (P[:, :, pos2idx[obstacle_state]] == 1).astype(np.float32)
            rewards[:, pos2idx[obstacle_state], pos2idx[obstacle_state]] = 0  
            
        return P, rewards
    
    
    def solve_mdp(self,
                  mode="value_iteration",
                  max_iter=10000,
                  epsilon=1e-7,
                  step_size=1,
                  asynchronous=False,
                  init=False,
                  verbose=False,
                  need_return=False,
                  noise=None,
                  seed=21):

        assert mode["alg"] in ["value_iteration", "policy_iteration",
                               "projected_Q_descent", "policy_descent",
                               "softmax", "softmax_adaptive", "softmax_temp", "softmax_NPG",
                               "phi", "escort", "escort_normalized"]
        
        alg = mode["alg"]
        
        if init:
            self.mdp.init_policy_and_V(random_init=True)
        
        if not verbose:
            print("Solving!")
        
        if alg == "value_iteration":
            return_dict = self.mdp.value_iteration(epsilon=epsilon, max_iter=max_iter, asynchronous=asynchronous,
                                              need_return=need_return, silence=verbose, seed=seed)
        elif alg == "policy_iteration":
            return_dict = self.mdp.policy_iteration(max_iter=max_iter,
                                              need_return=need_return, silence=verbose, seed=seed)
        elif alg == "projected_Q_descent":
            return_dict = self.mdp.projected_Q_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, noise=noise, seed=seed)
        elif alg == "policy_descent":
            return_dict = self.mdp.projected_Q_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, mode="policy_descent", noise=noise, seed=seed)
        elif alg == "softmax":
            return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
                                              need_return=need_return, silence=verbose, noise=noise, seed=seed)
        elif alg == "softmax_adaptive":
            return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
                                              need_return=need_return, silence=verbose, mode="adaptive", noise=noise, seed=seed)     
        elif alg == "softmax_temp":
            return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
                                              need_return=need_return, silence=verbose, mode="temp", noise=noise, seed=seed)    
        elif alg == "softmax_NPG":
            return_dict = self.mdp.softmax_descent(max_iter=max_iter, step_size=step_size,
                                              need_return=need_return, silence=verbose, mode="NPG", noise=noise, seed=seed)  
        elif alg == "escort_normalized":
            p = mode["p"]
            return_dict = self.mdp.escort_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, mode="normalized", p=p, noise=noise, seed=seed)
        elif alg == "escort":
            p = mode["p"]
            return_dict = self.mdp.escort_descent(max_iter=max_iter, step_size=step_size,
                                                  need_return=need_return, silence=verbose, mode="origin", p=p, noise=noise, seed=seed)                   
        elif alg == "phi": 
            phi = mode["phi"]
            return_dict = self.mdp.phi_policy_update(phi, max_iter=max_iter, step_size=step_size,
                                                     need_return=need_return, silence=verbose, noise=noise, seed=seed)
        if need_return:
            return return_dict      
        
        if not verbose:
            self.visualize_policy()
        
    
    def solve_mdp_using_MC_Learning(self,
                                    mode="Off-policy",
                                    max_iter=1000,
                                    epsilon=None,
                                    max_length=10000,
                                    verbose=False):

        assert mode in ["Off-policy", "On-policy"]

        self.solve_mdp(verbose=verbose)
        baseline_V = self.mdp.V
        
        print("Solving!")
        
        target_state_idx_list = [self.pos2idx[target_state] for target_state in self.target_state_list]
        if mode == "Off-policy":
            V_list = self.mdp.MC_off_policy_control(max_iter=max_iter,
                                                    epsilon=epsilon,
                                                    max_length=max_length,
                                                    terminate_state=target_state_idx_list,
                                                    need_return=True,
                                                    seed=1)
        else:
            V_list = self.mdp.MC_on_policy_control(max_iter=max_iter,
                                                   epsilon=epsilon,
                                                   max_length=max_length,
                                                   terminate_state=target_state_idx_list,
                                                   need_return=True,
                                                   seed=1)            

        V_array = np.stack(V_list, axis=0)
        V_mean_curse = np.mean(V_array, axis=1).tolist()
        if not verbose:
            plt.plot(V_mean_curse)
            plt.axhline(np.mean(baseline_V), color='red', alpha=.7)            
            plt.show()
        plt.clf()
        
        if not verbose:
            self.visualize_policy() 
        
        return V_mean_curse, np.mean(baseline_V)
    
    
    def solve_mdp_using_TD_Learning(self,
                                    mode="Off-policy",
                                    max_iter=1000,
                                    step_size=.1,
                                    epsilon=None,
                                    verbose=False,
                                    plot_freq=100):

        assert mode in ["SARSA", "Q-learning"]

        self.solve_mdp(verbose=verbose)
        baseline_V = self.mdp.V
        
        print("Solving!")
        
        target_state_idx_list = [self.pos2idx[target_state] for target_state in self.target_state_list]
        if mode == "SARSA":
            V_list = self.mdp.SARSA(max_iter=max_iter,
                                    epsilon=epsilon,
                                    step_size=step_size,
                                    terminate_state=target_state_idx_list,
                                    need_return=True,
                                    seed=1,
                                    plot_freq=plot_freq)
        else:
            # raise NotImplementedError
            V_list = self.mdp.Q_learning(max_iter=max_iter,
                                         epsilon=epsilon,
                                         step_size=step_size,
                                         terminate_state=target_state_idx_list,
                                         need_return=True,
                                         seed=1,
                                         plot_freq=plot_freq)         
        
        V_array = np.stack(V_list, axis=0)
        V_mean_curse = np.mean(V_array, axis=1).tolist()
        if not verbose:
            plt.plot(V_mean_curse)
            plt.axhline(np.mean(baseline_V), color='red', alpha=.7)            
            plt.show()
        plt.clf()
        
        if not verbose:
            self.visualize_policy() 
        
        return V_mean_curse, np.mean(baseline_V)    
    
                      
    
    def visualize_policy(self):
        
        def draw_one_arrow(s, a):
            y, x = s
            y, x = y+.5, x+.5
            
            if a == "up":
                start_x, start_y = x, y+.25
                ax.arrow(start_x, start_y, 0, -.5, length_includes_head=True, head_width=0.1, head_length=0.1, fc = 'r', ec = 'b')
            elif a == "down":
                start_x, start_y = x, y-.25
                ax.arrow(start_x, start_y, 0, .5, length_includes_head=True, head_width=0.1, head_length=0.1, fc = 'r', ec = 'b')
            elif a == "left":
                start_x, start_y = x+.25, y
                ax.arrow(start_x, start_y, -.5, 0, length_includes_head=True, head_width=0.1, head_length=0.1, fc = 'r', ec = 'b')
            elif a == "right":
                start_x, start_y = x-.25, y
                ax.arrow(start_x, start_y, .5, 0, length_includes_head=True, head_width=0.1, head_length=0.1, fc = 'r', ec = 'b')
        
        def draw_one_block(s, c):
            y, x = s
            plt.fill_between([x,x+1], y, y+1, facecolor=c)
                

        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.set_xlim(0, self.W); ax.set_ylim(0, self.H)
        xticks(np.linspace(0, self.W-1, self.W, endpoint=True))
        yticks(np.linspace(0, self.H-1, self.H, endpoint=True))
        ax.grid()
        ax.set_aspect('equal')
        ax.invert_yaxis()

        for target_state in self.target_state_list:
            draw_one_block(target_state, "green")
        for obstacle_state in self.obstacle_state_list:
            draw_one_block(obstacle_state, "red")

        policy = self.mdp.policy
        idx2pos = self.idx2pos
        idx2action = self.idx2action
        for s_idx, a_idx in policy.items():
            draw_one_arrow(idx2pos[s_idx], idx2action[a_idx])
        
        plt.show()
        
        
    def visualize_prob_policy(self, ax=None, verbose=False):
        
        def draw_one_arrow(s, policy):
            y, x = s
            y, x = y+.5, x+.5
            
            dx = 0.25*policy[2] - 0.25*policy[3]
            dy = 0.25*policy[0] - 0.25*policy[1]
            
            start_x = x + dx
            start_y = y + dy
            
            ddx = -.5*policy[2] + .5*policy[3]
            ddy = -.5*policy[0] + .5*policy[1]
        
            ax.arrow(start_x, start_y, ddx, ddy, length_includes_head=True, head_width=0.1, head_length=0.1, fc = 'r', ec = 'b')
        
        def draw_one_block(s, c):
            y, x = s
            ax.fill_between([x,x+1], y, y+1, facecolor=c)
                

        fig = plt.figure()
        if ax is None:
            ax = fig.add_subplot(111)
        ax.set_xlim(0, self.W); ax.set_ylim(0, self.H)
        # xticks(np.linspace(0, self.W-1, self.W, endpoint=True))
        # yticks(np.linspace(0, self.H-1, self.H, endpoint=True))
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.grid()
        ax.set_aspect('equal')
        ax.invert_yaxis()

        for target_state in self.target_state_list:
            draw_one_block(target_state, "green")
        for obstacle_state in self.obstacle_state_list:
            draw_one_block(obstacle_state, "red")

        idx2pos = self.idx2pos
        prob_policy = self.mdp.prob_policy
        S_size = self.mdp.S_size
        for s_idx in range(S_size):
            if idx2pos[s_idx] in self.target_state_list:
                continue
            draw_one_arrow(idx2pos[s_idx], prob_policy[s_idx])
        
        if not verbose:
            plt.show()
        
        return ax
    
        
    def plot_V_curve_in_VI(self, plot_num=5, differ_epsilon=1e-2):
        
        self.mdp.init_policy_and_V(random_init=False)
        
        V_list = self.mdp.value_iteration(need_return=True, silence=True)        
        
        V_policy_list = []
        for V in V_list:
            self.mdp.set_V(V)
            self.mdp.extract_policy()
            self.mdp.evaluate_policy()
            V_policy_list.append(self.mdp.V.copy())
        
        # Plot the curve.
        exception_V_s_curve = []
        for s in range(self.mdp.S_size):
            V_s_curve = [V_policy[s] for V_policy in V_policy_list]
            if not all(x<=(y+differ_epsilon) for x, y in zip(V_s_curve[:-1], V_s_curve[1:])):   
                print("Exception state: %d" % s)
                exception_V_s_curve.append(V_s_curve)
            plt.plot(V_s_curve)
    
        plt.show()

        plt.clf()
        for exception_V_s in exception_V_s_curve[:plot_num]:
            plt.plot(exception_V_s)
        
        plt.show()
        
        self.solve_mdp()
        
    
    def print_V(self):
            
        ct = 0
        for V_value in self.mdp.V:
            print("%8f" % V_value, end='\t')
            ct += 1
            if ct % self.W == 0:
                print("\n")
                
    
    def compute_delta(self):
        
        Delta = self.mdp.compute_delta()
        
        return Delta
    

if __name__ == '__main__':
    
    # board = np.array([
    #     [0, 0, 0, 0],
    #     [0, 0, 1, 0],
    #     [0, 1, 2, 0],
    #     [0, 1, 0, 0]
    # ], dtype=np.uint8)   
    
    # board = np.array([
    #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #     [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
    #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #     [0, 0, 0, 1, 0, 0, 0, 0, 1, 0],
    #     [0, 0, 0, 0, 0, 0, 2, 0, 0, 0],
    #     [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # ], dtype=np.uint8)
    
    random.seed(21)
    H = 10; W = 10
    board = generate_random_board(10, 10, .1, .005)
    # board = generate_one_goal_board(H, W)
    
    gamma = .9
    win_reward = 1,
    punish_reward = -5
    
    one_grid_world = Grid_world(board,
                                gamma,
                                win_reward,
                                punish_reward)
    
    # one_grid_world.solve_mdp(mode="policy_iteration",
    #                          init=True)
    
    one_grid_world.solve_mdp(mode="softmax",
                             init=True, step_size=10)    
    import pdb; pdb.set_trace()
    # import pdb; pdb.set_trace()
    # V_list_1, _ = one_grid_world.solve_mdp_using_MC_Learning(mode="Off-policy", max_iter=10000, epsilon=.1, verbose=True)
    # V_list_2, _ = one_grid_world.solve_mdp_using_MC_Learning(mode="Off-policy", max_iter=10000, epsilon=.5, verbose=True)
    # V_list_3, baseline_V = one_grid_world.solve_mdp_using_MC_Learning(mode="Off-policy", max_iter=10000, epsilon=1, verbose=True)
    
    # plt.plot(V_list_1, color='red')
    # plt.plot(V_list_2, color='black')
    # plt.plot(V_list_3, color='green')
    # plt.axhline(baseline_V, color='blue')
    # plt.show()
    
    # V_list_1, _ = one_grid_world.solve_mdp_using_TD_Learning(mode="SARSA", max_iter=1000000, epsilon=.1, verbose=True, plot_freq=10000)
    # V_list_2, _ = one_grid_world.solve_mdp_using_TD_Learning(mode="SARSA", max_iter=1000000, epsilon=.5, verbose=True, plot_freq=10000)
    # V_list_3, _ = one_grid_world.solve_mdp_using_TD_Learning(mode="SARSA", max_iter=1000000, epsilon=1, verbose=True, plot_freq=10000)
    # V_list_4, _ = one_grid_world.solve_mdp_using_TD_Learning(mode="Q-learning", max_iter=1000000, epsilon=.1, verbose=True, plot_freq=10000)
    # V_list_5, _ = one_grid_world.solve_mdp_using_TD_Learning(mode="Q-learning", max_iter=1000000, epsilon=.5, verbose=True, plot_freq=10000)
    # V_list_6, baseline_V = one_grid_world.solve_mdp_using_TD_Learning(mode="Q-learning", max_iter=1000000, epsilon=1, verbose=True, plot_freq=10000)
    # plt.plot(V_list_1, '-.', color='red', label='SARSA(epsilon=0.1)')
    # plt.plot(V_list_2, '--', color='red', label='SARSA(epsilon=0.5)')
    # plt.plot(V_list_3, '-', color='red', label='SARSA(epsilon=1)')
    # plt.plot(V_list_4, '-.', color='black', label='Q-Learning(epsilon=0.1)')
    # plt.plot(V_list_5, '--', color='black', label='Q-Learning(epsilon=0.5)')
    # plt.plot(V_list_6, '-', color='black', label='Q-Learning(epsilon=1)')    
    # plt.axhline(baseline_V, color='blue')
    # plt.legend()
    # plt.show()
    
    print("State Size: %d" % (H*W))
    
    Delta = one_grid_world.compute_delta()
    print("Delta: %f" % Delta)
    
    Delta_theory = gamma ** (H+W-3) * (1-gamma)
    print("Delta theory: %f" % Delta_theory)
    
    ori_upper_bound = ((H*W)*5 - 5) / (1-gamma) * np.log(1/(1-gamma))
    print("Ori upper bound: %f" % ori_upper_bound)
    
    upper_bound = np.log(2/(Delta_theory * (1-gamma))) / (1-gamma)
    print("Upper bound: %f" % upper_bound)