import numpy as np
from gym.spaces import Discrete

from fast_marl_graphex import FastGraphexEnv


class SISGraphex(FastGraphexEnv):
    """
    Models the SIS game.
    """

    def __init__(self, graphex='separable', infection_rate: float = 0.2, recovery_rate: float = 0.05,
                 time_steps: int = 500, initial_infection_prob: float = 0.5, cost_infection: float = 1,
                 cost_action: float = 0.5, **kwargs):
        self.infection_rate = infection_rate
        self.recovery_rate = recovery_rate
        self.initial_infection_prob = initial_infection_prob
        self.cost_infection = cost_infection
        self.cost_action = cost_action

        observation_space = Discrete(2)
        action_space = Discrete(2)

        mu_0 = np.array([1 - initial_infection_prob, initial_infection_prob])

        super().__init__(graphex, observation_space, action_space, time_steps, mu_0, **kwargs)

    def next_states(self, t, xs, us):
        G_infected = self.adj_matrix.multiply(xs).sum(axis=1).A.flatten() / self.degrees
        recoveries = np.random.rand(self.num_agents) < self.recovery_rate
        infections = np.random.rand(self.num_agents) < self.infection_rate * G_infected \
                     * (2/(1 + np.exp(- 0.5 * self.degrees))-1)  # * (1 - np.exp(- (3 - self.degrees) ** 2 / 2 ** 2))
        new_xs = xs * (1 - recoveries) \
                 + (1 - xs) * (1 - us) * infections \
                 + (1 - xs) * us * 0

        return new_xs

    def reward(self, t, xs, us):
        rewards = - self.cost_infection * xs - self.cost_action * us
        return rewards

    def get_P_k_G(self, t, k, G):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        P[:, 1, 0] = self.recovery_rate
        P[:, 1, 1] = 1 - P[:, 1, 0]
        P[0, 0, 1] = self.infection_rate * (G[1]) * (2/(1 + np.exp(- 0.5 * k))-1)
        P[0, 0, 0] = 1 - P[0, 0, 1]
        P[1, 0, 0] = 1

        return P  # Return joint transition matrices over actions U on X for a given degree and neighborhood

    def get_R_k_G(self, t, k, G):
        R = np.zeros((self.observation_space.n, self.action_space.n))

        R[1, :] -= self.cost_infection
        R[:, 1] -= self.cost_action

        return R  # Return array X x U of expected rewards for a given degree and neighborhood

    def get_P_high(self, t, mu):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        P[:, 1, 0] = self.recovery_rate
        P[:, 1, 1] = 1 - P[:, 1, 0]
        P[0, 0, 1] = self.infection_rate * mu[1]
        P[0, 0, 0] = 1 - P[0, 0, 1]
        P[1, 0, 0] = 1

        return P  # Return joint transition matrices over actions U on X for high degrees and given mf

    def get_R_high(self, t, mu):
        R = np.zeros((self.observation_space.n, self.action_space.n))

        R[1, :] -= self.cost_infection
        R[:, 1] -= self.cost_action

        return R  # Return array X x U of expected rewards for high degrees and given mf
