from abc import ABC, abstractmethod
from IPython.display import display, Markdown
from typing import Any, Callable, List, Generic, TypeVar, Tuple, cast, no_type_check
from types import MethodType
from sympy.matrices import Matrix
import z3

import numpy as np
import numpy.typing
import sympy as sym

from strategy import Strategy

from mytypes import NumExpr


S_ = TypeVar('S_')
A = TypeVar('A')
distr = List[float]


class Arith(Generic[A], ABC):
    @abstractmethod
    def array(self, a:List[NumExpr]) -> A:
        raise NotImplementedError
    @abstractmethod
    def sum(self, a:A, b: A) -> A:
        raise NotImplementedError
    @abstractmethod
    def mul(self, c: NumExpr, a:A) -> A:
        raise NotImplementedError
    

class ArithNumPy(Arith[numpy.typing.NDArray[np.float64]]):
    def array(self, a: List[NumExpr]) -> numpy.typing.NDArray[np.float64]:
        return np.array(cast(List[float], a))

    def sum(self, a:numpy.typing.NDArray[np.float64], b:numpy.typing.NDArray[np.float64]) -> numpy.typing.NDArray[np.float64]:
        return a + b
    
    def mul(self, c: NumExpr, a: numpy.typing.NDArray[np.float64]) -> numpy.typing.NDArray[np.float64]:
        return cast(float, c)*a

M = sym.Matrix

# Operations on column vectors
class ArithSymPy(Arith[M]):
    def array(self, a: List[NumExpr]) -> M:
        return sym.Matrix([[a_] for a_ in a])
    
    def sum(self, a: M, b: M) -> M:
        return a + b
    
    def mul(self, c:NumExpr, a:M) -> M:
        o:M = sym.sympify(c)*M
        return o

arithnumpy = ArithNumPy()
arithsympy = ArithSymPy()

class Game(Generic[S_], ABC):

    savedActions: List[List[int]] = []
    savedSupport: Any
    savedRewards: List[List[float]] = []
    maxPlayer: int = -1

    def __init__(self, maxPlayer: int):
        self.maxPlayer = maxPlayer

    @abstractmethod
    def rewards(self, state: S_, actions: List[int]) -> Tuple[S_, List[NumExpr]]:
        raise NotImplementedError

    @abstractmethod
    def maxAction(self, state: S_, player: int) -> int:
        raise NotImplementedError

    def mixedRewardsGen(self, state: S_, actions: List[List[float]], arith:Arith[A]) -> A:
        maxAction = max([self.maxAction(state, i)
                        for i in range(self.maxPlayer)])

        r: A = arith.array([0.]*self.maxPlayer)
        for i in range(maxAction**self.maxPlayer):
            pureActions: List[int] = [
                int(i/(maxAction**j)) % maxAction for j in range(self.maxPlayer)]
            validActions = True
            probaCurrentStrat = 1.
            for j in range(self.maxPlayer):
                validActions = validActions and self.maxAction(
                    state, j) > pureActions[j]
                if validActions:
                    probaCurrentStrat = probaCurrentStrat * \
                        actions[j][pureActions[j]]
            if not validActions:
                continue
            else:
                rew: A = arith.array(self.rewards(state, pureActions)[1])
                r: A = arith.sum(r, arith.mul(probaCurrentStrat, rew))
        return r
    
    def mixedRewards(self, state: S_, actions: List[List[float]]):
        return self.mixedRewardsGen(state, actions, arithnumpy)

    def mixedRewardsSymbolic(self, state: S_, actions: List[List[float]]) -> Matrix:
        return self.mixedRewardsGen(state, actions, arithsympy)

    def recordActions(self, actions: List[int]) -> None:
        pass

    def recordRewards(self, rewards: List[NumExpr]) -> None:
        pass

    def play(self, state: S_, players: list[Strategy[S_]]) -> S_:
        return self.playWithRewards(state, players)[0]
    
    def playWithRewards(self, state: S_, players: list[Strategy[S_]]) -> Tuple[S_, List[NumExpr]]:
        # slightly faster, but maybe not a good idea…
        if self.maxPlayer == 2:
            p1, p2 = tuple(players)
            actions = (p1.choice(state), p2.choice(state))
            self.recordActions(list(actions))
            new_state, rewards = self.rewards(state, list(actions))
            self.recordRewards(rewards)
            p1.saveAction(state, actions[0])
            p1.getInformedReward(state, new_state, 0, actions, cast(float, rewards[0]))
            p2.saveAction(state, actions[1])
            p2.getInformedReward(state, new_state, 1, actions, cast(float, rewards[1]))
            return new_state, rewards
        else:
            actions = [p.choice(state) for p in players]
            self.recordActions(actions)
            new_state, rewards = self.rewards(state, actions)
            self.recordRewards(rewards)
            i = 0
            for player, action, reward in zip(players, actions, rewards):
                player.saveAction(state, action)
                player.getInformedReward(state, new_state, i, actions, cast(float, reward))
                i += 1
            return new_state, rewards

    def repeatPlay(self, state: S_, players: list[Strategy[S_]], n: int) -> S_:
        for _ in range(n):
            state = self.play(state, players)
        return state

    def averageActions(self) -> List[distr]:
        playersAction: List[List[int]] = list(
            zip(*self.savedActions))  # type: ignore
        playersMax = [int(max(a))+1 for a in playersAction]
        playersActionMax: List[Tuple[List[int], int]] = list(
            zip(playersAction, playersMax))
        playersAverageAction = [np.average(np.array(
            [[0]*a + [1] + [0]*(m-a-1) for a in pa]), axis=0) for (pa, m) in playersActionMax]
        return playersAverageAction


def showStrategies(g: Game[None], model: z3.ModelRef):
    for j in range(g.maxPlayer):
        maxActionj: int = g.maxAction(None, j)
        display(Markdown("player " + str(j) +
                         " (" + ", ".join([str(model[z3.Real("m" + str(j) + "_" + str(a))])
                                           for a in range(maxActionj)]) + ")"
                         + " : " + str(model[z3.Real("p" + str(j))])
                         ))


def doRecordActions(self: Game[Any], actions: List[int]) -> None:
    self.savedActions.append(actions)


def doRecordRewards(self: Game[Any], rewards: List[float]) -> None:
    self.savedRewards.append(rewards)

def doRecordActionsSupport(self: Game[Any], actions: List[int]) -> None:
    self.savedSupport[actions[0], actions[1]] += 1

@no_type_check
def enableLogging(g: Game[Any]) -> None:
    funcType = MethodType
    g.savedActions = []
    g.recordActions = funcType(doRecordActions, g)
    g.savedRewards = []
    g.recordRewards = funcType(doRecordRewards, g)

"""
Only for two-players actions
"""
@no_type_check
def enableSupport(g: Game[Any], maxactp1:int, maxactp2:int) -> None:
    funcType = MethodType
    g.savedSupport = np.zeros([maxactp1, maxactp2])
    g.recordActions = funcType(doRecordActionsSupport, g)

R = TypeVar('R')
S = TypeVar('S')
class ImperfectRecall(Game[Tuple[S, R]]):
    def __init__(self, g: Game[S], f:Callable[[S, List[int], List[NumExpr]], R]):
        super().__init__(g.maxPlayer)
        self.g = g
        self.recall = f
    
    def rewards(self, state: Tuple[S, R], actions: List[int]) -> Tuple[Tuple[S, R], List[NumExpr]]:
        stateg, rewards = self.g.rewards(state[0], actions)
        return (stateg, self.recall(state[0], actions, rewards)), rewards

    def maxAction(self, state:Tuple[S, R], player: int) -> int:
        return self.g.maxAction(state[0], player)

class ImperfectRecall2(Game[R]):
    def __init__(self, g: Game[S], f:Callable[[S, List[int], List[NumExpr]], R], finv:Callable[[R], S]):
        super().__init__(g.maxPlayer)
        self.g = g
        self.recall = f
        self.invrecall = finv
    
    def rewards(self, state: R, actions: List[int]) -> Tuple[R, List[NumExpr]]:
        stateg, rewards = self.g.rewards(self.invrecall(state), actions)
        return (self.recall(stateg, actions, rewards)), rewards

    def maxAction(self, state:R, player: int) -> int:
        return self.g.maxAction(self.invrecall(state), player)

def OneRecall(g: Game[None]) -> Game[List[int]]:
    return ImperfectRecall2(g, lambda x, y, z: y, lambda _: None)

def OneRecallInt(g: Game[None]) -> Game[int]:
    return ImperfectRecall2(g, lambda x, y, z:  y[0]*2 + y[1], lambda _: None)

class NRecall(Game[list[int]]):
    def __init__(self, g: Game[None]):
        self.g = g
        super().__init__(g.maxPlayer)
    
    def rewards(self, state: list[int], actions: List[int]) -> Tuple[list[int], List[NumExpr]]:
        return (state[self.g.maxPlayer:] + actions), self.g.rewards(None, actions)[1]
    
    def maxAction(self, state: list[int], player: int) -> int:
        return self.g.maxAction(None, player)