from abc import ABC, abstractmethod

import numpy as np
import torch
from gym import spaces
from torch.distributions import Uniform

from games.base import MeanFieldGame, ChainedTupleDistribution
from simulator.mean_fields.base import MeanField
from solver.policy.graphon_policy import DiscretizedGraphonFeedbackPolicy
import pdb

class GraphonMeanFieldGame(MeanFieldGame, ABC):
    """
    Models a finite graphon mean field game in discrete time by state extension with alpha (agent graph index).
    """

    def __init__(self, agent_observation_space, agent_action_space, time_steps, initial_state_distribution):
        """
        Initializes
        :param agent_observation_space: observation space
        :param agent_action_space: action space
        :param time_steps: time horizon
        :param initial_state_distribution: random function returning initial state
        :param graphon: the graphon function from [0,1]^2 to [0,1]
        """
        ext_obs_space = spaces.Tuple((spaces.Box(0, 1, shape=()), agent_observation_space))
        ext_isd = ChainedTupleDistribution(Uniform(torch.tensor([0.]), torch.tensor([1.])), initial_state_distribution)

        super().__init__(ext_obs_space, agent_action_space, time_steps, ext_isd)


class FiniteGraphonMeanFieldGame(GraphonMeanFieldGame, ABC):
    """
    Models a graphon mean field game with discrete, finite state space. The states are tuple(alpha, state).
    """

    def __init__(self, agent_observation_space, agent_action_space, time_steps, initial_state_distribution):
        super().__init__(agent_observation_space, agent_action_space, time_steps, initial_state_distribution)

    def next_state(self, t, x, u, mu):
        return tuple([x[0], np.random.choice(range(self.agent_observation_space[1].n), 1, None,
                                             p=self.transition_probs(t, x, u, mu)).item()])

    def observation(self, t, x, u, mu, next_state):
        return next_state

    @abstractmethod
    def reward(self, t, x, u, mu):
        pass
    
    @abstractmethod
    def transition_probs(self, t, x, u, mu: MeanField):
        """
        Returns the row of the transition probability matrix in state x if using action u under mean field ensemble mu
        :param t: time t
        :param x: extended state x, i.e. tuple (alpha, x)
        :param u: action u
        :param mu: mean field mu
        :return: row of the transition probability matrix
        """
        pass

    def transition_probability_matrix(self, t, policy: DiscretizedGraphonFeedbackPolicy, mu: MeanField):
        """
        Returns the full transition probability matrix if using policy u under mean field ensemble mu
        :param t: time t
        :param alpha: graphon index
        :param policy: policy u
        :param mu: mean field mu
        :return: the transition probability matrix
        """
        x_values = np.arange(self.agent_observation_space[1].n)
        u_values = np.arange(policy.action_space.n)
        pmf_values = np.array([policy.pmf(t, x) for x in x_values])
        transition_values = np.array([[self.transition_probs(t, x, u, mu) for x in x_values] for u in u_values])
        probs = np.einsum('ji,ijk->jk',pmf_values,transition_values)
        return probs
