import numpy as np
import matplotlib.pyplot as plt

class QuadModel:
    def __init__(
            self,
            name,
            d,
            verbose=False,
            T=1.0,
            Nt=200,
            a=2,
            a_l=1,
            a_u=3,
            b=4
    ):
        self.basename = name
        self.verbose = verbose

        self.T = T
        self.Nt = Nt
        self.Dt = T / Nt
        self.d = d
        self.a = a
        self.a_l = a_l
        self.a_u = a_u
        self.b = b
        self.states = np.arange(d)

        # Full plotting styles
        self.colors = ['black', 'red', 'green', 'blue', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive']
        self.linestyles = ['-', '--', ':', '-.', (0, (1, 1)), (0, (5, 1)), (0, (3, 5, 1, 5)), (0, (5, 10)), (0, (1, 10)), (0, (3, 1, 1, 1))]
        self.linewidths = [2, 2, 5, 2, 2, 2, 2, 2, 2, 2]

    # Running cost f
    def running_cost(self, x, a):
        return self.b * (np.sum(np.square(a - 2)) - (a[x] - 2)**2)

    # Mean field cost F 
    def mean_field_cost(self, x, eta):
        return eta[x]

    # Terminal cost g
    def final_cost(self, x, eta, theta):
        return theta[x] + eta[x]

    # Hamiltonian H
    def get_Hamiltonian(self, x, p, eta):
        mask_p = p
        mask_p[x] = 0
        H = np.sum(2 * mask_p - mask_p**2 / (4 * self.b))
        return H + self.mean_field_cost(x, eta)

    # Optimal control
    def get_gamma_hat(self, p):
        return np.clip(-(p / ( (self.a_u - self.a_l) * self.b)) + self.a, self.a_l, self.a_u)
    
    # Transition rate from optimal control
    def get_q_t(self, mu_t, u_t):
        q_t = np.zeros((self.d, self.d))
        for x in range(self.d):
            q_t[x] = self.get_gamma_hat(u_t - u_t[x])

        for x in range(0, self.d):
            q_t[x, x] = -np.sum(q_t[x])
        return q_t
    
    # KFP step of Picard iteration
    def solve_KFP(self, mu, u):
        new_mu = np.zeros((self.Nt + 1, self.d))
        new_mu[0] = mu[0]
        for it in range(0, self.Nt):
            q_t = self.get_q_t(mu[it], u[it])
            new_mu[it + 1] = np.matmul(new_mu[it], np.eye(self.d) + self.Dt * q_t)
        return new_mu

    # HJB step of Picard iteration
    def solve_HJB(self, mu, theta):
        new_u = np.zeros((self.Nt + 1, self.d))
        for x in range(self.d):
            new_u[self.Nt][x] = self.final_cost(x, eta=mu[-1], theta=theta)
        for it in range(self.Nt - 1, -1, -1):
            opt_H = np.zeros(self.d)
            for x in range(self.d):
                u_diff = new_u[it + 1] - new_u[it + 1][x]
                opt_H[x] = self.get_Hamiltonian(x, u_diff, mu[it + 1])
            new_u[it] = new_u[it + 1] + self.Dt * opt_H
        return new_u

    # Full Picard iteration algorithm
    def picard_iteration(self, mu_0, u_T, theta, tol=1e-6):
        mu_init = np.zeros((self.Nt + 1, self.d))
        u_init = np.zeros((self.Nt + 1, self.d))
        for it in range(self.Nt + 1):
            mu_init[it] = mu_0
            u_init[it] = u_T
        u = u_init
        mu = mu_init
        new_mu = self.solve_KFP(mu, u)
        new_u = self.solve_HJB(new_mu, theta=theta)

        i = 1
        if self.verbose:
            print(f'Iteration {i}: \t tolerance = {tol}, \t np.norm(new_mu - mu) = {np.linalg.norm(new_mu - mu)}, \t np.norm(new_u - u) = {np.linalg.norm(new_u - u)}')
        while np.linalg.norm(new_mu - mu) > tol or np.linalg.norm(new_u - u) > tol:
            u = new_u
            mu = new_mu
            new_mu = self.solve_KFP(mu, u)
            new_u = self.solve_HJB(new_mu, theta=theta)
            i += 1
            if self.verbose:
                print(f'Iteration {i}, \t tolerance = {tol}, \t np.norm(new_mu - mu) = {np.linalg.norm(new_mu - mu)}, \t np.norm(new_u - u) = {np.linalg.norm(new_u - u)}') 
        return new_mu, new_u
    
    # Auxiliary plotting code
    def plot_results(self, mu, u):
        fig, ax = plt.subplots()

        ax.margins(0.05)
        for x in range(self.d):
            ax.plot(
                np.linspace(0, self.T, self.Nt + 1), 
                mu[:, x], 
                label=rf"$\mu({x})$",
                color=self.colors[x], 
                linestyle=self.linestyles[x], 
                linewidth=self.linewidths[x]
            )
        ax.set_xlabel('Time ($t$)')
        ax.set_ylabel(r'$\mu$')
        ax.legend()

        figname = self.basename + '_mu.png'
        plt.savefig(figname, dpi=200)
        plt.show()
        plt.close(fig)

        fig, ax = plt.subplots()
        ax.margins(0.05)
        for x in range(self.d):
            ax.plot(
                np.linspace(0, self.T, self.Nt + 1), 
                u[:, x], 
                label=rf"$u({x})$",
                color=self.colors[x], 
                linestyle=self.linestyles[x], 
                linewidth=self.linewidths[x]
            )
        ax.set_xlabel('Time ($t$)')
        ax.set_ylabel(r'$u$')
        ax.legend()

        figname = self.basename + '_u.png'
        plt.savefig(figname, dpi=200)
        plt.show()
        plt.close(fig)