import numpy as np
from gym.spaces import Discrete

from fast_marl_graphex import FastGraphexEnv


class SIRGraphex(FastGraphexEnv):
    """
    Models the SIR game.
    """

    def __init__(self, graphex='separable', infection_rate: float = 0.05, recovery_rate: float = 0.01,
                 time_steps: int = 500, initial_infection_prob: float = 0.1, cost_infection: float = 1,
                 cost_action: float = 0.25,  **kwargs):
        self.infection_rate = infection_rate
        self.recovery_rate = recovery_rate
        self.initial_infection_prob = initial_infection_prob
        self.cost_infection = cost_infection
        self.cost_action = cost_action

        observation_space = Discrete(3)
        action_space = Discrete(2)

        mu_0 = np.array([1 - initial_infection_prob, initial_infection_prob, 0])

        super().__init__(graphex, observation_space, action_space, time_steps, mu_0, **kwargs)

    def next_states(self, t, xs, us):
        G_infected = self.adj_matrix.multiply(xs == 1).sum(axis=1).A.flatten() / self.degrees
        recoveries = np.random.rand(self.num_agents) < self.recovery_rate
        infections = np.random.rand(self.num_agents) < self.infection_rate * G_infected \
                     * (2/(1 + np.exp(- 0.5 * self.degrees))-1)
        new_xs = (xs == 2) * 2 \
                 + (xs == 1) * recoveries * 2 \
                 + (xs == 1) * (1 - recoveries) * 1 \
                 + (xs == 0) * (1 - us) * infections \
                 + (xs == 0) * us * 0

        return new_xs

    def reward(self, t, xs, us):
        rewards = - self.cost_infection * (xs == 1) - self.cost_action * us
        return rewards

    def get_P_k_G(self, t, k, G):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        P[:, 2, 2] = 1
        P[:, 1, 2] = self.recovery_rate
        P[:, 1, 1] = 1 - self.recovery_rate
        P[0, 0, 1] = self.infection_rate * (G[1]) * (2/(1 + np.exp(- 0.5 * k))-1)
        P[0, 0, 0] = 1 - P[0, 0, 1]
        P[1, 0, 0] = 1

        return P  # Return joint transition matrices over actions U on X for a given degree and neighborhood

    def get_R_k_G(self, t, k, G):
        R = np.zeros((self.observation_space.n, self.action_space.n))

        R[1, :] -= self.cost_infection
        R[:, 1] -= self.cost_action

        return R  # Return array X x U of expected rewards for a given degree and neighborhood

    def get_P_high(self, t, mu):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        P[:, 2, 2] = 1
        P[:, 1, 2] = self.recovery_rate
        P[:, 1, 1] = 1 - self.recovery_rate
        P[0, 0, 1] = self.infection_rate * mu[1]
        P[0, 0, 0] = 1 - P[0, 0, 1]
        P[1, 0, 0] = 1

        return P  # Return joint transition matrices over actions U on X for high degrees and given mf

    def get_R_high(self, t, mu):
        R = np.zeros((self.observation_space.n, self.action_space.n))

        R[1, :] -= self.cost_infection
        R[:, 1] -= self.cost_action

        return R  # Return array X x U of expected rewards for high degrees and given mf
