import numpy as np
import matplotlib.pyplot as plt

class CSModel:
    def __init__(
            self,
            name,
            verbose=False,
            T=10.0,
            Nt=100,
            NS=4,
            beta_UU=0.3,
            beta_UD=0.4,
            beta_DU=0.3,
            beta_DD=0.4,
            v_H=0.6,
            lambda_speed=0.8,
            q_rec_D=0.5,
            q_rec_U=0.4,
            q_inf_D=0.4,
            q_inf_U=0.3,
            k_D=0.3,
            k_I=0.5,
            states={'DI' : 0, 'DS': 1, 'UI': 2, 'US': 3}
    ):
        self.basename = name
        self.verbose = verbose

        self.T = T
        self.Nt = Nt
        self.Dt = T / Nt
        self.NS = NS
        self.beta_UU = beta_UU
        self.beta_UD = beta_UD
        self.beta_DU = beta_DU
        self.beta_DD = beta_DD
        self.v_H = v_H
        self.lambda_speed = lambda_speed
        self.q_rec_D = q_rec_D
        self.q_rec_U = q_rec_U
        self.q_inf_D = q_inf_D
        self.q_inf_U = q_inf_U
        self.k_D = k_D
        self.k_I = k_I
        self.states = states

        # Plotting styles
        self.colors = ['black', 'red', 'green', 'blue']
        self.linestyles = ['-', '--', ':', '-.']
        self.linewidths = [2, 2, 5, 2]

    # Utility functions for cybersecurity model
    def get_index(self, str):
        if str in self.states:
            return self.states[str]
        else:
            raise ValueError(f"State {str} not found in states dictionary.")
        
    def get_state(self, iS):
        for key, value in self.states.items():
            if value == iS:
                return key
        raise ValueError(f"Index {iS} not found in states dictionary.")
    
    # Transition matrix
    def get_lambda_t(self, mu_t, alpha_t):
        lambda_matrix = np.zeros((self.NS, self.NS))
        lambda_matrix[self.get_index('DI'), self.get_index('DS')] = self.q_rec_D
        lambda_matrix[self.get_index('DS'), self.get_index('DI')] = self.v_H * self.q_inf_D + self.beta_DD * mu_t[self.get_index('DI')] + self.beta_UD * mu_t[self.get_index('UI')]
        lambda_matrix[self.get_index('UI'), self.get_index('US')] = self.q_rec_U
        lambda_matrix[self.get_index('US'), self.get_index('UI')] = self.v_H * self.q_inf_U + self.beta_UU * mu_t[self.get_index('UI')] + self.beta_DU * mu_t[self.get_index('DI')]
        if alpha_t == 1:
            lambda_matrix[self.get_index('DI'), self.get_index('UI')] = self.lambda_speed
            lambda_matrix[self.get_index('DS'), self.get_index('US')] = self.lambda_speed
            lambda_matrix[self.get_index('UI'), self.get_index('DI')] = self.lambda_speed
            lambda_matrix[self.get_index('US'), self.get_index('DS')] = self.lambda_speed
        for iS in range(0, self.NS):
            lambda_matrix[iS, iS] = - np.sum(lambda_matrix[iS])
        return lambda_matrix

    # Running cost for given state and control
    def running_cost_t(self, iS):
        rcost = 0
        if iS == self.get_index('DI') or iS == self.get_index('DS'):
            rcost += self.k_D
        if iS == self.get_index('DI') or iS == self.get_index('UI'):
            rcost += self.k_I
        return rcost

    # Augmented final cost
    def final_cost(self, iS, c=1.):
        f_cost = 0.
        if iS == self.get_index('DI') or iS == self.get_index('UI'):
            f_cost = c
        return f_cost
    
    # Hamiltonian
    def get_Hamiltonian(self, iS, mu_t, u_t, alpha_t):
        return np.matmul(self.get_lambda_t(mu_t, alpha_t)[iS], u_t) + self.running_cost_t(iS)

    # Optimal control for given state, mu and u
    def get_alphahat_t(self, iS, mu_t, u_t):
        H_0 = self.get_Hamiltonian(iS, mu_t, u_t, 0)
        H_1 = self.get_Hamiltonian(iS, mu_t, u_t, 1)
        if H_0 <= H_1:
            return 0
        else:
            return 1

    def get_alphahat_t_vec(self, mu_t, u_t):
        alphahat = np.zeros(self.NS)
        for iS in range(self.NS):
            alphahat[iS] = self.get_alphahat_t(iS, mu_t, u_t)
        return alphahat
    
    def get_q_t(self, mu_t, u_t):
        q_t = np.zeros((self.NS, self.NS))
        alphahat = self.get_alphahat_t_vec(mu_t, u_t)
        for iS in range(self.NS):
            q_t[iS] = self.get_lambda_t(mu_t, alphahat[iS])[iS]
        return q_t
    
    # KFP step of Picard iteration
    def solve_KFP(self, mu, u):
        new_mu = np.zeros((self.Nt + 1, self.NS))
        new_mu[0] = mu[0]
        for it in range(0, self.Nt):
            q_t = self.get_q_t(mu[it], u[it])
            new_mu[it + 1] = np.matmul(new_mu[it], np.eye(self.NS) + self.Dt * q_t)
        return new_mu

    # HJB step of Picard iteration
    def solve_HJB(self, mu, cost_param=1):
        new_u = np.zeros((self.Nt + 1, self.NS))
        for iS in range(self.NS):
            new_u[self.Nt][iS] = self.final_cost(iS, c=cost_param)
        for it in range(self.Nt - 1, -1, -1):
            opt_H = np.zeros(self.NS)
            for iS in range(self.NS):
                opt_H[iS] = min(self.get_Hamiltonian(iS, mu[it + 1], new_u[it + 1], 0), self.get_Hamiltonian(iS, mu[it + 1], new_u[it + 1], 1))
            new_u[it] = new_u[it + 1] + self.Dt * opt_H
        return new_u

    # Full Picard iteration algorithm
    def picard_iteration(self, mu_0, u_T, cost_param=1, tol=1e-6):
        mu_init = np.zeros((self.Nt + 1, self.NS))
        u_init = np.zeros((self.Nt + 1, self.NS))
        for it in range(self.Nt + 1):
            mu_init[it] = mu_0
            u_init[it] = u_T
        u = u_init
        mu = mu_init
        new_mu = self.solve_KFP(mu, u)
        new_u = self.solve_HJB(new_mu, cost_param=cost_param)

        i = 1
        if self.verbose:
            print(f'Iteration {i}: \t tolerance = {tol}, \t np.norm(new_mu - mu) = {np.linalg.norm(new_mu - mu)}, \t np.norm(new_u - u) = {np.linalg.norm(new_u - u)}')
        while np.linalg.norm(new_mu - mu)>tol or np.linalg.norm(new_u - u)>tol:
            u = new_u
            mu = new_mu
            new_mu = self.solve_KFP(mu, u)
            new_u = self.solve_HJB(new_mu, cost_param=cost_param)
            i += 1
            if self.verbose:
                print(f'Iteration {i}, \t tolerance = {tol}, \t np.norm(new_mu - mu) = {np.linalg.norm(new_mu - mu)}, \t np.norm(new_u - u) = {np.linalg.norm(new_u - u)}') 
        return new_mu, new_u
    
    # Auxiliary plotting code
    def plot_results(self, mu, u):
        fig, ax = plt.subplots()

        ax.margins(0.05)
        for iS in range(self.NS):
            ax.plot(
                np.linspace(0, self.T, self.Nt + 1), 
                mu[:, iS], 
                label=rf"$\mu({self.get_state(iS)})$",
                color=self.colors[iS], 
                linestyle=self.linestyles[iS], 
                linewidth=self.linewidths[iS]
            )
        ax.set_xlabel('Time ($t$)')
        ax.set_ylabel(r'$\mu$')
        ax.legend()

        figname = self.basename + '_mu.png'
        plt.savefig(figname, dpi=200)
        plt.show()
        plt.close(fig)

        fig, ax = plt.subplots()
        ax.margins(0.05)
        for iS in range(self.NS):
            ax.plot(
                np.linspace(0, self.T, self.Nt + 1), 
                u[:, iS], 
                label=rf"$u({self.get_state(iS)})$",
                color=self.colors[iS], 
                linestyle=self.linestyles[iS], 
                linewidth=self.linewidths[iS]
            )
        ax.set_xlabel('Time ($t$)')
        ax.set_ylabel(r'$u$')
        ax.legend()

        figname = self.basename + '_u.png'
        plt.savefig(figname, dpi=200)
        plt.show()
        plt.close(fig)

    def save_results(self, mu, u):
        filename = self.basename + '_data_and_solution.npz'
        np.savez(filename, T=self.T, Nt=self.Nt, NS=self.NS, beta_UU=self.beta_UU, beta_UD=self.beta_UD, beta_DU=self.beta_DU, beta_DD=self.beta_DD, \
            v_H=self.v_H, lambda_speed=self.lambda_speed, q_rec_D=self.q_rec_D, q_rec_U=self.q_rec_U, q_inf_D=self.q_inf_D, q_inf_U=self.q_inf_U, \
            k_D=self.k_D, k_I=self.k_I, \
            mu=mu, u=u)    
    