# Some code adapted from https://github.com/david-lindner/idrl (as mentioned and cited in paper)

from typing import  Optional
import numpy as np
from preferences_offlineRL.envs.common import TabularMDP


class DoubleChain(TabularMDP):
    def __init__(
        self,
        discount_factor: float,
        episode_length: int
    ):
        N = 5
        rewards = np.array([10, 0, 0, 0, 1])
        transitions = np.zeros((2, N, N))
        for i in range(N):
            # Action 0 goes right with proba 0.9
            transitions[0, i, min(N-1, i + 1)] = 0.9
            transitions[0, i, max(0, i - 1)] = 0.1

            # Action 1 goes left with proba 0.5, right with proba 0.5
            transitions[1, i, max(0, i - 1)] = 0.5
            transitions[1, i, min(N-1, i + 1)] = 0.5
        super().__init__(
            N,
            2,
            rewards,
            transitions,
            discount_factor,
            [],
            episode_length,
            2,
            observation_type="state",
            observation_noise=0,
        )



class JunctionMDP(TabularMDP):
    #            -> x <-> x <-> x <-> x <-> x <-> x <-> x <-> x
    # x -> x -> |
    #            -> x <-> x <-> x <-> x <-> x <-> x <-> x <-> x
    def __init__(
        self,
        rewards_chain: np.ndarray,
        rewards_junction_top: np.ndarray,
        rewards_junction_bottom: np.ndarray,
        discount_factor: float,
        episode_length: int,
        init_agent_pos: Optional[int] = None,
        observation_type: str = "state",
        observation_noise: float = 0,
    ):
        # first N states are left chain (from 0 to N-1)
        # next M states are upper right chain (from N to N+M-1)
        # final M states are lower right chain (from N+M to N+2M-1)
        N = rewards_chain.shape[0]
        M = rewards_junction_top.shape[0]
        assert rewards_junction_bottom.shape[0] == M

        transitions = np.zeros((2, N + 2 * M, N + 2 * M))

        # first N states can only go right
        for i in range(N - 1):
            transitions[0, i, min(N - 1, i + 1)] = 1
            transitions[1, i, min(N - 1, i + 1)] = 1

        # at Nth state can go up or down
        transitions[0, N - 1, N] = 1
        transitions[1, N - 1, N + M] = 1

        # for both paths we then have random walks within the M states of the path
        for i in range(M):
            # upper path
            transitions[0, N + i, min(N + M - 1, N + i + 1)] += 0.5
            transitions[1, N + i, min(N + M - 1, N + i + 1)] += 0.5
            transitions[0, N + i, max(N, N + i - 1)] += 0.5
            transitions[1, N + i, max(N, N + i - 1)] += 0.5
            # lower path
            transitions[0, N + M + i, min(N + 2 * M - 1, N + M + i + 1)] += 0.5
            transitions[1, N + M + i, min(N + 2 * M - 1, N + M + i + 1)] += 0.5
            transitions[0, N + M + i, max(N + M, N + M + i - 1)] += 0.5
            transitions[1, N + M + i, max(N + M, N + M + i - 1)] += 0.5

        super().__init__(
            N + 2 * M,
            2,
            np.array(
                list(rewards_chain)
                + list(rewards_junction_top)
                + list(rewards_junction_bottom)
            ),
            transitions,
            discount_factor,
            [],
            episode_length,
            init_agent_pos,
            observation_type=observation_type,
            observation_noise=observation_noise,
        )


class QuadrupleChain(TabularMDP):
    def __init__(
        self,
        discount_factor: float,
        episode_length: int
    ):
        N_states = 9
        N_actions = 4
        rewards = np.zeros((N_states,))
        rewards[2]=1
        rewards[4]=10
        transitions = np.zeros((N_actions, N_states, N_states))
        
        transitions[0, 0, 1] = 0.9
        transitions[0, 1, 2] = 0.9
        transitions[0, 5, 0] = 0.9
        transitions[0, 6, 5] = 0.9
        transitions[0, 0, 0] = 1-transitions[0, 0, 1]
        transitions[0, 1, 1] = 1-transitions[0, 1, 2]
        transitions[0, 5, 5] = 1-transitions[0, 5, 0]
        transitions[0, 6, 6] = 1-transitions[0, 6, 5]
        for i in range(N_states): 
            if i == 0 or i == 1 or i == 5 or i == 6: continue
            transitions[0, i, i] = 1

        transitions[1, 0, 3] = 0.5
        transitions[1, 3, 4] = 0.5
        transitions[1, 7, 0] = 0.5
        transitions[1, 8, 7] = 0.5
        transitions[1, 0, 0] = 1-transitions[1, 0, 3]
        transitions[1, 3, 3] = 1-transitions[1, 3, 4]
        transitions[1, 7, 7] = 1-transitions[1, 7, 0]
        transitions[1, 8, 8] = 1-transitions[1, 8, 7]
        for i in range(N_states): 
            if i == 0 or i == 3 or i == 7 or i == 8: continue
            transitions[1, i, i] = 1

        transitions[2, 0, 5] = 0.9
        transitions[2, 5, 6] = 0.9
        transitions[2, 1, 0] = 0.9
        transitions[2, 2, 1] = 0.9
        transitions[2, 0, 0] = 1-transitions[2, 0, 5]
        transitions[2, 5, 5] = 1-transitions[2, 5, 6]
        transitions[2, 1, 1] = 1-transitions[2, 1, 0]
        transitions[2, 2, 2] = 1-transitions[2, 2, 1]
        for i in range(N_states): 
            if i == 0 or i == 5 or i == 2 or i == 1: continue
            transitions[2, i, i] = 1

        transitions[3, 0, 7] = 0.9
        transitions[3, 7, 8] = 0.9
        transitions[3, 3, 0] = 0.9
        transitions[3, 4, 3] = 0.9
        transitions[3, 0, 0] = 1-transitions[3, 0, 7]
        transitions[3, 7, 7] = 1-transitions[3, 7, 8]
        transitions[3, 3, 3] = 1-transitions[3, 3, 0]
        transitions[3, 4, 4] = 1-transitions[3, 4, 3]
        for i in range(N_states): 
            if i == 0 or i == 7 or i == 3 or i == 4: continue
            transitions[3, i, i] = 1

        super().__init__(
            N_states,
            N_actions,
            rewards,
            transitions,
            discount_factor,
            [],
            episode_length,
            0,
            observation_type="state",
            observation_noise=0,
        )

class NChain(TabularMDP):
    def __init__(self,
                 N_chains: int,
                 discount_factor: int,
                 episode_length: int,
                 ):
        N_states = 2 * N_chains + 1
        N_actions = N_chains

        rewards = np.zeros((N_states,))
        rewards[2]=1
        rewards[4]=10

        transitions = np.zeros((N_actions, N_states, N_states))
        transitions[0, 0, 1] = 0.9
        transitions[0, 1, 2] = 0.9
        transitions[0, 0, 0] = 1-transitions[0, 0, 1]
        transitions[0, 1, 1] = 1-transitions[0, 1, 2]
        
        transitions[0, 2, 2] = 1
        for c in range(1, N_chains):
            transitions[0, 2*c+1, 0] = 0.9
            transitions[0, 2*c+2, 2*c+1] = 0.9
            transitions[0, 2*c+1, 2*c+1] = 1-transitions[0, 2*c+1, 0]
            transitions[0, 2*c+2, 2*c+2] = 1-transitions[0, 2*c+2, 2*c+1]
        
        transitions[1, 0, 3] = 0.5
        transitions[1, 3, 4] = 0.5
        transitions[1, 0, 0] = 1-transitions[1, 0, 3]
        transitions[1, 3, 3] = 1-transitions[1, 3, 4]
        for i in range(N_states): 
            if i == 0 or i == 3: continue
            transitions[1, i, i] = 1

        for i in range(2,N_actions):
            if i == 1: continue
            transitions[i, 0, 2*i+1] = 0.9
            transitions[i, 2*i+1, 2*i+2] = 0.9
            transitions[i, 0, 0] = 1-transitions[i, 0, 2*i+1]
            transitions[i, 2*i+1, 2*i+1] = 1-transitions[i, 2*i+1, 2*i+2]
            for j in range(N_states):
                if j == 0 or j == 2*i+1: continue
                transitions[i, j, j] = 1

        super().__init__(
            N_states,
            N_actions,
            rewards,
            transitions,
            discount_factor,
            [],
            episode_length,
            0,
            observation_type="state",
            observation_noise=0,
        )



class StarMDP(TabularMDP):
    def __init__(
        self,
        discount_factor: float = 1,
        episode_length: int = 3
    ):
        
        N_states = 5
        N_actions = 4

        transitions = np.zeros((N_actions, N_states, N_states))

        transitions[0, 0, 1] = 0.9
        transitions[0, 0, 0] = 0.1
        transitions[1, 0, 0] = 1
        transitions[2, 0, 0] = 1
        transitions[3, 0, 0] = 1

        transitions[0, 1, 2] = 0.9
        transitions[0, 1, 1] = 0.1
        transitions[1, 1, 0] = 0.9
        transitions[1, 1, 1] = 0.1
        transitions[2, 1, 4] = 0.9
        transitions[2, 1, 1] = 0.1
        transitions[3, 1, 3] = 0.9
        transitions[3, 1, 1] = 0.1

        transitions[0, 2, 2] = 1
        transitions[1, 2, 1] = 0.9
        transitions[1, 2, 2] = 0.1
        transitions[2, 2, 2] = 1
        transitions[3, 2, 2] = 1

        transitions[0, 3, 3] = 1
        transitions[1, 3, 3] = 1
        transitions[2, 3, 1] = 0.9
        transitions[2, 3, 3] = 0.1
        transitions[3, 3, 3] = 1

        transitions[0, 4, 4] = 1
        transitions[1, 4, 4] = 1
        transitions[2, 4, 4] = 1
        transitions[3, 4, 4] = 0.9
        transitions[3, 4, 1] = 0.1


        rewards = np.array([0,0,6,-1,10])


        super().__init__(
            N_states,
            N_actions,
            rewards,
            transitions,
            discount_factor,
            [],
            episode_length,
            initial_state=0,
            observation_type="state",
            observation_noise=0,
        )