import numpy as np
import torch
from gym.spaces import Discrete
from torch.distributions import Categorical

from games.graphon_mfg import FiniteGraphonMeanFieldGame


class CyberGraphon(FiniteGraphonMeanFieldGame):
    """
    Models the Cybersecurity game.
    """

    def __init__(self, graphon=(lambda x, y: 1-max(x, y)), time_steps: int = 50, mu_0=(0.25,) * 4,
                 q_rec_D: float = 0.3, q_rec_U: float = 0.2, lambda_wait: float = 0.3, v_H: float = 0.1,
                 z_inf_D: float = 0.05, z_inf_U: float = 0.1, beta_DD: float = 0.1, beta_UD: float = 0.2,
                 beta_DU: float = 0.7, beta_UU: float = 0.8, k_D: float = 0.7, k_I: float = 2.0, **kwargs):
        self.graphon = graphon
        self.q_rec_D = q_rec_D
        self.q_rec_U = q_rec_U
        self.lambda_wait = lambda_wait
        self.v_H = v_H
        self.z_inf_D = z_inf_D
        self.z_inf_U = z_inf_U
        self.beta_DD = beta_DD
        self.beta_DU = beta_DU
        self.beta_UD = beta_UD
        self.beta_UU = beta_UU
        self.k_D = k_D
        self.k_I = k_I

        # States: DI DS UI US
        def initial_state_distribution(x):
            return Categorical(probs=torch.tensor(mu_0))
        agent_observation_space = Discrete(4)
        agent_action_space = Discrete(2)
        super().__init__(agent_observation_space, agent_action_space, time_steps, initial_state_distribution, graphon)

    def transition_probs_g(self, t, x, u, g):
        g_DI = min(1, g.evaluate_integral(t, lambda dy: dy[1] == 0))
        g_UI = min(1, g.evaluate_integral(t, lambda dy: dy[1] == 2))
        q_inf_D = self.v_H * self.z_inf_D \
                      + self.beta_DD * g_DI \
                      + self.beta_UD * g_UI \
                      - self.v_H * self.z_inf_D * self.beta_DD * g_DI \
                      - self.v_H * self.z_inf_D * self.beta_UD * g_UI \
                      - self.beta_DD * self.beta_UD * g_DI * g_UI \
                      + self.v_H * self.z_inf_D * self.beta_DD * self.beta_UD * g_DI * g_UI
        q_inf_U = self.v_H * self.z_inf_U \
                      + self.beta_DU * g_DI \
                      + self.beta_UU * g_UI \
                      - self.v_H * self.z_inf_U * self.beta_DU * g_DI \
                      - self.v_H * self.z_inf_U * self.beta_UU * g_UI \
                      - self.beta_DU * self.beta_UU * g_DI * g_UI \
                      + self.v_H * self.z_inf_U * self.beta_DU * self.beta_UU * g_DI * g_UI
        q_rec_D = self.q_rec_D
        q_rec_U = self.q_rec_U

        transition_matrix = np.array([
            [(1 - u * self.lambda_wait) * (1 - q_rec_D), (1 - u * self.lambda_wait) * (q_rec_D), (u * self.lambda_wait) * (1 - q_rec_D), (u * self.lambda_wait) * (q_rec_D)],
            [(1 - u * self.lambda_wait) * (q_inf_D), (1 - u * self.lambda_wait) * (1 - q_inf_D), (u * self.lambda_wait) * (q_inf_D), (u * self.lambda_wait) * (1 - q_inf_D)],
            [(u * self.lambda_wait) * (1 - q_rec_U), (u * self.lambda_wait) * (q_rec_U), (1 - u * self.lambda_wait) * (1 - q_rec_U), (1 - u * self.lambda_wait) * (q_rec_U)],
            [(u * self.lambda_wait) * (q_inf_U), (u * self.lambda_wait) * (1 - q_inf_U), (1 - u * self.lambda_wait) * (q_inf_U), (1 - u * self.lambda_wait) * (1 - q_inf_U)],
        ])

        return transition_matrix[x[1]]

    def reward_g(self, t, x, u, g):
        return - self.k_D * ((x[1] == 0) + (x[1] == 1)) - self.k_I * ((x[1] == 0) + (x[1] == 2))
