from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, TypeVar, Generic, Callable, Union, cast
import numpy as np

from mytypes import NumExpr




S = TypeVar('S', contravariant=True)
class Strategy(Generic[S], ABC):

    _saved_actions: Union[None, Dict[S, Tuple[int, Dict[int, int]]]] = None

    @abstractmethod
    def choice(self, state: S) -> int:
        raise NotImplementedError

    @abstractmethod
    def mixedAction(self, state: S) -> List[float]:
        raise NotImplementedError

    @abstractmethod
    def getReward(self, state: S, next_state: S, action:int, reward:float) -> None:
        raise NotImplementedError
    
    @abstractmethod
    def description(self) -> str:
        raise NotImplementedError

    def saveAction(self, state: S, action: int) -> None:
        if self._saved_actions is not None:
            if state not in self._saved_actions:
                d: Dict[int, int] = {}
                self._saved_actions[state] = (0, d)
            t, a = self._saved_actions[state]
            if action not in a: a[action] = 0
            a[action] += 1
            self._saved_actions[state] = (t+1, a)


    def savedActions(self) -> Dict[S, Any]:
        if self._saved_actions is None:
            raise Exception("save not enabled")
        else:
            return {s:(np.array([a[1][i] if i in a[1] else 0 for i in range(max(a[1])+1)])/a[0]) for (s, a) in self._saved_actions.items() }

    def enableSave(self):
        self._saved_actions = {}

    def getInformedReward(self, state:S, next_state: S, player: int, actions: List[int], reward:float):
        return self.getReward(state, next_state, actions[player], reward)

class ObliviousStrategy(Strategy[S]):
    def choice(self, state: S) -> int:
        distr = self.mixedAction(state)
        return np.random.choice(range(len(distr)), p=distr)

    def mixedAction(self, state: S) -> List[float]:
        return self.f(self.t, state)
    
    def getReward(self, state:S, next_state: S, action: int, reward: float) -> None:
        # no learning (oblivious strategy)
        self.t += 1

    def __init__(self, f: Callable[[int, S], List[float]]) -> None:
        self.f = f
        self.t = 0
        super().__init__()
    
    def description(self) -> str:
        return "Oblivious " + str(self.f)

class StationaryStrategy(ObliviousStrategy[S]):
    def __init__(self, s: Dict[S, List[NumExpr]]):
        self.stationary = s
        super().__init__(lambda _, state: cast(List[float], s[state]))

    def description(self) -> str:
        return "stationary"

    def __repr__(self) -> str:
        ks:List[S] = list(self.stationary.keys())
        if len(ks) == 1:
            return self.stationary[ks[0]].__repr__()
        else:
            return self.stationary.__repr__()

    def __eq__(self, o: object) -> bool:
        try:
            # FIXME: for some reason it is not possible to check that o is a StationaryStrategy ??
            return o.stationary == self.stationary # type: ignore
        except:
            return super().__eq__(o)

class EncodedStrategy(Strategy[S]):
    def __init__(self, s: Strategy[int], state_encode: Callable[[S], int]):
        self.s = s
        self.state_encode = state_encode

    def choice(self, state: S) -> int:
        return self.s.choice(self.state_encode(state))
    
    def description(self) -> str:
        return self.s.description()

    def mixedAction(self, state: S) -> List[float]:
        return self.s.mixedAction(self.state_encode(state))

    def getReward(self, state: S, next_state: S, action: int, reward: float) -> None:
        return self.s.getReward(self.state_encode(state), self.state_encode(next_state), action, reward)

    def getRewardSarsa(self, state: S, next_state: S, action: int, next_action: int, reward: float) -> None:
        return self.s.getReward(self.state_encode(state), self.state_encode(next_state), action, next_action, reward)

def stateEncode(maxInt: int) -> Callable[[List[int]], int]:
    encMat: list[None | list[int]] = [None]
    def enc(s: List[int]) -> int:
        if encMat[0] is None:
            encMat[0] = [maxInt**i for i in range(len(s))]
        return np.dot(s, encMat[0])
    return enc

"""
Faster than stateEncode, but only valid for list of length 2
"""
def stateEncode2(maxInt: int) -> Callable[[List[int]], int]:
    def enc(s: List[int]) -> int:
        return s[0]*maxInt + s[1]
    return enc

def stateEncodeNone(l:Any) -> int:
    return 0

def StateDecode(s: int, maxInt: int, length:int) -> list[int]:
    '''
    s : int
    length : the length of the list to be returned
    '''
    dec: Callable[[int, int], list[int]] = lambda s, maxInt: [s%maxInt] + dec(s//maxInt, maxInt) if s>0 else []
    StateDec = dec(s, maxInt)
    return StateDec + [0] * (length-len(StateDec))
