import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import wandb


#########################################################################################
def shortest(MAP):
    """
    Use Floyd-Warshall to compute shortest distances between all states
    """
    board = MAP.replace(" ","").split('\n')
    arr = np.array([list(row) for row in board])
    free_spaces = list(map(tuple, np.argwhere(arr != '1')))

    dist = {(x, y) : np.inf for x in free_spaces for y in free_spaces}

    for (u, v) in dist.keys():
        d = abs(u[0] - v[0]) + abs(u[1] - v[1])
        if d == 0:
            dist[(u, v)] = 0
        elif d == 1:
            dist[(u, v)] = 1

    for k in free_spaces:
        for i in free_spaces:
            for j in free_spaces:
                if dist[(i, j)] > dist[(i, k)] + dist[(k, j)]:
                    dist[(i, j)] = dist[(i, k)] + dist[(k, j)]
    
    return dist, free_spaces

def get_V_optimal(env, dist, free_spaces, gamma = 1):
    V = defaultdict(lambda: 0)
    for state in free_spaces:
        values = []
        for goal in env.T_states:
            d = dist[(state,goal)]
            C = d
            if gamma != 1:
                C = (1-(gamma**C))/(1-gamma)
            v = env._get_reward(goal, 0) + C * env._get_reward(state, 0)
            values.append(v)
        V[state] = np.max(values)
    return V

def get_EV_optimal(env, dist, free_spaces, rmin_ = -100, gamma = 1):
    EV = defaultdict(lambda: defaultdict(lambda: 0))
    for state in free_spaces:
        for goal in env.T_states:
            d = dist[(state,goal)]
            C = d
            if gamma != 1:
                C = (1-(gamma**C))/(1-gamma)
            EV[state][goal] = env._get_reward(goal, 0) + C * env._get_reward(state, 0)
    return EV

def V_equal(V1,V2,epsilon=1e-2):    
    for state in V1:
        if abs(V1[state]-V2[state])>epsilon:
            return False
    return True

def EV_equal(EV1,EV2,epsilon=1e-2):    
    for state in EV1:
        for goal in EV1[state]:
            if abs(EV1[state][goal]-EV2[state][goal])>epsilon:
                return False
    return True


#########################################################################################

def evaluateQ(env, Q, V_opt, gamma=1, render=False):
    G=0      
    S=0
    state, _ = env.reset()
    for t in range(20):
        if render:
            env.render()
            plt.pause(0.001)
        action = Q[state].argmax()
        state, reward, done, truncate, _ = env.step(action) 
        G += (gamma**t)*reward
        if done or truncate:
            S = reward>0
            break
    return G, S, t

def get_Vopt(env_eval, gamma):
    """
    Evaluates the regret of the greedy policy derived from Q in the given environment.

    Parameters:
    - env_eval: A Gymnasium environment with a finite state space.
    - Q: A dictionary mapping state -> action-value dictionary.
    - gamma: Discount factor (0 < gamma <= 1).

    Returns:
    - regret: The difference between the optimal value function and the greedy policy's value function.
    """
    
    env_eval.reset()
    num_states = len(env_eval.unwrapped.states)
    num_actions = env_eval.unwrapped.action_space.n
    #print("*1")
    # Compute optimal value function via value iteration
    V_opt = defaultdict(float)
    theta = 1e-6  # Convergence threshold
    while True:
        delta = 0
        for s in env_eval.unwrapped.states:
            max_q = max(sum(p * (r + gamma * V_opt[s_]) for p, s_, r, _ in env_eval.unwrapped.stepP(s,a))
                        for a in range(num_actions))
            delta = max(delta, abs(max_q - V_opt[s]))
            V_opt[s] = max_q
        if delta < theta:
            break
    return V_opt

def evaluate_regret(env, Q, V_opt, gamma, render=False):
    """
    # Compute value function of greedy policy induced by Q
    V_greedy = defaultdict(float)
    while True:
        delta = 0
        for s in env_eval.states:
            a_greedy = Q[s].argmax()  # Greedy action from Q
            v_new = sum(p * (r + gamma * V_greedy[s_]) for p, s_, r, _ in env_eval.stepP(s,a_greedy))
            delta = max(delta, abs(v_new - V_greedy[s]))
            V_greedy[s] = v_new
        if delta < theta:
            break
    """
    G, S, T, passive_count, active_count = 0, 0, 0, 0, 0
    length = env.unwrapped.length
    min_length = env.unwrapped.min_length
    V_opt = get_Vopt(env, gamma)
    for goal in [0,1]:
        env.unwrapped.start_goal = goal
        env.unwrapped.min_length = length
        state, _ = env.reset()
        for t in range(1000):
            #if render:
            #    env.render()
            #    plt.pause(0.001)
            action = Q[state].argmax()
            state_, _, done, truncate, info = env.step(action)
            reward = info["reward"]
            G += (gamma**t)*reward
            if hasattr(env,"num_stack"):
                s_size = env.unwrapped.observation_space.shape[0]
                goal = tuple(env.get_pos_obs(env.goal))
                s_obs = state[-s_size:]
                ns_obs = state_[-s_size:]
                s_mem = [state[i*s_size:(i+1)*s_size] for i in range(0,env.num_stack-1)]
                ns_mem = [state_[i*s_size:(i+1)*s_size] for i in range(0,env.num_stack-1)]
                if not done and (goal in s_mem): passive_count += (goal not in ns_mem)
                if not done and (s_obs==goal) and (goal not in s_mem): active_count += (goal not in ns_mem)
                # print(done, s_size, state, s_obs, s_mem, ns_obs, ns_mem, goal, passive_count, active_count)
            if done or truncate: S += reward>0; break
            state = state_
        T += t
    env.unwrapped.min_length = min_length
    # Compute regret as the difference in value functions
    #regret = np.sum([V_opt[s] - V_greedy[s] for s in env_eval.states])
    R = np.mean([V_opt[s] for s in env.unwrapped.start_states]) - G/2
    # print(R, V_opt[env.start_states[0]], np.mean([V_opt[s] for s in env.start_states]), G/2)
    return R, G/2, S/2, T/2, passive_count/2, active_count/2

#########################################################################################

def Q_learning(env, env_eval, V_optimal=None, gamma=1, epsilon=1, mask_type="no_stack", epsilon_type="edecay", alpha=1, maxiter=100, mean_episodes=100, save_path=False, p=True):

    qinit = 0
    Q = defaultdict(lambda: qinit+np.zeros(env.action_space.n))
    # if mask_type == "masked" or mask_type == "demir":
    #     if env.length>0: qinit = ((1/(env.length))*gamma*(1-gamma**(env.length))/(1-gamma))
    #     Q = defaultdict(lambda: qinit+np.zeros(env.action_space.n))
    env_eval.reset()
    V_opt = get_Vopt(env_eval, gamma)

    print('Path: ', save_path)
    stats = {"T":0, "episode":[0], "regrets":[0], "returns":[0], "rewards":[], "success":[],"steps":[],"states":[],"learned":[], "mask_regret":[0], "passive_count":[0], "active_count":[0]}
    k=0
    T=0
    t=0
    passive_count=0
    active_count=0
    mask_regret=0
    rewards=0
    successes=0
    frames_max=0
    state, _ = env.reset()
    
    frames={}
    if "masked" in mask_type: frames=dict(state = env.env.frames.copy())

    stop_cond = lambda k: k < maxiter
    if V_optimal:
        stop_cond = lambda k: True if k%mean_episodes != 0 else not V_equal(V_optimal,Q_V(Q))

    while T<maxiter:        
        if epsilon_type=="edecay":
            fraction = min(T / (maxiter), 1.0)
            epsilon = 1 - 0.95*fraction
        if np.random.random() > epsilon:
            action = np.random.choice(np.flatnonzero(Q[state] == Q[state].max())) # Q[state].argmax() # 
        else:
            action = np.random.randint(env.action_space.n)      

        goal = tuple(env.unwrapped.get_pos_obs(env.unwrapped.goal))
        state_, reward, done, truncate, info = env.step(action)
        stats["returns"][-1] += (gamma**t)*info["reward"]
        stats["rewards"].append(reward)
        if info["reward"]!=0: stats["success"].append(info["reward"]>0)
        if hasattr(env,"num_stack"):
            s_size = env.unwrapped.observation_space.shape[0]
            s_obs = state[-s_size:]
            ns_obs = state_[-s_size:]
            s_mem = [state[i*s_size:(i+1)*s_size] for i in range(0,env.num_stack-1)]
            ns_mem = [state_[i*s_size:(i+1)*s_size] for i in range(0,env.num_stack-1)]
            stats["mask_regret"][-1] += ((not done) and (goal not in ns_mem))
            if (not done) and (goal in s_mem): stats["passive_count"][-1] += (goal not in ns_mem)
            if (not done) and (s_obs==goal) and (goal not in s_mem): stats["active_count"][-1] += (goal not in ns_mem)
        if mask_type == "no_stack" and (tuple(state)==goal): stats["passive_count"][-1] += 1; stats["active_count"][-1] += 1; 
        if mask_type == "no_stack": stats["mask_regret"][-1] += 1

        if mask_type == "ca_masked":
            env_action, mask = env.split_action_mask(action)
            G = 0 if done else np.max([Q[state_][env.join_action_mask(env_action_,mask)] for env_action_ in range(env.action_space.n//env.num_stack)])
            Q[state][action] += alpha*(reward + gamma*G - Q[state][action])
        elif mask_type == "ca_all_masked":
            env_action, _ = env.split_action_mask(action)
            for mask in range(env.num_stack):
                _action = env.join_action_mask(env_action,mask)
                pre_frames = info["pre_frames"].copy(); del pre_frames[mask]; pre_frames.append(info["obs"])
                _state_ = tuple(np.array(pre_frames).flatten())
                G = 0 if done else np.max([Q[_state_][env.join_action_mask(env_action_,mask)] for env_action_ in range(env.action_space.n//env.num_stack)])
                Q[state][_action] += alpha*(reward + gamma*G - Q[state][_action])
        elif mask_type == "all_masked":
            env_action, _ = env.split_action_mask(action)
            for mask in range(env.num_stack):
                _action = env.join_action_mask(env_action,mask)
                pre_frames = info["pre_frames"].copy(); del pre_frames[mask]; pre_frames.append(info["obs"])
                _state_ = tuple(np.array(pre_frames).flatten())
                G = 0 if done else np.max(Q[_state_])
                Q[state][_action] += alpha*(reward + gamma*G - Q[state][_action])
        elif mask_type == "ca_all_history_masked":# and not done:
            env_action, _ = env.split_action_mask(action)
            frames_={}
            for pre_frames in frames.values():
                _state = tuple(np.array(pre_frames).flatten())
                for mask in range(env.num_stack):
                    _action = env.join_action_mask(env_action,mask)
                    pre_frames_ = pre_frames.copy(); del pre_frames_[mask]; pre_frames_.append(info["obs"])
                    _state_ = tuple(np.array(pre_frames_).flatten()); frames_[_state_]=pre_frames_
                    G = 0 if done else np.max([Q[_state_][env.join_action_mask(env_action_,mask)] for env_action_ in range(env.action_space.n//env.num_stack)])
                    Q[_state][_action] += alpha*(reward + gamma*G - Q[_state][_action])
            frames = frames_            
        elif mask_type == "all_history_masked":# and not done:
            env_action, _ = env.split_action_mask(action)
            frames_={}
            for pre_frames in frames.values():
                _state = tuple(np.array(pre_frames).flatten())
                for mask in range(env.num_stack):
                    _action = env.join_action_mask(env_action,mask)
                    pre_frames_ = pre_frames.copy(); del pre_frames_[mask]; pre_frames_.append(info["obs"])
                    _state_ = tuple(np.array(pre_frames_).flatten()); frames_[_state_]=pre_frames_
                    G = 0 if done else np.max(Q[_state_])
                    Q[_state][_action] += alpha*(reward + gamma*G - Q[_state][_action])
            frames = frames_
        else:
            G = 0 if done else np.max(Q[state_])
            Q[state][action] += alpha*(reward + gamma*G - Q[state][action])
        
        state = state_
        t+=1; T += 1
        if done or truncate:                       
            state, _ = env.reset()
            if "masked" in mask_type: 
                frames={"state": env.env.frames.copy()}
            stats["steps"].append(T-stats["T"])
            stats["T"] = T
            stats["episode"].append(0)
            k+=1
        if (T)%(env.unwrapped.length+2+env.unwrapped.active)==0:
            frames_max = max(frames_max,len(frames))
            stats["returns"].append(0)
            stats["mask_regret"].append(0)
            stats["passive_count"].append(0)
            stats["active_count"].append(0)
            t=0

        if (T)%100==0:
            # R, G, S, steps, _, _ = evaluate_regret(env_eval, Q, V_opt, gamma)
            # stats["regrets"].append(R)
            # stats["states"].append(len(list(Q.keys())))
            
            if (T)%(mean_episodes*1) == 0:
                q = {s:q.copy() for s,q in Q.items()}
                stats["learned"].append(q)
            if p and (T)%(mean_episodes*1) == 0:
                if save_path:
                    np.save(save_path, stats)
                    np.save(save_path+"-values", dict(Q))
                # wandb.log({"episode": k})
                # wandb.log({"regrets": R})
                # wandb.log({"returns": G})
                # wandb.log({"success": S})
                # wandb.log({"passive_count": stats["passive_count"][-1]})
                # wandb.log({"active_count": stats["active_count"][-1]})
                # wandb.log({"steps": steps})
                # wandb.log({"epsilon": epsilon})
                # wandb.log({"states": len(list(Q.keys()))})
                
                # mean_regret = evaluateRegret(env_eval, Q, gamma)
                # stats["regrets"].append(mean_regret)
                mean_passive_count = np.mean(stats["passive_count"][-mean_episodes:])
                mean_active_count = np.mean(stats["active_count"][-mean_episodes:])
                mean_mask_regret = np.mean(stats["mask_regret"][-mean_episodes:])
                mean_regret = np.mean(stats["regrets"][-mean_episodes:])
                mean_return = np.mean(stats["returns"][-mean_episodes:])
                mean_success = np.mean(stats["success"][-mean_episodes:])
                total_rewards = np.sum(stats["rewards"][-1000:])
                print('Steps: ', T, ' | Episode: ', k, ' | Regret: ', round(mean_regret,2), ' | Return: ', round(mean_return,2), ' | Rewards: ', total_rewards, ' | Success: ', round(mean_success,2),
                        ' | States: ', len(list(Q.keys())), ' | Passive regret: ', round(mean_passive_count,2), ' | Active regret: ', round(mean_active_count,2), ' | Mask regret: ', round(mean_mask_regret,2), "| frames buffer: ", frames_max)
    if save_path:
        np.save(save_path, stats)
    
    return Q, stats

#########################################################################################

def Double_Q_learning(env, V_optimal=None, gamma=1, epsilon=1, epsilon_type="edecay", alpha=1, maxiter=100, mean_episodes=100, save_path=False, p=True):

    Q1 = defaultdict(lambda: np.zeros(env.action_space.n))
    Q2 = defaultdict(lambda: np.zeros(env.action_space.n))
    
    stats = {"R":[], "T":0, "episode":[], "returns":[], "success":[],"steps":[],"states":[],"learned":[]}
    k=0
    T=0
    t=0    
    state = env.reset()

    stop_cond = lambda k: k < maxiter
    if V_optimal:
        stop_cond = lambda k: True if k%mean_episodes != 0 else not V_equal(V_optimal,Q_V(Q1))

    while stop_cond(k):
        
        if epsilon_type=="edecay":
            fraction = min(float(k) / maxiter, 1.0)
            epsilon = 1 - 0.9*fraction
        if np.random.random() > epsilon:
            action = np.random.choice(np.flatnonzero(Q1[state]+Q2[state] == (Q1[state]+Q2[state]).max()))
        else:
            action = np.random.randint(env.action_space.n)             
        state_, reward, done, _ = env.step(action)
        
        if np.random.random()>0.5:
            G = 0 if done else Q2[state_][Q1[state_].argmax()]
            Q1[state][action] += alpha*(reward + gamma*G - Q1[state][action])
        else:
            G = 0 if done else Q1[state_][Q2[state_].argmax()]
            Q2[state][action] += alpha*(reward + gamma*G - Q2[state][action])
        
        state = state_
        t+=1
        if done: #  or (t>100 and epsilon_type!="offline"):            
            G, S, steps = evaluateQ(env, Q1, gamma)
            
            state = env.reset()

            stats["T"] += t
            t=0
            k+=1
            stats["episode"].append(k-1)
            stats["returns"].append(G)
            stats["success"].append(S)
            stats["steps"].append(steps)
            stats["states"].append(len(list(Q1.keys())))
            # stats["all_states"] = states
            if (k-1)%mean_episodes == 0:
                q = {s:q.copy() for s,q in Q1.items()}
                stats["learned"].append(q)
            if p and (k-1)%mean_episodes == 0:
                if save_path:
                    np.save(save_path, stats)
                wandb.log({"episode": k})
                wandb.log({"returns": G})
                wandb.log({"success": S})
                wandb.log({"steps": steps})
                wandb.log({"epsilon": epsilon})
                wandb.log({"states": len(list(Q1.keys()))})
                
                mean_return = np.mean(stats["returns"][-mean_episodes:])
                mean_success = np.mean(stats["success"][-mean_episodes:])
                print('Path: ', save_path, 'Episode: ', k-1, ' | Mean return: ', mean_return, ' | Mean success: ', mean_success,
                      ' | States: ', len(list(Q1.keys())))
    if save_path:
        np.save(save_path, stats)
    
    return Q1, stats

#########################################################################################

def evaluateSQ(env, Q, gamma=1, temperature=1, render=False):
    G=0      
    S=0
    state = env.reset()
    for t in range(20):
        if render:
            env.render(goal=True, agent=True)
            # plt.pause(0.001)
            plt.show()
        v = temperature * np.log(np.sum(np.exp(Q[state] / temperature)))
        dist = np.exp((Q[state]) / temperature)
        action_probs = dist / np.sum(dist)
        action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
        state, reward, done, _ = env.step(action) 
        G += (gamma**t)*reward
        if done:
            S = 1
            break
    return G, S, t

def Soft_Q_learning(env, V_optimal=None, gamma=1, epsilon=1, epsilon_type="edecay", temperature=1, alpha=1, maxiter=100, mean_episodes=100, save_path=False, p=True):
    
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    stats = {"R":[], "T":0, "episode":[], "returns":[], "success":[],"steps":[],"states":[],"learned":[]}
    k=0
    T=0
    t=0    
    state = env.reset()

    stop_cond = lambda k: k < maxiter
    if V_optimal:
        stop_cond = lambda k: True if k%mean_episodes != 0 else not V_equal(V_optimal,Q_V(Q))

    while stop_cond(k):
        if epsilon_type=="edecay":
            fraction = min(float(k) / maxiter, 1.0)
            epsilon = 1 - 0.9*fraction
        if np.random.random()>epsilon:
            v = temperature * np.log(np.sum(np.exp(Q[state] / temperature)))
            dist = np.exp((Q[state]) / temperature)
            action_probs = dist / np.sum(dist)
            action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
        else:
            action = np.random.randint(env.action_space.n)    
        state_, reward, done, _ = env.step(action)
        
        G = 0 if done else temperature * np.log(np.sum(np.exp(Q[state_] / temperature)))
        Q[state][action] += alpha*(reward + gamma*G - Q[state][action])
        
        state = state_
        t+=1
        if done: #  or (t>100 and epsilon_type!="offline"):            
            G, S, steps = evaluateSQ(env, Q, gamma, temperature)
            
            state = env.reset()

            stats["T"] += t
            t=0
            k+=1
            stats["episode"].append(k-1)
            stats["returns"].append(G)
            stats["success"].append(S)
            stats["steps"].append(steps)
            stats["states"].append(len(list(Q.keys())))
            # stats["all_states"] = states
            if (k-1)%mean_episodes == 0:
                q = {s:q.copy() for s,q in Q.items()}
                stats["learned"].append(q)
            if p and (k-1)%mean_episodes == 0:
                # if k>8000:
                #     goal = (1,1)
                #     env.render(P=SQ_P(Q,env.goals[0],temperature), V = SQ_V(Q,env.goals[0],temperature))
                #     plt.show()
                if save_path:
                    np.save(save_path, stats)
                wandb.log({"episode": k})
                wandb.log({"returns": G})
                wandb.log({"success": S})
                wandb.log({"steps": steps})
                wandb.log({"temperature": temperature})
                wandb.log({"states": len(list(Q.keys()))})
                
                mean_return = np.mean(stats["returns"][-mean_episodes:])
                mean_success = np.mean(stats["success"][-mean_episodes:])
                print('Path: ', save_path, 'Episode: ', k-1, ' | Mean return: ', mean_return, ' | Mean success: ', mean_success,
                      ' | States: ', len(list(Q.keys())))
    if save_path:
        np.save(save_path, stats)
                    
    return Q, stats

#########################################################################################

def Monte_Carlo_0(env, V_optimal=None, gamma=1, epsilon=1, epsilon_type="edecay", alpha=1, maxiter=100, mean_episodes=100, save_path=False, p=True):

    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    C = defaultdict(lambda: np.zeros(env.action_space.n))
    
    stats = {"R":[], "T":0, "episode":[], "returns":[], "success":[],"steps":[],"states":[],"learned":[]}
    k=0
    T=0
    t=0    
    episode = []
    episode_ = []
    state = env.reset()

    stop_cond = lambda k: k < maxiter
    if V_optimal:
        stop_cond = lambda k: True if k%mean_episodes != 0 else not V_equal(V_optimal,Q_V(Q))

    while stop_cond(k):
        
        best_action = True
        if epsilon_type=="edecay":
            fraction = min(float(k) / maxiter, 1.0)
            epsilon = 1 - 0.9*fraction
        if np.random.random() > epsilon:
            action = np.random.choice(np.flatnonzero(Q[state] == Q[state].max()))
        else:
            action = np.random.randint(env.action_space.n)             
            best_action = False
        state_, reward, done, _ = env.step(action)
        episode.append((state,action,best_action,reward,state_,done))
        episode_.append((state,action))
        
        state = state_
        t+=1
        if done: #  or (t>100 and epsilon_type!="offline"):      
            ### MC update
            G = 0
            W = 1
            episode.reverse()
            episode_.reverse()
            for i, (state,action,best_action,reward,_,_) in enumerate(episode):
                G = gamma*G + reward
                if (state,action) not in episode_[i+1:]:
                    Q[state][action] = Q[state][action]*C[state][action]/(C[state][action]+1) + G/(C[state][action]+1)
                    C[state][action] += 1 
            ###
                  
            G, S, steps = evaluateQ(env, Q, gamma)
            
            state = env.reset()

            stats["T"] += t
            t=0
            k+=1
            episode=[]
            episode_=[]
            stats["episode"].append(k-1)
            stats["returns"].append(G)
            stats["success"].append(S)
            stats["steps"].append(steps)
            stats["states"].append(len(list(Q.keys())))
            # stats["all_states"] = states
            if (k-1)%mean_episodes == 0:
                q = {s:q.copy() for s,q in Q.items()}
                stats["learned"].append(q)
            if p and (k-1)%mean_episodes == 0:
                # if k>8000:
                #     goal = (1,1)
                #     env.render(P=Q_P(Q,env.goals[0]), V = Q_V(Q,env.goals[0]))
                #     plt.show()
                if save_path:
                    np.save(save_path, stats)
                wandb.log({"episode": k})
                wandb.log({"returns": G})
                wandb.log({"success": S})
                wandb.log({"steps": steps})
                wandb.log({"epsilon": epsilon})
                wandb.log({"states": len(list(Q.keys()))})
                
                mean_return = np.mean(stats["returns"][-mean_episodes:])
                mean_success = np.mean(stats["success"][-mean_episodes:])
                print('Path: ', save_path, 'Episode: ', k-1, ' | Mean return: ', mean_return, ' | Mean success: ', mean_success,
                      ' | States: ', len(list(Q.keys())))
    if save_path:
        np.save(save_path, stats)
                    
    return Q, stats

#########################################################################################

def Monte_Carlo_1(env, V_optimal=None, gamma=1, epsilon=1, epsilon_type="edecay", alpha=1, maxiter=100, mean_episodes=100, save_path=False, p=True):

    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    C = defaultdict(lambda: np.zeros(env.action_space.n))
    
    stats = {"R":[], "T":0, "episode":[], "returns":[], "success":[],"steps":[],"states":[],"learned":[]}
    k=0
    T=0
    t=0    
    episode = []
    state = env.reset()

    stop_cond = lambda k: k < maxiter
    if V_optimal:
        stop_cond = lambda k: True if k%mean_episodes != 0 else not V_equal(V_optimal,Q_V(Q))

    while stop_cond(k):
        
        best_action = True
        if epsilon_type=="edecay":
            fraction = min(float(k) / maxiter, 1.0)
            epsilon = 1 - 0.9*fraction
        if np.random.random() > epsilon:
            action = np.random.choice(np.flatnonzero(Q[state] == Q[state].max()))
        else:
            action = np.random.randint(env.action_space.n)             
            best_action = False
        state_, reward, done, _ = env.step(action)
        episode.append((state,action,best_action,reward,state_,done))
        
        state = state_
        t+=1
        if done: #  or (t>100 and epsilon_type!="offline"):      
            ### MC update
            G = 0
            W = 1
            episode.reverse()
            for (state,action,best_action,reward,_,_) in episode:
                G = gamma*G + reward
                Q[state][action] += alpha*(G-Q[state][action])
            ###
                  
            G, S, steps = evaluateQ(env, Q, gamma)
            
            state = env.reset()

            stats["T"] += t
            t=0
            k+=1
            episode=[]
            stats["episode"].append(k-1)
            stats["returns"].append(G)
            stats["success"].append(S)
            stats["steps"].append(steps)
            stats["states"].append(len(list(Q.keys())))
            # stats["all_states"] = states
            if (k-1)%mean_episodes == 0:
                q = {s:q.copy() for s,q in Q.items()}
                stats["learned"].append(q)
            if p and (k-1)%mean_episodes == 0:
                # if k>8000:
                #     goal = (1,1)
                #     env.render(P=Q_P(Q,env.goals[0]), V = Q_V(Q,env.goals[0]))
                #     plt.show()
                if save_path:
                    np.save(save_path, stats)
                wandb.log({"episode": k})
                wandb.log({"returns": G})
                wandb.log({"success": S})
                wandb.log({"steps": steps})
                wandb.log({"epsilon": epsilon})
                wandb.log({"states": len(list(Q.keys()))})
                
                mean_return = np.mean(stats["returns"][-mean_episodes:])
                mean_success = np.mean(stats["success"][-mean_episodes:])
                print('Path: ', save_path, 'Episode: ', k-1, ' | Mean return: ', mean_return, ' | Mean success: ', mean_success,
                      ' | States: ', len(list(Q.keys())))
    if save_path:
        np.save(save_path, stats)
                    
    return Q, stats

#########################################################################################

def Monte_Carlo_2(env, V_optimal=None, gamma=1, epsilon=1, epsilon_type="edecay", alpha=1, maxiter=100, mean_episodes=100, save_path=False, p=True):

    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    C = defaultdict(lambda: np.zeros(env.action_space.n))
    
    stats = {"R":[], "T":0, "episode":[], "returns":[], "success":[],"steps":[],"states":[],"learned":[]}
    k=0
    T=0
    t=0    
    episode = []
    state = env.reset()

    stop_cond = lambda k: k < maxiter
    if V_optimal:
        stop_cond = lambda k: True if k%mean_episodes != 0 else not V_equal(V_optimal,Q_V(Q))

    while stop_cond(k):
        
        best_action = True
        if epsilon_type=="edecay":
            fraction = min(float(k) / maxiter, 1.0)
            epsilon = 1 - 0.9*fraction
        if np.random.random() > epsilon:
            action = np.random.choice(np.flatnonzero(Q[state] == Q[state].max()))
        else:
            action = np.random.randint(env.action_space.n)             
            best_action = False
        state_, reward, done, _ = env.step(action)
        episode.append((state,action,best_action,reward,state_,done))
        
        state = state_
        t+=1
        if done: #  or (t>100 and epsilon_type!="offline"):      
            ### MC update
            G = 0
            W = 1
            episode.reverse()
            for (state,action,best_action,reward,_,_) in episode:
                G = gamma*G + reward
                C[state][action] += W
                Q[state][action] += (W/C[state][action])*(G-Q[state][action])
                prob = 1-epsilon + epsilon/env.action_space.n if best_action else epsilon/env.action_space.n
                W = W/prob
                action_new = np.random.choice(np.flatnonzero(Q[state] == Q[state].max()))
                if action != action_new:
                    break
            ###
                  
            G, S, steps = evaluateQ(env, Q, gamma)
            
            state = env.reset()

            stats["T"] += t
            t=0
            k+=1
            episode=[]
            stats["episode"].append(k-1)
            stats["returns"].append(G)
            stats["success"].append(S)
            stats["steps"].append(steps)
            stats["states"].append(len(list(Q.keys())))
            # stats["all_states"] = states
            if (k-1)%mean_episodes == 0:
                q = {s:q.copy() for s,q in Q.items()}
                stats["learned"].append(q)
            if p and (k-1)%mean_episodes == 0:
                # if k>8000:
                #     goal = (1,1)
                #     env.render(P=Q_P(Q,env.goals[0]), V = Q_V(Q,env.goals[0]))
                #     plt.show()
                if save_path:
                    np.save(save_path, stats)
                wandb.log({"episode": k})
                wandb.log({"returns": G})
                wandb.log({"success": S})
                wandb.log({"steps": steps})
                wandb.log({"epsilon": epsilon})
                wandb.log({"states": len(list(Q.keys()))})
                
                mean_return = np.mean(stats["returns"][-mean_episodes:])
                mean_success = np.mean(stats["success"][-mean_episodes:])
                print('Path: ', save_path, 'Episode: ', k-1, ' | Mean return: ', mean_return, ' | Mean success: ', mean_success,
                      ' | States: ', len(list(Q.keys())))
    if save_path:
        np.save(save_path, stats)
                    
    return Q, stats

#########################################################################################

def get_binary_masks(factors):
    masks = []
    n=int(np.ceil(np.log2(factors)))
    m=(2**n)/2
    for i in range(n):
        masks.append([])
        b=False
        for j in range(0,2**n):
            if j>=factors:
                break
            if b:
                masks[i].append(1) #1=True=rmax
            else:
                masks[i].append(0) #0=False=rmin
            if (j+1)%m==0:
                if b:
                    b=False
                else:
                    b=True
        m=m/2
    masks = np.array(masks, dtype=int)
    return masks

def get_all_masks(factors):
    masks = []
    for t in range(2**factors):
        mask = bin(t)[2:]
        mask = "".join((["0"]*(factors-len(mask))))+mask
        mask = np.array([int(g) for g in list(mask)])
        masks.append(mask)
    return masks

def get_onehot_masks(factors):
    masks = np.eye(factors, dtype=int)
    return masks

def get_option(RQ, goal, epsilon=0.5):
    option = np.argmax(RQ[goal])
    if np.random.rand() <= epsilon:
        option = np.random.randint(len(RQ[goal]))        
    return option

def get_mask(MQ, option, epsilon=0.5):
    mask = np.argmax(MQ[option])
    if np.random.rand() <= epsilon:
        mask = np.random.randint(len(MQ[option]))        
    return mask
    
def mask_state(mask, state):
    mstate = tuple((0,)*len(state) if not mask[i] else state[i] for i in range(len(mask)))
    return tuple(mask), mstate


#########################################################################################

def MQ_Q_all(MQ_all, masks, states):
    Q_all = []
    for m in masks:
        Q_all.append({})
        for state in states:
            m_state = mask_state(m,state)
            Q_all[-1][state] = MQ_all[m_state] if m_state in MQ_all else np.zeros(4)
            # Q_all[-1][state] = MQ_all[mask_state(masks[-1],state)] - MQ_all[mask_state(m,state)] # np.isclose(MQ_all[mask_state(masks[-1],state)], MQ_all[mask_state(m,state)], atol=1e-3)
    return Q_all
def MQ_Q_all_render(MQ_all, masks, states, env, use_diff=1, Qr=None, goal=0):
    Q_all = []
    M_max = {}
    Q_max = {}
    states=[]
    for m_,s in MQ_all.keys():
        if m_==(1,1,1,1):
            states.append(s)
    for m in masks:
        Q = {}
        for state in states:
            m_state = mask_state(m,state)
            if use_diff==0:
                Q[state] = MQ_all[m_state] if m_state in MQ_all else np.zeros(4)
            else:
                nm_state = mask_state(masks[-1],state)
                oracle_state = []
                for s in state:
                    oracle_state.append((*s[:-1],0))
                oracle_state = tuple(oracle_state)
                if Qr:
                    Q[state] = Qr[oracle_state] - MQ_all[m_state] if (m_state in MQ_all) and (oracle_state in Qr) else np.zeros(4)
                else:
                    Q[state] = MQ_all[nm_state] - MQ_all[m_state] if (m_state in MQ_all) and (nm_state in MQ_all) else np.zeros(4)
        if use_diff==0:
            Q_all.append(env.render(P=Q_P(Q,env.goals[goal]), V=Q_V(Q,env.goals[goal]), goal=env.goals[goal], mode = 'rgb_array'))
        else:
            Q_all.append(env.render(V=Q_V(Q,env.goals[goal]), goal=env.goals[goal], mode = 'rgb_array'))
    for m,s in MQ_all.keys():
        if m==(1,1,1,1):
            state = s
            values = np.array([MQ_all[mask_state(masks[mask],state)] for mask in range(len(masks))])
            M_max[state] = values.max(axis=1)
            Q_max[state] = values.max(axis=0)
    M_max = env.render( M=Q_P(M_max,env.goals[goal]), goal=env.goals[goal], mode = 'rgb_array')
    Q_max = env.render( P=Q_P(Q_max,env.goals[goal]), V=Q_V(Q_max,env.goals[goal]), goal=env.goals[goal], mode = 'rgb_array')
    return Q_all, M_max, Q_max

def MQ_Q_all_sum(MQ_all_iter, masks, states, use_diff=1, Qr=None):
    Q_all = np.zeros((len(masks),len(MQ_all_iter)))
    mask_best = np.zeros((len(masks[0]),len(MQ_all_iter)))
    mask_best_v = np.zeros((len(MQ_all_iter),2))
    states=[]
    for m_,s in MQ_all_iter[-1].keys():
        if m_==(1,1,1,1):
            states.append(s)
    for i,MQ_all in enumerate(MQ_all_iter):
        m_best = None
        m_best_v = None
        for j,m in enumerate(masks):
            Q = []
            q = []
            for state in states:
                m_state = mask_state(m,state)
                if use_diff==0:
                    q.append(MQ_all[m_state].max() if m_state in MQ_all else 0)
                    Q.append(MQ_all[m_state].max() if m_state in MQ_all else 0)
                else:
                    nm_state = mask_state(masks[-1],state)
                    oracle_state = []
                    for s in state:
                        oracle_state.append((*s[:-1],0))
                    oracle_state = tuple(oracle_state)
                    if Qr:
                        q.append((Qr[oracle_state] - MQ_all[m_state]).max() if (m_state in MQ_all) and (oracle_state in Qr) else 0)
                        Q.append((Qr[oracle_state] - MQ_all[m_state]).max() if (m_state in MQ_all) and (oracle_state in Qr) else 0)
                    else:
                        q.append((MQ_all[nm_state] - MQ_all[m_state]).max() if (m_state in MQ_all) and (nm_state in MQ_all) else 0)
                        Q.append((MQ_all[nm_state] - MQ_all[m_state]).max() if (m_state in MQ_all) and (nm_state in MQ_all) else 0)
                if state[2][:2] != (5,5):
                    q.pop()
            q = np.array(q)
            if type(m_best) == type(None):
                m_best = m
                m_best_v = q
            elif use_diff==0 and np.mean(q) > np.mean(m_best_v):
                m_best = m
                m_best_v = q
            elif use_diff==1 and np.mean(q) < np.mean(m_best_v):
                m_best = m
                m_best_v = q
            Q_all[j,i] = np.mean(Q) # Q
            mask_best[:,i] = m_best
            mask_best_v[i] = np.mean(m_best_v), np.std(m_best_v)
    return Q_all, mask_best, mask_best_v

def MQ_Q(MQ, QM, masks, states, n_actions, start_position=None, mask_type="mask_states"):
    Q = {}
    M = {}
    for state in states:
        if mask_type=="mask_options":
            mask = QM[state[-1]].argmax()
            # M[state][mask] = 1
            M[state] = QM[state[-1]]
            Q[state] = MQ[mask_state(masks[mask],state)]
        if mask_type=="mask_initials":
            init_state = state if start_position else list(QM.keys())[0]
            mask = QM[init_state].argmax()
            # M[state][mask] = 1
            M[state] = QM[init_state]
            Q[state] = MQ[mask_state(masks[mask],state)]
        if mask_type=="mask_states":
            values = np.array([MQ[mask_state(masks[mask],state)] for mask in range(len(masks))])
            mask = values.max(axis=1).argmax()
            # M[state][mask] = 1
            M[state] = values.max(axis=1)
            Q[state] = values.max(axis=0)
    return Q, M

def Q_P(Q, goal):
    P = {}
    for state in Q:
        if (goal == state[0][2:4]) or (goal == state[1][2:4]) or (goal == state[2][2:4]) or (goal == state[3][2:4]):
            s = state[0]
            b = np.prod([state[i][-1] == 0 for i in range(len(state)) ])
            for i in range(len(state)):
                if state[i][:2] != (0,0):
                    s = state[i][:2]
                    break
            if b and s!=goal:
                P[s] = np.argmax(Q[state])
    return P
def Q_V(Q, goal):
    V = {}
    for state in Q:
        if (goal == state[0][2:4]) or (goal == state[1][2:4]) or (goal == state[2][2:4]) or (goal == state[3][2:4]):
            s = state[0]
            b = np.prod([state[i][-1] == 0 for i in range(len(state)) ])
            for i in range(len(state)):
                if state[i][:2] != (0,0):
                    s = state[i][:2]
                    break
            if b and s!=goal:
                V[s] = np.max(Q[state])
    return V
def SQ_P(Q, goal, temperature):
    P = {}
    for state, g in Q:
        if (goal == state[0][2:4]) or (goal == state[1][2:4]) or (goal == state[2][2:4]) or (goal == state[3][2:4]):
            s = state[0]
            b = np.prod([state[i][-1] == 0 for i in range(len(state)) ])
            for i in range(len(state)):
                if state[i][:2] != (0,0):
                    s = state[i][:2]
                    break
            if b:
                v = temperature * np.log(np.sum(np.exp(Q[state,goal] / temperature)))
                dist = np.exp((Q[state,goal]) / temperature)
                action_probs = dist / np.sum(dist)
                P[s] = action_probs.argmax() # np.random.choice(np.arange(len(action_probs)), p=action_probs)
    return P
def SQ_V(Q, goal, temperature):
    V = {}
    for state,g in Q:
        if (goal == state[0][2:4]) or (goal == state[1][2:4]) or (goal == state[2][2:4]) or (goal == state[3][2:4]):
            s = state[0]
            b = np.prod([state[i][-1] == 0 for i in range(len(state)) ])
            for i in range(len(state)):
                if state[i][:2] != (0,0):
                    s = state[i][:2]
                    break
            if b:
                V[s] = temperature * np.log(np.sum(np.exp(Q[(state,goal)] / temperature)))
    return V

def save_EQ(EQ, name):
    data = [[s,[[g,EQ[s][g]] for g in EQ[s]]] for s in EQ]
    np.save('./storage/{0}'.format(name),data, allow_pickle=True)

def load_EQ(name, actions = 5, data_dir='./storage'):
    data = np.load(data_dir + '/{0}'.format(name), allow_pickle=True)
    EQ = {s: defaultdict(lambda: np.zeros(actions), {g:v for (g,v) in gv}) for (s,gv) in data}
    EQ = defaultdict(lambda: defaultdict(lambda: np.zeros(actions)), EQ)
    return EQ

def get_grid_vfs(env, vf):
    vf_ = np.ones([env.m*env.m,env.n*env.n])
    grid = np.zeros([env.m*env.m,env.n*env.n,4])

    for x in range(env.m):
        for y in range(env.n):
            if (x,y) not in env.walls:
                img = np.zeros([env.m, env.n, 4])
                for (i,j) in env.walls:
                    img[i,j,-1] = 1.0
                grid[x*env.m:x*env.m+env.m,y*env.n:y*env.n+env.n] = img

                img = np.zeros([env.m, env.n])+float("-inf")
                states = list(vf.keys())
                for state in states:
                    state_ = state[0]
                    if not env.factored:
                        (i,j) = state[0]
                    else:
                        state_ = [(0,0,0)]*len(state[0])
                        for r, (i,j,_) in enumerate(state[0]):
                            if (i,j) != (0,0):
                                state_[r] = (i,j,0)
                                break
                        state_ = tuple(state_)
                    img[i,j] = vf[(state_,(x,y))].max()
                    # if img[i,j]>0:
                    #     img[i,j] -= 7
                vf_[x*env.m:x*env.m+env.m,y*env.n:y*env.n+env.n] = img
            else:
                img = np.ones([env.m, env.n, 4])
                grid[x*env.m:x*env.m+env.m,y*env.n:y*env.n+env.n] = img
    vf = vf_[env.m:-env.m,env.n:-env.n]
    grid = grid[env.m:-env.m,env.n:-env.n]

    fig = plt.figure(1, figsize=(20, 20), dpi=60, facecolor='w', edgecolor='k')
    plt.clf()
    plt.xticks([])
    plt.yticks([])
    plt.axis("off")
    plt.grid(False)
    ax = fig.gca()

    cmap = 'RdBu_r'#'YlOrRd' if False else 'RdYlBu_r'
    ax.imshow(vf, origin="upper", cmap=cmap, extent=[0, env.n, env.m, 0])
    ax.imshow(grid, origin="upper", extent=[0, env.n, env.m, 0])

    return fig
#########################################################################################
