import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import scipy.sparse
import sys

from sklearn import svm
from sklearn.linear_model import LogisticRegression


from sklearn.metrics.cluster import normalized_mutual_info_score




class ReturnValue:
    def __init__(self, state, S, I, R, Sactive, Iactive, Ractive):
        self.state = state
        self.S = S
        self.I = I
        self.R = R
        self.Sactive = Sactive
        self.Iactive = Iactive
        self.Ractive = Ractive
        
        
    def plotSIRall(self,  *args, **kwargs):
        
        colors = kwargs.get('colors', ["r", "g", "b", "grey"])
        
        fig = plt.figure(figsize = (12,6))
        plt.plot(self.S, label = "S", color = colors[0])
        plt.plot(self.I, label = "I", color = colors[1])
        plt.plot(self.R, label = "R", color = colors[2])
        plt.plot(self.S + self.I + self.R, color = colors[3], linestyle = "--")
        plt.legend(fontsize = 16, loc = 1)
        plt.show();
        
        return
    
    def plotSIR(self,  *args, **kwargs):
        
        colors = kwargs.get('colors', ["r", "g", "b", "grey"])
        
        fig = plt.figure(figsize = (12,6))
        plt.plot(self.Sactive, label = "S", color = colors[0])
        plt.plot(self.Iactive, label = "I", color = colors[1])
        plt.plot(self.Ractive, label = "R", color = colors[2])
        plt.plot(self.Sactive + self.Iactive + self.Ractive, color = colors[3], linestyle = "--")
        plt.legend(fontsize = 16, loc = 1)
        plt.show();
        
        return

        

def SIR(G, nodes, λ, η, s0):
    
   
    state = SIR_model(G, nodes, λ, η, s0)
    T = len(G)
    
    S = np.array([np.sum(state[t] == 0) for t in range(T)])
    I = np.array([np.sum(state[t] == 1) for t in range(T)])
    R = np.array([np.sum(state[t] == 2) for t in range(T)])
    
    Sactive = np.array([np.sum(state.loc[np.array(G[t].nodes)][t] == 0) for t in range(T)])
    Iactive = np.array([np.sum(state.loc[np.array(G[t].nodes)][t] == 1) for t in range(T)])
    Ractive = np.array([np.sum(state.loc[np.array(G[t].nodes)][t] == 2) for t in range(T)])

    
    return ReturnValue(state, S, I, R, Sactive, Iactive, Ractive)


def SIR_model(G, nodes, λ, η, s0):
    """
    This function runs a SIR simulation on the unweighted temporal graph G
    
    Use: state = SIR(G, nodes, λ, η, s0)
    
    Input: 
    G (list of networkx graphs), input snapshot representation of the temporal graph
    nodes (array): set of all nodes appearing in the dynamical graph
    λ (float), probability of contagion
    η (float), probability of recovery  
    s0 (array), vector containing the initial condition with S = 0, I = 1, R = 2 
           
    Output: 
    state (pandas dataframe), state[t].loc[i] indicates the state of node i at time t. 
    """

    n = len(nodes) 
    T = len(G)

    # this records the state of each node at each time step
    state = pd.DataFrame()
    state["ID"] = nodes
    state.set_index("ID", inplace = True)
    state[0] = s0

    for t in range(1,T):

        # initialize to the state at time t
        state[t] = np.zeros(n).astype(int)

        susceptible = nodes[state[t-1] == 0]
        susceptible_active = susceptible[np.isin(susceptible, G[t-1].nodes)]
        susceptible_inactive = susceptible[np.logical_not(np.isin(susceptible, G[t-1].nodes))]

        infected = nodes[state[t-1] == 1]
        recovered = nodes[state[t-1] == 2]


        # update the labels of the active S nodes
        neighbour_I = np.array([np.sum(np.isin([i for i in G[t-1].neighbors(j)], infected)) for j in susceptible_active])
        p = 1-(1 - λ)**(neighbour_I)
        r = np.random.binomial(1, p, len(susceptible_active))
        state[t].loc[susceptible_active] = r

        # update the labels of the inactive S nodes
        state[t].loc[susceptible_inactive] = 0

        # update the labels of the I nodes
        r = np.random.binomial(1,η, len(infected))+1
        state[t].loc[infected] = r

        # update the labels of the R nodes
        state[t].loc[recovered] = 2

    return state
