from qpsolvers import solve_qp
import numpy as np
import torch

from utils import cprint

EPS = 1e-8

def flatGrad(y, x, retain_graph=False, create_graph=False):
    if create_graph:
        retain_graph = True
    g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)
    g = torch.cat([t.view(-1) for t in g])
    return g

class MultiTROptimizer:
    def __init__(
            self, 
            actor:torch.nn.Module, 
            max_tr_size:float, 
            n_objs:int, 
            n_cons:int, 
            con_thresholds:np.ndarray, 
            device:torch.device) -> None:

        # set parameters
        self.actor = actor
        self.device = device
        self.max_tr_size = max_tr_size
        self.n_objs = n_objs
        self.n_cons = n_cons
        self.con_thresholds = con_thresholds

        # count number of parameters
        self.n_params = 0
        for param in self.actor.parameters():
            self.n_params += param.shape.numel()

        # inner variables
        self.B_tensor = torch.zeros((self.n_objs + self.n_cons, self.n_params), device=self.device, dtype=torch.float32)

    #################
    # public function
    #################

    def step(self, get_obj_con, preference, states_tensor, sampled_preferences_tensor):
        # get objectives, constraints, kl
        objectives, constraints = get_obj_con()

        # calculate gradient of objectives
        for obj_idx in range(self.n_objs):
            b_tensor = flatGrad(-objectives[obj_idx], self.actor.parameters(), retain_graph=True)
            self.B_tensor[obj_idx].data.copy_(b_tensor)

        # calculate gradient of constraints and check safety mode
        con_vals = []
        for con_idx in range(self.n_cons):
            b_tensor = flatGrad(constraints[con_idx], self.actor.parameters(), retain_graph=True)
            self.B_tensor[self.n_objs + con_idx].data.copy_(b_tensor)
            con_vals.append(constraints[con_idx].item())

        with torch.no_grad():
            # constrct QP problem
            S_tensor = self.B_tensor@self.B_tensor.T
            S_mat = S_tensor.detach().cpu().numpy().astype(np.float64)
            S_mat = 0.5*(S_mat + S_mat.T)
            # ===== to ensure S_mat is invertible ===== #
            min_eig_val = min(0.0, np.min(np.linalg.eigvals(S_mat)))
            S_mat += np.eye(S_mat.shape[0])*(-min_eig_val + EPS)
            # ========================================= #
            c_scalars = []
            for obj_idx in range(self.n_objs):
                b_H_inv_b = S_mat[obj_idx, obj_idx]
                c_scalar = self.max_tr_size*np.sqrt(np.clip(b_H_inv_b, 0.0, np.inf))
                c_scalars.append(preference[obj_idx]*c_scalar)
            for con_idx in range(self.n_cons):
                b_H_inv_b = S_mat[self.n_objs + con_idx, self.n_objs + con_idx]
                c_scalar = self.max_tr_size*np.sqrt(np.clip(b_H_inv_b, 0.0, np.inf))
                const_value = con_vals[con_idx] - self.con_thresholds[con_idx]
                c_scalars.append(const_value)
            c_vector = np.array(c_scalars)

            # solve QP
            try:
                # solve QP with constraints and objectives
                con_lam_vector = solve_qp(
                    P=S_mat, q=-c_vector, lb=np.zeros_like(c_vector))
                assert con_lam_vector is not None

                # find scaling factor
                approx_tr_size = np.sqrt(np.dot(con_lam_vector, S_mat@con_lam_vector)) + EPS
                scaling_factor = min(self.max_tr_size/approx_tr_size, 1.0)
            except Exception as e:
                cprint(f"Error: {e}", bold=True, color='red')
                con_lam_vector = np.zeros_like(c_vector)
                scaling_factor = 0.0

            # find search direction
            lam_tensor = torch.tensor(con_lam_vector, device=self.device, dtype=torch.float32)
            delta_theta = scaling_factor*(self.B_tensor.T@lam_tensor)

            # backup parameters
            init_theta = torch.cat([t.view(-1) for t in self.actor.parameters()]).clone().detach()

            # update distribution list
            self._applyParams(init_theta - delta_theta)
            means, log_stds, _ = self.actor(states_tensor, sampled_preferences_tensor)
            self._applyParams(init_theta)

        return means, log_stds, objectives, constraints

    ##################
    # private function
    ##################

    def _applyParams(self, params):
        n = 0
        for p in self.actor.parameters():
            numel = p.numel()
            p.data.copy_(params[n:n + numel].view(p.shape))
            n += numel

    def _applyGradParams(self, params):
        n = 0
        for p in self.actor.parameters():
            numel = p.numel()
            p.grad = params[n:n + numel].view(p.shape)
            n += numel
