import torch
import numpy as np
import gym


class Network(torch.nn.Module):
    def __init__(self, in_dim: int, out_dim: tuple[int, int], hidden_size: int = 20, activation: str = "ReLu", temperature=1) -> None:
        super(Network, self).__init__()

        # store the hyperparameters.
        self.in_dim = in_dim
        self.out_dim = out_dim  # [0] -> S, [1] -> K.
        self.hidden_size = hidden_size
        self.temperature = temperature
        # define the model structure here.
        self.fc1 = torch.nn.Linear(in_dim, hidden_size).float().cuda()
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size).float().cuda()
        self.final_layers = []

        for _ in range(out_dim[1]):
            self.final_layers.append(torch.nn.Linear(
                hidden_size, out_dim[0]).float().cuda())

        # check the activation function.
        if type(activation) != str:
            TypeError(str(type(activation)) + ' are not supported')

        if activation == "ReLu":
            self.activation = torch.nn.ReLU()
        elif activation == "Softmax":
            self.activation = torch.nn.Softmax(dim=1)
        else:
            ValueError(activation + ' are not supported')

        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, state_vec: torch.Tensor, action_vec: torch.Tensor) -> torch.Tensor:
        # forward performs the forward propagation.
        x = torch.cat((state_vec, action_vec), dim=1)
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))

        for i in range(self.out_dim[1]):
            if i == 0:
                out = self.softmax(self.final_layers[i](x) / self.temperature)
            else:
                out = torch.cat(
                    (out, self.softmax(self.final_layers[i](x) / self.temperature)), dim=0)
        return out.cpu().detach().numpy()


class MDP(gym.Env):
    def __init__(self, S: int, A: int, func: torch.nn.Module, theta: np.ndarray, gamma=0.9):

        # Store the hyperparameters.
        self.S = S
        self.A = A
        self.action_space = np.eye(A)
        self.state_space = np.eye(S)
        self.theta = theta
        self.R = np.random.rand(S, A, S)
        self.gamma = gamma
        self.func = func

        # initialization
        self.state = np.random.randint(S)  # random generate state.
        # history of state and action (S,A).
        self.history = np.array([[0]*A]*S)
        # transitions: Callable[np.ndarray, np.ndarray] = lambda s, a: f(torch.cat((torch.from_numpy(s),torch.from_numpy(a))).cuda())

        # Compute the transition function.
        self.transitions = []
        for i in range(self.S):
            s = torch.from_numpy(np.array([self.state_space[i]])).float().cuda()
            self.transitions.append([])
            for j in range(self.A):
                a = torch.from_numpy(np.array([self.action_space[j]])).float().cuda()
                with torch.no_grad():
                    phi = self.func(s, a)
                self.transitions[i].append(np.dot(np.array([theta]), phi)[0].tolist())

        self.v_star = self.value_iteration(policy="optimal")

    def value_iteration(self, policy="optimal", max_iterations=10**6, delta=10**-3):

        value_fn = np.zeros(self.S)
        R, P = self.R, self.transitions
        for _ in range(max_iterations):
            previous_value_fn = value_fn.copy()
            Q = np.einsum('ijk,ijk -> ij', P, R + self.gamma * value_fn)

            if policy == "optimal":
                value_fn = np.max(Q, axis=1)
            else:
                for s in range(self.S):
                    value_fn[s] = Q[s][policy.select(s)]
            if np.max(np.abs(value_fn - previous_value_fn)) < delta:
                break

        # Get and return optimal policy
        if policy == "optimal":
            self.opt_policy = np.argmax(Q, axis=1)

        return value_fn

    def step(self, action: int):
        self.history[self.state][action] += 1
        transitions = self.transitions[self.state][action]

        # Get next state according to the transition.
        next_state = np.random.choice(
            np.arange(self.S), p=transitions / np.sum(transitions))

        reward = self.R[self.state][action][next_state]
        self.state = next_state

        return next_state, reward

    def reset(self):
        self.state = np.random.randint(self.S)
        self.history = np.array([[0]*self.A]*self.S)
        return self.state


class MDP_hard(gym.Env):
    def __init__(self, S=9, theta=np.array([1.0, 1.0]), gamma=0.9):

        # Store the hyperparameters.
        self.S = S
        self.A = 2
        self.theta = theta
        self.gamma = gamma

        # initialization
        self.state = int(S/2)  # random generate state.
        # history of state and action (S,A).
        self.history = np.array([[0]*2]*S)
        # transitions: Callable[np.ndarray, np.ndarray] = lambda s, a: f(torch.cat((torch.from_numpy(s),torch.from_numpy(a))).cuda())

        # Compute the transition function.
        self.transitions = np.array([[[0.0]*S]*2]*S)
        for i in range(self.S):
            for j in range(2):
                if j == 1:
                    if i == self.S - 1:
                        self.transitions[i][j][self.S-1] = 0.9
                        self.transitions[i][j][self.S-2] = 0.1
                    elif i == 0:
                        self.transitions[i][j][i+1] = 0.9
                        self.transitions[i][j][i] = 0.1
                    else:
                        self.transitions[i][j][i+1] = 0.9
                        self.transitions[i][j][i] = 0.05
                        self.transitions[i][j][i-1] = 0.05
                if j == 0:
                    if i-1 < 0:
                        self.transitions[i][j][0] = 1
                    else:
                        self.transitions[i][j][i-1] = 1

        self.R = np.array([[[0]*S]*2]*S)
        self.R[S-1][1][S-1] = 1
        self.R[0][0][0] = 0.005

        # self.v_star = (self.gamma ** (int(self.S / 2) - 1)) * 10 - (1 - self.gamma)
        self.v_star = self.value_iteration(policy="optimal")

    def func(self, s, a):
        s = np.argmax(s.cpu().numpy())
        a = np.argmax(a.cpu().numpy())
        v = []
        for i in range(self.S):
            if a == 0:
                v.append(np.array([0, self.transitions[s][a][i]]))
            else:
                v.append(np.array([self.transitions[s][a][i], 0]))
        return np.array(v).T

    def value_iteration(self, policy="optimal", max_iterations=10**6, delta=10**-3):

        value_fn = np.zeros(self.S)
        R, P = self.R, self.transitions
        for _ in range(max_iterations):
            previous_value_fn = value_fn.copy()
            Q = np.einsum('ijk,ijk -> ij', P, R + self.gamma * value_fn)

            if policy == "optimal":
                value_fn = np.max(Q, axis=1)
            else:
                for s in range(self.S):
                    value_fn[s] = Q[s][policy.select(s)]
            if np.max(np.abs(value_fn - previous_value_fn)) < delta:
                break

        # Get and return optimal policy
        if policy == "optimal":
            self.opt_policy = np.argmax(Q, axis=1)

        return value_fn

    def step(self, action: int):
        self.history[self.state][action] += 1
        transitions = self.transitions[self.state][action]

        # Get next state according to the transition.
        next_state = np.random.choice(
            np.arange(self.S), p=transitions / np.sum(transitions))

        reward = self.R[self.state][action][next_state]
        self.state = next_state

        return next_state, reward

    def reset(self):
        self.state = int(self.S/2)
        self.history = np.array([[0]*self.A]*self.S)
        return self.state


if __name__ == '__main__':
    # Define hyperparameters here
    S = 5
    A = 4
    K = 3 # len of theta
    gamma = 0.9
    seed = 1
    temperature = 1e-02
    T = 100
    Episodes = 3
    theta = np.array([0.2,0.5,0.3])
    func = Network(S+A,(S, len(theta)), temperature = 0.1)
    action_space = np.eye(A)
    observation_space = np.eye(S)
    state_vec = torch.from_numpy(np.array([observation_space[1]])).float().cuda()
    action_vec = torch.from_numpy(np.array([action_space[1]])).float().cuda()
    phi = func(torch.cat((state_vec,action_vec),dim = 1))

    print(np.shape(phi),np.dot(np.array([theta]),phi))
    print(sum(np.dot(np.array([theta]),phi)[0]))
    # env = MDP_hard()
    # s, r = env.step(1)
    # s, r = env.step(1)
    # s, r = env.step(1)
    # s, r = env.step(1)
    # s, r = env.step(1)
