import os
import sys
sys.path.append(os.getcwd())

import jax
import jax.numpy as np
import numpy as onp

import src.utils.tab_rl as tab_rl
from functools import partial
import itertools

import src.envs.adapt_chain as adapt_chain

class SwitchStay():
    def __init__(self):
        # switch-stay
        self.n_states = 2
        self.n_actions = 2

        self.P = onp.zeros((self.n_states, self.n_actions, self.n_states))
        self.r = onp.zeros((self.n_states, self.n_actions))

        # action 0 is stay, action 1 is switch
        self.r[0, 0] = 1
        self.r[0, 1] = -1
        self.r[1, 0] = 2
        self.r[1, 1] = 0
        # need to ensure all returns are positive so modify switch-stay by adding 1 to everything
        self.r += 1
        self.gamma = 0.9
        self.starting_states = [0]
        self.n_starting_states = 1

        self.P[0, 0, 0] = 1
        self.P[0, 1, 1] = 1
        self.P[1, 0, 1] = 1
        self.P[1, 1, 0] = 1
        self.name = "SwitchStay"

        # for interpolation
        self.optimal_pi = tab_rl.get_greedy_pi(tab_rl.get_optimal_Q(self.P, self.r, 0.9))
        self.optimal_Q = tab_rl.get_Q(self.P, self.r, 0.9, self.optimal_pi)
        self.pi1 = np.array([[0.99, 0.01], [0.01, 0.99]])
        self.pi2 = np.array([[0.01, 0.99], [0.99, 0.01]])
        self.almost_optimal_pi = self.pi2

    # define the following tab_rl functions so that we can use jax.vmap without hassle
    @partial(jax.jit, static_argnums=(0,))
    def get_P_pi(self, pi):
        return np.sum(self.P * np.expand_dims(pi, axis = -1), axis = 1)

    @partial(jax.jit, static_argnums=(0,))
    def get_V(self, pi):
        P_pi = self.get_P_pi(pi)
        r_pi = np.sum(self.r * pi, axis = -1)
        return np.linalg.inv(np.eye(pi.shape[0]) - self.gamma * P_pi) @ r_pi 

    @partial(jax.jit, static_argnums=(0,))
    def get_Q(self, pi):
        n_states = self.P.shape[0]
        V = self.get_V(pi)
        Vp =  np.sum(self.P * V.reshape((1, 1, n_states)), axis = -1)
        return self.r + self.gamma * Vp


class Grid():
    # 0   1   2   3
    # 4   5   6   7
    # 8   9   10  11
    # 12  13  14  15
    # can't transition out of goals
    # 0 = up
    # 1 = right
    # 2 = down
    # 3 = left
    def __init__(self):
        n_states = 4 * 4
        n_actions = 4

        coords = onp.arange(n_states).reshape((4, 4))
        actions = onp.array([[0, 1], [1, 0], [0, -1], [-1, 0]])

        P = onp.zeros((n_states, n_actions, n_states))
        r = onp.zeros((n_states, n_actions))

        goals = [3, 12]
        # add small positive number so that rewards are positive
        r += 0.001

        r[2, 1] = r[7, 0] = 20 # optimal action
        r[8, 2] = r[13, 3] = 10 # suboptimal action
        # but no rewards coming out of goal
        r[3, :] = np.zeros(n_actions)
        r[12, :] = np.zeros(n_actions)

        for i in range(n_states):
            x, y = np.where(coords == i)
            for j in range(n_actions):
                if i in goals: # stay forever at goal
                    P[i, j, i] = 1
                else:
                    a = actions[j]
                    new_ix = [x[0], y[0]] + a
                    # print(new_ix.shape)
                    if new_ix.min() >= 0 and new_ix.max() < 3:
                        # print(coords[new_ix[0], new_ix[1]])
                        sp = coords[new_ix[0], new_ix[1]] 
                        P[i, j, sp] = 1
                    else:
                        P[i, j, i] = 1 # action makes you stay

        # checks on the transition matrix
        for s in range(n_states):
            for a in range(n_actions):
                assert np.sum(P[s, a, :]) == 1, (s, a, P[s, a])

        self.P = P 
        self.r = r 
        self.n_states = n_states 
        self.n_actions = n_actions
        self.name = "Grid"
        self.gamma = 0.99

        # print(r)
        self.pi1 = np.ones((n_states, n_actions)) / n_actions # uniform policy
        self.optimal_Q = tab_rl.action_value_iteration(P, r, self.gamma, onp.random.random((n_states, n_actions)), iters = 50000)
        self.optimal_pi = tab_rl.get_greedy_pi(self.optimal_Q)
        # print(self.optimal_pi)
        self.almost_optimal_pi = np.exp(self.optimal_pi / 0.01)
        self.pi2 = jax.nn.softmax(self.optimal_Q / 0.01, axis = 1)  # softmax optimal policy
        # print(self.pi2)

class GridWorld():
    def __init__(self, _map_str, _dims, name, step_r=-1.0, wall_r=0.0, goal_r_dict={"g": 0.0, "h": 0.0}, obs_type='ids', action_eps=0.0):
        self._dims = _dims # grid dimensions
        self._map_str = _map_str
        self._step_r = step_r # reward for taking a step
        self._wall_r = wall_r # reward for hitting a wall
        self._goal_r_dict = goal_r_dict # reward for reaching the goal
        self._obs_type = obs_type # observation type: 'ids' or 'coords'
        self._action_eps = action_eps # action stochasticity
        self.name = name

        if obs_type == "coords":
            # for use with function-approximation
            self.obs_dim = 2
            self.action_dim = 4
        self._actions = onp.array([[-1, 0], [0, -1], [1, 0], [0, 1]])
        self.P = onp.zeros((onp.prod(self._dims), len(self._actions), onp.prod(self._dims)))
        self.r = onp.zeros((onp.prod(self._dims), len(self._actions))) + step_r
        self._parse_map(self._map_str)
    
    def _parse_map(self, map_str):
        assert onp.prod(self._dims) == len(map_str)
        self._starts = []
        self._walls = onp.full(self._dims, False, dtype=bool)
        self._goals = onp.full(self._dims, False, dtype=bool)
        self._goals_r = onp.full(self._dims, 0, dtype=float)
        for i, c in enumerate(map_str):
            pos = int(i / self._dims[1]), i % self._dims[1]
            if c == 'w':
                self._walls[pos] = True
            elif c == 's':
                self._starts.append(list(pos))
            elif c in self._goal_r_dict.keys():
                self._goals[pos] = True
                self._goals_r[pos] = self._goal_r_dict[c]
            # elif c == 'g':
            # 	self._goals[pos] = True
            # 	self._goals_r[pos] = self._goal_r_dict[c]
        self._starts = onp.array(self._starts)
        self.build_P_r()

    def build_P_r(self):
        # return transition matrix P[s, a, s'] and reward function r[s, a]
        for s in itertools.product(onp.arange(self._dims[0]), onp.arange(self._dims[1])):
            if self._walls[s[0], s[1]] or self._goals[s[0], s[1]]:
                self.P[self.state_id(s), :, self.state_id(s)] = 1
                continue
            for a in range(self.n_actions):
                sp = s + self._actions[a]
                if self._walls[sp[0], sp[1]]:
                    sp = s
                    self.r[self.state_id(s), a] += (1 - self._action_eps) * self._wall_r
                elif self._goals[sp[0], sp[1]]:
                    self.r[self.state_id(s), a] += (1 - self._action_eps) * self._goals_r[sp[0], sp[1]]
                self.P[self.state_id(s), a, self.state_id(sp)] = 1 - self._action_eps # account for action epsilon
                for fake_a in range(self.n_actions):
                    sp = s + self._actions[fake_a]
                    if self._walls[sp[0], sp[1]]:
                        sp = s
                        self.r[self.state_id(s), a] += self._action_eps / 3 * self._wall_r
                    elif self._goals[sp[0], sp[1]]:
                        self.r[self.state_id(s), a] += self._action_eps / 3 * self._goals_r[sp[0], sp[1]]
                    if fake_a != a:
                        self.P[self.state_id(s), a, self.state_id(sp)] = self._action_eps / 3
        # checks
        for s in range(self._dims[0]):
            for a in range(len(self._actions)):
                assert onp.sum(self.P[s, a, :]) == 1, f"P at (s, a) = ({s, a}) does not sum to 1"

    def reset(self):
        self._state = self._starts[onp.random.randint(self._starts.shape[0])]
        if self._obs_type == 'ids':
            self.s0 = self.state_id(self._state)
            return self.state_id(self._state)
        elif self._obs_type == 'coords':
            self.s0 = self._state
            return self._state

    def step(self, a):
        if self._action_eps != 0.0:
            if onp.random.random() <= self._action_eps:
                a = onp.random.randint(4)
        sp = self._state + self._actions[a]
        r = self._step_r
        T = False
        if self._walls[sp[0], sp[1]]:
            sp = self._state
            r += self._wall_r
        elif self._goals[sp[0], sp[1]]:
            T = True
            r += self._goals_r[sp[0], sp[1]]
        self._state = sp
        if self._obs_type == 'ids':
            return self.state_id(self._state), r, T, None
        elif self._obs_type == 'coords':
            return self._state, r, T, None

    def render(self):
        print(self.to_string())

    def to_string(self):
        out_str = "\n"
        for r in range(self._dims[0]):
            for c in range(self._dims[1]):
                if r == self._state[0] and c == self._state[1]:
                    out_str += 'a'
                elif self._walls[r, c]:
                    out_str += 'w'
                elif self._goals[r, c]:
                    out_str += 'g'
                else:
                    out_str += ' '
            out_str += '\n'
        return out_str

    def state_id(self, s):
        return s[0] * self._dims[1] + s[1]

    def to_coords(self, s):
        col = s % self._dims[1]
        row = s // self._dims[1]
        return (row, col)

    @property
    def n_states(self):
        return onp.prod(self._dims)

    @property
    def n_starting_states(self):
        return len(self._starts)

    @property
    def starting_states(self):
        return [self.state_id(s) for s in self._starts]

    @property
    def n_actions(self):
        return len(self._actions)


# wrapper for adaptchain
class AdaptChain():
    def __init__(self):
        self.adaptchain = adapt_chain.AdaptChain(n = 50)
        self.P = self.adaptchain._mdp._P.transpose((1, 2, 0))
        self.r = self.adaptchain._mdp._r
        self.starting_states = [0]
        self.n_starting_states = 1
        self.gamma = 0.98
        self.name = "AdaptChain"

        self.n_states = self.adaptchain.n_states
        self.n_actions = self.adaptchain.n_actions