from qpsolvers import solve_qp
import numpy as np
import torch

EPS = 1e-8

def flatGrad(y, x, retain_graph=False, create_graph=False):
    if create_graph:
        retain_graph = True
    g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)
    g = torch.cat([t.view(-1) for t in g])
    return g


class TROptimizer:
    def __init__(self, actor, damping_coeff, num_conjugate, line_decay, max_kl, device) -> None:
        self.actor = actor
        self.damping_coeff = damping_coeff
        self.num_conjugate = num_conjugate
        self.line_decay = line_decay
        self.max_kl = max_kl
        self.device = device

        # count number of parameters
        self.n_params = 0
        for param in self.actor.parameters():
            self.n_params += param.shape.numel()

    #################
    # public function
    #################

    def step(self, get_obj_kl, mu_kl=0.0):
        # for adaptive kl
        max_kl = self._getMaxKL(mu_kl)

        # calculate gradient
        objective, kl = get_obj_kl()
        self._computeKLGrad(kl)
        g_tensor = flatGrad(objective, self.actor.parameters(), retain_graph=True)
        H_inv_g_tensor = self._conjugateGradient(g_tensor)
        approx_g_tensor = self._Hx(H_inv_g_tensor)

        with torch.no_grad():
            # calculate search direction
            g_H_inv_g_tensor = torch.dot(approx_g_tensor, H_inv_g_tensor)
            nu = torch.sqrt(2.0*max_kl/(g_H_inv_g_tensor + EPS))
            delta_theta = nu*H_inv_g_tensor

            # line search
            beta = 1.0
            init_theta = torch.cat([t.view(-1) for t in self.actor.parameters()]).clone().detach()
            init_objective = objective.clone().detach()
            while True:
                theta = beta*delta_theta + init_theta
                self._applyParams(theta)
                objective, kl = get_obj_kl()
                if kl <= 1.5*max_kl and objective >= init_objective:
                    break
                beta *= self.line_decay
        return objective.item(), kl.item(), max_kl, beta

    ##################
    # private function
    ##################

    def _getMaxKL(self, mu_kl=0.0):
        kl_bonus = np.sqrt(mu_kl*(self.max_kl + 0.25*mu_kl)) - 0.5*mu_kl
        max_kl = np.clip(self.max_kl - kl_bonus, 0.0, np.inf)
        return max_kl

    def _computeKLGrad(self, kl):
        self._flat_grad_kl = flatGrad(kl, self.actor.parameters(), create_graph=True)

    def _applyParams(self, params):
        n = 0
        for p in self.actor.parameters():
            numel = p.numel()
            p.data.copy_(params[n:n + numel].view(p.shape))
            n += numel

    def _Hx(self, x):
        kl_x = torch.dot(self._flat_grad_kl, x.detach())
        H_x = flatGrad(kl_x, self.actor.parameters(), retain_graph=True)
        return H_x + x*self.damping_coeff

    def _conjugateGradient(self, g):
        x = torch.zeros_like(g)
        r = g.clone()
        rs_old = torch.dot(r, r)
        if rs_old < EPS:
            return x
        p = g.clone()
        for i in range(self.num_conjugate):
            Ap = self._Hx(p)
            pAp = torch.dot(p, Ap)
            alpha = rs_old/pAp
            x += alpha*p
            if i == self.num_conjugate - 1:
                break
            r -= alpha*Ap
            rs_new = torch.dot(r, r)
            if rs_new < EPS:
                break
            p = r + (rs_new/rs_old)*p
            rs_old = rs_new
        return x


class MultiTROptimizer(TROptimizer):
    def __init__(
        self, actor, 
        damping_coeff, 
        num_conjugate, 
        line_decay, 
        max_tr_size, 
        n_objs, 
        con_thresholds, 
        con_entropy,
        device) -> None:        

        self.max_tr_size = max_tr_size
        max_kl = 0.5*(self.max_tr_size**2)

        super().__init__(actor, damping_coeff, num_conjugate, line_decay, max_kl, device)
        self.n_objs = n_objs
        self.n_cons = len(con_thresholds)
        self.con_thresholds = con_thresholds
        self.con_entropy = con_entropy

        # inner variables
        self.B_tensor = torch.zeros((self.n_objs + self.n_cons, self.n_params), device=self.device, dtype=torch.float32)
        self.H_inv_B_tensor = torch.zeros((self.n_objs + self.n_cons, self.n_params), device=self.device, dtype=torch.float32)


    def step(self, get_obj_con_kl, preference, states_tensor, preferences_tensor):
        # for adaptive kl
        preference = preference/np.max(preference)

        # calculate gradient
        objectives, constraints, kl = get_obj_con_kl()
        self._computeKLGrad(kl)

        for obj_idx in range(self.n_objs):
            b_tensor = flatGrad(-objectives[obj_idx], self.actor.parameters(), retain_graph=True)
            H_inv_b_tensor = self._conjugateGradient(b_tensor)
            approx_b_tensor = self._Hx(H_inv_b_tensor)
            self.B_tensor[obj_idx].data.copy_(approx_b_tensor)
            self.H_inv_B_tensor[obj_idx].data.copy_(H_inv_b_tensor)

        con_vals = []
        safety_mode = False
        for con_idx in range(self.n_cons):
            b_tensor = flatGrad(constraints[con_idx], self.actor.parameters(), retain_graph=True)
            H_inv_b_tensor = self._conjugateGradient(b_tensor)
            approx_b_tensor = self._Hx(H_inv_b_tensor)
            self.B_tensor[self.n_objs + con_idx].data.copy_(approx_b_tensor)
            self.H_inv_B_tensor[self.n_objs + con_idx].data.copy_(H_inv_b_tensor)
            con_vals.append(constraints[con_idx].item())
            if self.con_entropy and con_idx == self.n_cons - 1:
                continue
            if con_vals[con_idx] > self.con_thresholds[con_idx]:
                safety_mode = True

        with torch.no_grad():
            # constrct QP problem
            S_tensor = self.B_tensor@self.H_inv_B_tensor.T
            S_mat = S_tensor.detach().cpu().numpy().astype(np.float64)
            S_mat = 0.5*(S_mat + S_mat.T)
            # ===== to ensure S_mat is invertible ===== #
            min_eig_val = min(0.0, np.min(np.linalg.eigvals(S_mat)))
            S_mat += np.eye(S_mat.shape[0])*(-min_eig_val + EPS)
            # ========================================= #
            c_scalars = []
            safe_c_scalars = []
            active_indices = []
            safe_active_indices = []
            for obj_idx in range(self.n_objs):
                b_H_inv_b = S_mat[obj_idx, obj_idx]
                c_scalar = self.max_tr_size*np.sqrt(np.clip(b_H_inv_b, 0.0, np.inf))
                c_scalars.append(c_scalar*preference[obj_idx])
                active_indices.append(obj_idx)
            for con_idx in range(self.n_cons):
                b_H_inv_b = S_mat[self.n_objs + con_idx, self.n_objs + con_idx]
                c_scalar = self.max_tr_size*np.sqrt(np.clip(b_H_inv_b, 0.0, np.inf))
                const_value = con_vals[con_idx] - self.con_thresholds[con_idx]
                if self.con_entropy and con_idx == self.n_cons - 1:
                    if const_value + c_scalar >= 0.0:
                        active_indices.append(self.n_objs + con_idx)
                        safe_active_indices.append(con_idx)
                        safe_c_scalar = min(const_value, c_scalar)
                else:
                    if const_value > 0.0: # for violated constraint
                        safe_active_indices.append(con_idx)
                        safe_c_scalar = c_scalar
                    elif const_value + c_scalar >= 0.0: # for active constraint
                        active_indices.append(self.n_objs + con_idx)
                        safe_active_indices.append(con_idx)
                        safe_c_scalar = min(const_value, c_scalar)
                    else: # for inactive constraint
                        safe_c_scalar = const_value
                safe_c_scalars.append(safe_c_scalar)
                c_scalars.append(const_value)
            safe_c_vector = np.array(safe_c_scalars)
            c_vector = np.array(c_scalars)

            # solve QP
            if not safety_mode:
                try:
                    # solve QP with constraints and objectives
                    temp_S_mat = S_mat[active_indices][:, active_indices]
                    temp_c_vector = c_vector[active_indices]
                    temp_con_lam_vector = solve_qp(
                        P=temp_S_mat, q=-temp_c_vector, lb=np.zeros_like(temp_c_vector))
                    assert temp_con_lam_vector is not None
                    con_lam_vector = np.zeros_like(c_vector)
                    for idx, active_idx in enumerate(active_indices):
                        con_lam_vector[active_idx] = temp_con_lam_vector[idx]

                    # find scaling factor
                    approx_tr_size = np.sqrt(np.dot(con_lam_vector, S_mat@con_lam_vector))
                    scaling_factor = min(self.max_tr_size/approx_tr_size, 1.0)

                    # find search direction
                    lam_tensor = torch.tensor(con_lam_vector, device=self.device, dtype=torch.float32)
                    delta_theta = scaling_factor*(self.H_inv_B_tensor.T@lam_tensor)
                except:
                    # if QP solver failed, then use safety mode
                    print("QP solver failed.")
                    safety_mode = True

            if safety_mode:
                # solve QP with only constraints
                S_mat = S_mat[self.n_objs:, self.n_objs:]
                temp_S_mat = S_mat[safe_active_indices][:, safe_active_indices]
                temp_c_vector = safe_c_vector[safe_active_indices]
                temp_con_lam_vector = solve_qp(
                    P=temp_S_mat, q=-temp_c_vector, lb=np.zeros_like(temp_c_vector))
                assert temp_con_lam_vector is not None, "QP solver failed 2."
                con_lam_vector = np.zeros_like(safe_c_vector)
                for idx, active_idx in enumerate(safe_active_indices):
                    con_lam_vector[active_idx] = temp_con_lam_vector[idx]

                # find scaling factor
                approx_tr_size = np.sqrt(np.dot(con_lam_vector, S_mat@con_lam_vector))
                scaling_factor = min(self.max_tr_size/approx_tr_size, 1.0)

                # find search direction
                lam_tensor = torch.tensor(con_lam_vector, device=self.device, dtype=torch.float32)
                delta_theta = scaling_factor*(self.H_inv_B_tensor[self.n_objs:].T@lam_tensor)

            # backup parameters
            init_theta = torch.cat([t.view(-1) for t in self.actor.parameters()]).clone().detach()

            # update distribution list
            self._applyParams(init_theta - delta_theta)
            means, log_stds, stds = self.actor(states_tensor, preferences_tensor)

            # restore parameters
            self._applyParams(init_theta)

        return objectives, constraints, means, log_stds, stds, safety_mode