import numpy as np
from gym.spaces import Discrete

from fast_marl_graphex import FastGraphexEnv


class RumorGraphex(FastGraphexEnv):
    """
    Models the Rumor game.
    """

    def __init__(self, graphex = 'separable', infection_rate: float = 0.3, time_steps: int = 50,
                 initial_infection_prob: float = 0.1, cost_infection: float = 0.8, reward_infection: float = 0.5,
                 **kwargs):
        self.infection_rate = infection_rate
        self.initial_infection_prob = initial_infection_prob
        self.reward_infection = reward_infection
        self.cost_infection = cost_infection

        observation_space = Discrete(4)
        action_space = Discrete(2)

        mu_0 = np.array([1 - initial_infection_prob, initial_infection_prob] + [0] * 2)

        super().__init__(graphex, observation_space, action_space, time_steps, mu_0, **kwargs)

    def next_states(self, t, xs, us):
        G_infecting = self.adj_matrix.multiply(xs == 3).sum(axis=1).A.flatten() / self.degrees
        infections = np.random.rand(self.num_agents) < self.infection_rate * G_infecting \
                     * (2/(1 + np.exp(- 0.5 * self.degrees))-1)
        new_xs = (xs == 3) * 1 \
                 + (xs == 2) * 1 \
                 + (xs == 1) * us * 3 \
                 + (xs == 1) * (1 - us) * 2 \
                 + (xs == 0) * infections * 1 \
                 + (xs == 0) * (1 - infections) * 0

        return new_xs

    def reward(self, t, xs, us):
        G_infected = self.adj_matrix.multiply((xs == 2) + (xs == 3)).sum(axis=1).A.flatten() / self.degrees
        G_uninfected = self.adj_matrix.multiply(xs == 0).sum(axis=1).A.flatten() / self.degrees
        rewards = (xs == 3) * (self.reward_infection * G_uninfected - self.cost_infection * G_infected)
        return rewards

    def get_P_k_G(self, t, k, G):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        P[:, 3, 1] = 1
        P[:, 2, 1] = 1
        P[1, 1, 3] = 1
        P[0, 1, 2] = 1
        P[:, 0, 1] = self.infection_rate * (G[3]) * (2/(1 + np.exp(- 0.5 * k))-1)
        P[:, 0, 0] = 1 - P[:, 0, 1]

        return P  # Return joint transition matrices over actions U on X for a given degree and neighborhood

    def get_R_k_G(self, t, k, G):
        R = np.zeros((self.observation_space.n, self.action_space.n))

        R[3, :] += self.reward_infection * (G[0])
        R[3, :] -= self.cost_infection * (G[2] + G[3])

        return R  # Return array X x U of expected rewards for a given degree and neighborhood

    def get_P_high(self, t, mu):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        P[:, 3, 1] = 1
        P[:, 2, 1] = 1
        P[1, 1, 3] = 1
        P[0, 1, 2] = 1
        P[:, 0, 1] = self.infection_rate * (mu[3])
        P[:, 0, 0] = 1 - P[:, 0, 1]

        return P  # Return joint transition matrices over actions U on X for high degrees and given mf

    def get_R_high(self, t, mu):
        R = np.zeros((self.observation_space.n, self.action_space.n))

        R[3, :] += self.reward_infection * (mu[0])
        R[3, :] -= self.cost_infection * (mu[2] + mu[3])

        return R  # Return array X x U of expected rewards for high degrees and given mf
