import gpytorch
import numpy as np
import torch
from botorch.models.model_list_gp_regression import ModelListGP
from scipy.linalg import null_space
from scipy.optimize import minimize
from pymoo.util.ref_dirs import get_reference_directions
from mobo_osd.helper import distance_to_line

class FirstOrderApproximation:
    def __init__(
            self,
            model_list: ModelListGP,
            x_opt: np.ndarray,
        ):
        self.model_list = model_list
        self.x_opt = x_opt
        self.num_objectives = len(model_list.models)
        self.dim = model_list.models[0].train_inputs[0].shape[-1]

    def _get_box_const_value_jacobian_hessian(self, x, bounds):
        '''
        Getting the value, jacobian and hessian of active box constraints.
        Input:
            x: a design sample, shape = (n_var,)
            bounds: problem's lower and upper bounds, shape = (2, n_var)
        Output:
            G: value of active box constraints (always 0), shape = (n_active_const,)
            DG: jacobian matrix of active box constraints (1/-1 at active locations, otherwise 0), shape = (n_active_const, n_var)
            HG: hessian matrix of active box constraints (always 0), shape = (n_active_const, n_var, n_var)
        '''
        # get indices of active constraints
        active_idx, upper_active_idx, _ = self._get_active_box_const(x, bounds)
        n_active_const, n_var = len(active_idx), len(x)

        if n_active_const > 0:
            G = np.zeros(n_active_const)
            DG = np.zeros((n_active_const, n_var))
            for i, idx in enumerate(active_idx):
                constraint = np.zeros(n_var)
                if idx in upper_active_idx:
                    constraint[idx] = 1 # upper active
                else:
                    constraint[idx] = -1 # lower active
                DG[i] = constraint
            HG = np.zeros((n_active_const, n_var, n_var))
            return G, DG, HG
        else:
            # no active constraints
            return None, None, None

    def _get_active_box_const(self, x, bounds):
        '''
        Getting the indices of active box constraints.
        Input:
            x: a design sample, shape = (n_var,)
            bounds: problem's lower and upper bounds, shape = (2, n_var)
        Output:
            active_idx: indices of all active constraints
            upper_active_idx: indices of upper active constraints
            lower_active_idx: indices of lower active constraints
        '''
        eps = 1e-8 # epsilon value to determine 'active'
        upper_active = bounds[1] - x < eps
        lower_active = x - bounds[0] < eps
        active = np.logical_or(upper_active, lower_active)
        active_idx, upper_active_idx, lower_active_idx = np.where(active)[0], np.where(upper_active)[0], np.where(lower_active)[0]
        return active_idx, upper_active_idx, lower_active_idx
  
    def _get_kkt_dual_variables(self, F, G, DF, DG):
        '''
        Optimizing for dual variables alpha and beta in KKT conditions, see section 4.2, proposition 4.5.
        Input:
            Given a design sample,
            F: performance value, shape = (n_obj,)
            G: active constraints, shape = (n_active_const,)
            DF: jacobian matrix of performance, shape = (n_obj, n_var)
            DG: jacobian matrix of active constraints, shape = (n_active_const, n_var)
            where n_var = D, n_obj = d, n_active_const = K' in the original paper
        Output:
            alpha_opt, beta_opt: optimized dual variables
        '''
        # NOTE: use min-norm solution for solving alpha then determine beta instead?
        n_obj = len(F)
        n_active_const = len(G) if G is not None else 0

        '''
        Optimization formulation:
            To optimize the last line of (2) in section 4.2, we change it to a quadratic optization problem by:
            find x to let Ax = 0 --> min_x (Ax)^2
            where x means [alpha, beta] and A means [DF, DG].
            Constraints: alpha >= 0, beta >= 0, sum(alpha) = 1.
            NOTE: we currently ignore the constraint beta * G = 0 because G will always be 0 with only box constraints, but add that constraint will result in poor optimization solution (?)
        '''
        if n_active_const > 0: # when there are active constraints

            def fun(x, n_obj=n_obj, DF=DF, DG=DG):
                alpha, beta = x[:n_obj], x[n_obj:]
                objective = alpha @ DF + beta @ DG
                return 0.5 * objective @ objective

            def jac(x, n_obj=n_obj, DF=DF, DG=DG):
                alpha, beta = x[:n_obj], x[n_obj:]
                objective = alpha @ DF + beta @ DG
                return np.vstack([DF, DG]) @ objective

            const = {'type': 'eq', 
                'fun': lambda x, n_obj=n_obj: np.sum(x[:n_obj]) - 1.0, 
                'jac': lambda x, n_obj=n_obj: np.concatenate([np.ones(n_obj), np.zeros_like(x[n_obj:])])}
        
        else: # when there's no active constraint
            
            def fun(x, DF=DF):
                objective = x @ DF
                return 0.5 * objective @ objective

            def jac(x, DF=DF):
                objective = x @ DF
                return DF @ objective

            const = {'type': 'eq', 
                    'fun': lambda x: np.sum(x) - 1.0, 
                    'jac': np.ones_like}

        # specify different bounds for alpha and beta
        bounds = np.array([[0.0, np.inf]] * (n_obj + n_active_const))
        
        # NOTE: we use random value to initialize alpha for now, maybe consider the location of F we can get a more accurate initialization
        alpha_init = np.random.random(len(F))
        alpha_init /= np.sum(alpha_init)
        beta_init = np.zeros(n_active_const) # zero initialization for beta
        x_init = np.concatenate([alpha_init, beta_init])

        # do optimization using SLSQP
        res = minimize(fun, x_init, method='SLSQP', jac=jac, bounds=bounds, constraints=const)
        x_opt = res.x
        alpha_opt, beta_opt = x_opt[:n_obj], x_opt[n_obj:]
        return alpha_opt, beta_opt

    def _first_order_approximation(
            self, 
        ):
        def get_mean_for_grad(_x_torch, index):
            posterior = self.model_list.models[index].posterior(_x_torch)
            mean_pred = -posterior.mean
            return mean_pred.sum()
        x_candidate_torch = torch.tensor(self.x_opt, dtype=torch.double, requires_grad=True).unsqueeze(0)
        x_samples = np.array([self.x_opt])
        try:
            with torch.no_grad(), gpytorch.settings.cholesky_jitter(double=1e-1):
                mean_pred = -self.model_list.posterior(x_candidate_torch).mean

            cand_hess = []
            cand_grad = []
            for i in range(self.num_objectives):
                cand_grad.append(torch.autograd.functional.jacobian(lambda x: get_mean_for_grad(x, i), x_candidate_torch, create_graph=True).squeeze())
                cand_hess.append(torch.autograd.functional.hessian(lambda x: get_mean_for_grad(x, i), x_candidate_torch, create_graph=True).squeeze())
            cand_hess = torch.stack(cand_hess)
            cand_grad = torch.stack(cand_grad)
            F = mean_pred.detach().numpy().squeeze()
            DF = cand_grad.detach().numpy().squeeze()
            HF = cand_hess.detach().numpy()
            fft_bounds = [np.zeros(self.dim), np.ones(self.dim)]
            G, DG, HG = self._get_box_const_value_jacobian_hessian(self.x_opt, fft_bounds)
            
            alpha, beta = self._get_kkt_dual_variables(F, G, DF, DG)

            n_obj = len(F)
            n_var = self.dim
            n_active_const = len(G) if G is not None else 0

            if n_active_const > 0:
                H = HF.T @ alpha + HG.T @ beta
            else:
                H = HF.T @ alpha
            alpha_const = np.concatenate([np.ones(n_obj), np.zeros(n_active_const + n_var)])
            if n_active_const > 0:
                comp_slack_const = np.column_stack([np.zeros((n_active_const, n_obj + n_active_const)), DG])
                DxHx = np.vstack([alpha_const, comp_slack_const, np.column_stack([DF.T, DG.T, H])])
            else:
                DxHx = np.vstack([alpha_const, np.column_stack([DF.T, H])])
            directions = null_space(DxHx)

            # eliminate numerical error
            eps = 1e-8
            directions[np.abs(directions) < eps] = 0.0

            # TODO: check why unused d_alpha and d_beta here
            d_alpha, d_beta, d_x = directions[:n_obj], directions[n_obj:n_obj + n_active_const], directions[-n_var:]
            eps = 1e-8
            if np.linalg.norm(d_x) < eps: # direction is a zero vector
                raise ValueError("Direction is a zero vector")
            direction_dim = d_x.shape[1]

            if direction_dim > n_obj - 1:
                # more than d-1 directions to explore, randomly choose d-1 sub-directions
                indices = np.random.choice(np.arange(direction_dim), n_obj - 1)
                while np.linalg.norm(d_x[:, indices]) < eps:
                    indices = np.random.choice(np.arange(direction_dim), n_obj - 1)
                d_x = d_x[:, indices]
            elif direction_dim < n_obj - 1:
                # less than d-1 directions to explore, do not expand the point
                all_candidates = np.vstack([all_candidates, x_samples])
                raise ValueError("Less than d-1 directions to explore")
            
            # normalize direction
            d_x /= np.linalg.norm(d_x)

            # NOTE: Adriana's code also checks if such direction has been expanded, but maybe unnecessary
            upper_bound = fft_bounds[1]
            lower_bound = fft_bounds[0]
            n_grid_sample = 100

            # grid sampling on expanded surface (NOTE: more delicate sampling scheme?)
            bound_scale = np.expand_dims(upper_bound - lower_bound, axis=1)
            d_x *= bound_scale
            loop_count = 0 # avoid infinite loop when it's hard to get valid samples
            while len(x_samples) < n_grid_sample:
                # compute expanded samples
                curr_dx_samples = np.sum(np.expand_dims(d_x, axis=0) * np.random.random((n_grid_sample, 1, n_obj - 1)), axis=-1)
                curr_x_samples = np.expand_dims(self.x_opt, axis=0) + curr_dx_samples
                # check validity of samples
                valid_idx = np.where(np.logical_and((curr_x_samples <= upper_bound).all(axis=1), (curr_x_samples >= lower_bound).all(axis=1)))[0]
                x_samples = np.vstack([x_samples, curr_x_samples[valid_idx]])
                loop_count += 1
                if loop_count > 10:
                    break
            x_samples = x_samples[:n_grid_sample]
            # print(f'len(x_samples): {len(x_samples)}')
        except Exception as e:
            #print(f"Error when first order approximation: {e}. Select original candidates...")
            pass
        return x_samples
    
    def do(self):
        return self._first_order_approximation()


class ParetoEstimationSet():
    def __init__(self, n, xs, ys, boundary_points, ortho_direction):
        '''
        A data structure that stores data points that approximate PF. 
        Each cell represents a region in the objective space that is around a line passing through
        points on approximated CHIM (boundary_points), following ortho_direction. 
        The points are discretized from the approximated CHIM by the number of n.
        The vector is the OSD direction.
        If we want to store more data points, i.e., denser approximation of PF, 
        we can increase n or max_sample_per_cell
        Note: this data structure is different from that of DGEMO (Lukov et al., 2022), as this data structure
        only stores data, and is not used for graph cut, so it does not need to 
        carefully design edge map for each number of objectives.

        '''
        num_objectives = boundary_points.shape[1]
        convex_set = get_reference_directions("energy", num_objectives, n, seed=np.random.randint(1e6))
        discretized_chim = np.dot(convex_set, boundary_points)
        ortho_directions = np.tile(ortho_direction, (n, 1))

        self.ids = [[] for _ in range(n)]
        self.y = [[] for _ in range(n)]
        self.x = [[] for _ in range(n)]
        self.t = [[] for _ in range(n)]
        self.upoint = np.atleast_2d(discretized_chim)
        self.unormal = np.atleast_2d(ortho_directions)
        self.max_sample_per_cell = 10 # increase to store more points per cell
        # distribute data points to each cell
        self.insert_data(xs, ys, (-np.ones(len(xs), dtype=int)).tolist())
        self.update_cell()
    
    @property
    def n_samples(self):
        return len(np.hstack(self.ids))
    
    @property
    def n_cell(self):
        return len(self.y)
    
    @property
    def n_obj(self):
        return self.upoint.shape[1]
    
    @property
    def n_empty_cells(self):
        return len(np.where([len(ys) == 0 for ys in self.y])[0])
    
    @staticmethod
    def project_point_to_line(_f, _point_on_line, _line_normal):
        _f = np.atleast_2d(_f)
        pa = _f - _point_on_line
        ba = np.atleast_2d(_line_normal)
        t = (pa*ba).sum(1) / (ba*ba).sum(1)
        projected_point = _point_on_line + t.reshape(-1, 1)*ba
        return t, projected_point
    
    def update_cell(self):
        non_empty_cells = np.where([len(ys) > 0 for ys in self.y])[0]
        for cell_idx in non_empty_cells:
            ys = np.array(self.y[cell_idx])
            xs = np.array(self.x[cell_idx])
            ts = np.array(self.t[cell_idx])
            idx = np.argsort(ts)[::-1]
            self.y[cell_idx] = np.array(self.y[cell_idx])[idx][:self.max_sample_per_cell].tolist()
            self.x[cell_idx] = np.array(self.x[cell_idx])[idx][:self.max_sample_per_cell].tolist()
            self.t[cell_idx] = np.array(self.t[cell_idx])[idx][:self.max_sample_per_cell].tolist()
            self.ids[cell_idx] = np.array(self.ids[cell_idx])[idx][:self.max_sample_per_cell].tolist()

    def insert_data(self, xs, ys, ids: list):
        for x, y, patch_id in zip(xs, ys, ids):
            dist = distance_to_line(y, self.upoint, self.unormal)
            argmin = np.argmin(dist)
            if len(self.x[argmin]) > 0 and np.any(np.all(x == self.x[argmin], axis=1)):
                continue
            self.y[argmin].append(y)
            self.x[argmin].append(x)
            self.ids[argmin].append(patch_id)
            ts = ParetoEstimationSet.project_point_to_line(y, self.upoint[argmin], self.unormal[argmin])[0]
            self.t[argmin].append(ts.item())
        ...