from src.searchlight.headers import State
from src.searchlight.headers import InitialInferencer2
from typing import Any

'''
We define a couple of inferencers to test the search functionality

Recall that inferences have prediction functions of the following form:

predict(self, state: State) -> tuple[set, dict, set, dict, dict, dict]

args:
    state: current state
returns:
    actors: set of actors that may take actions at the state. if empty set, then the state is terminal
    policies: dict from actors to dict of action to probability
    actions: set of (joint) actions (tuples of tuples (actor, action))
    next_state_values: dict from next_state to actors to expected value for the actor of the next state
    intermediate_rewards: dict from (joint) actions to intermediate rewards
    transitions: dict from (joint) actions to next states
'''

class ThreeState(InitialInferencer2):
    '''
    Single actor (player) with 3 states (root and two children)
    '''

    STATES = [State(0), State(1), State(2)]

    def _predict(self, state: State) -> tuple[dict, dict, dict[tuple[tuple[Any, Any],...],Any], dict[tuple[tuple[Any, Any],...],Any], dict]:
        if state == self.STATES[0]:
            actors = {0}
            policies = {0: {0: 0.5, 1: 0.5}}
            actions = {((0, 0),), ((0, 1),)}
            next_state_values = {self.STATES[1]: {0: 0.0}, self.STATES[2]: {0: 0.0}}
            intermediate_rewards = {((0, 0),): {0:0.0}, ((0, 1),): {0:1.0}}
            transitions = {((0, 0),): self.STATES[1], ((0, 1),): self.STATES[2]}
        else:
            actors = None
            policies = dict()
            actions = None
            next_state_values = dict()
            intermediate_rewards = dict()
            transitions = dict()
        return policies, next_state_values, intermediate_rewards, transitions, dict()

class FiveState(InitialInferencer2):
    '''
    Single actor (player) with 5 states, graph as follows:

    0 -> 1, 2
    1 -> 3, 4
    '''

    STATES = [State(0), State(1), State(2), State(3), State(4)]

    def _predict(self, state: State) -> tuple[dict, dict, dict[tuple[tuple[Any, Any],...],Any], dict[tuple[tuple[Any, Any],...],Any], dict]:
        if state == self.STATES[0]:
            actors = {0}
            policies = {0: {0: 0.5, 1: 0.5}}
            actions = {((0, 0),), ((0, 1),)}
            next_state_values = {self.STATES[1]: {0: 0.0}, self.STATES[2]: {0: 0.0}}
            intermediate_rewards = {((0, 0),): {0:1.0}, ((0, 1),): {0:2.0}}
            transitions = {((0, 0),): self.STATES[1], ((0, 1),): self.STATES[2]}
        elif state == self.STATES[1]:
            actors = {0}
            policies = {0: {0: 0.5, 1: 0.5}}
            actions = {((0, 0),), ((0, 1),)}
            next_state_values = {self.STATES[3]: {0: 3.0}, self.STATES[4]: {0: 0.0}}
            intermediate_rewards = {((0, 0),): {0:1.0}, ((0, 1),): {0:2.0}}
            transitions = {((0, 0),): self.STATES[3], ((0, 1),): self.STATES[4]}
        else:
            actors = None
            policies = dict()
            actions = None
            next_state_values = dict()
            intermediate_rewards = dict()
            transitions = dict()
        return policies, next_state_values, intermediate_rewards, transitions, dict()
    
class FourChain(InitialInferencer2):
    '''
    Single actor (player) with 4 states in a chain

    0 -> 1 -> 2 -> 3
    '''

    STATES = [State(0), State(1), State(2), State(3)]

    def _predict(self, state: State) -> tuple[dict, dict, dict[tuple[tuple[Any, Any],...],Any], dict[tuple[tuple[Any, Any],...],Any], dict]:
        if state == self.STATES[0]:
            actors = {0}
            policies = {0: {0: 1.0}}
            actions = {((0, 0),), }
            next_state_values = {self.STATES[1]: {0: 2.0}, }
            intermediate_rewards = {((0, 0),): {0:1.0}, }
            transitions = {((0, 0),): self.STATES[1], }
        elif state == self.STATES[1]:
            actors = {0}
            policies = {0: {0: 1.0}}
            actions = {((0, 0),), }
            next_state_values = {self.STATES[2]: {0: 1.0}, }
            intermediate_rewards = {((0, 0),): {0:0.0}, }
            transitions = {((0, 0),): self.STATES[2], }
        elif state == self.STATES[2]:
            actors = {0}
            policies = {0: {0: 1.0}}
            actions = {((0, 0),), }
            next_state_values = {self.STATES[3]: {0: 0.0}, }
            intermediate_rewards = {((0, 0),): {0:3.0}, }
            transitions = {((0, 0),): self.STATES[3], }
        else:
            actors = None
            policies = dict()
            actions = None
            next_state_values = dict()
            intermediate_rewards = dict()
            transitions = dict()
        return policies, next_state_values, intermediate_rewards, transitions, dict()
    
class SevenState(InitialInferencer2):
    '''
    Single actor (player) with 7 states, graph as follows:

    0 -> 1, 2
    1 -> 3, 4
    2 -> 5, 6
    '''

    STATES = [State(0), State(1), State(2), State(3), State(4), State(5), State(6)]

    def _predict(self, state: State) -> tuple[dict, dict, dict[tuple[tuple[Any, Any],...],Any], dict[tuple[tuple[Any, Any],...],Any], dict]:
        if state == self.STATES[0]:
            actors = {0}
            policies = {0: {0: 0.5, 1: 0.5}}
            actions = {((0, 0),), ((0, 1),)}
            next_state_values = {self.STATES[1]: {0: 2.0}, self.STATES[2]: {0: 0.0}}
            intermediate_rewards = {((0, 0),): {0:1.0}, ((0, 1),): {0:1.0}}
            transitions = {((0, 0),): self.STATES[1], ((0, 1),): self.STATES[2]}
        elif state == self.STATES[1]:
            actors = {0}
            policies = {0: {0: 0.5, 1: 0.5}}
            actions = {((0, 0),), ((0, 1),)}
            next_state_values = {self.STATES[3]: {0: 10.0}, self.STATES[4]: {0: 10.0}}
            intermediate_rewards = {((0, 0),): {0:1.0}, ((0, 1),): {0:1.0}}
            transitions = {((0, 0),): self.STATES[3], ((0, 1),): self.STATES[4]}
        elif state == self.STATES[2]:
            actors = {0}
            policies = {0: {0: 0.5, 1: 0.5}}
            actions = {((0, 0),), ((0, 1),)}
            next_state_values = {self.STATES[5]: {0: 0.0}, self.STATES[6]: {0: 0.0}}
            intermediate_rewards = {((0, 0),): {0:2.0}, ((0, 1),): {0:0.0}}
            transitions = {((0, 0),): self.STATES[5], ((0, 1),): self.STATES[6]}
        else:
            actors = set()
            policies = dict()
            actions = set()
            next_state_values = dict()
            intermediate_rewards = dict()
            transitions = dict()
        return policies, next_state_values, intermediate_rewards, transitions, dict()

INFERENCERS = {'three_state': ThreeState, 'five_state': FiveState, 'four_chain': FourChain, 'seven_state': SevenState}