import logging
import os
import time
from copy import deepcopy

import numpy as np
import torch
import torch.multiprocessing as tmp
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.utils.multi_objective.box_decompositions import \
    FastNondominatedPartitioning
from botorch.utils.sampling import draw_sobol_samples
from pymoo.indicators.hv import Hypervolume
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
from pymoo.util.ref_dirs import get_reference_directions
from scipy.optimize import minimize
from sklearn.preprocessing import StandardScaler

from _test_functions.objective_function import ObjectiveFunction
from mobo_osd.gp import estimate_mean_and_std, train_gp
from mobo_osd.helper import distance_to_line, get_nadir_ideal_points
from mobo_osd.pareto_estimation import (FirstOrderApproximation,
                                        ParetoEstimationSet)
from mobo_osd.qHVI import qHypervolumeImprovement


class MOBO_OSD:
    def __init__(self, 
                 objective_function: ObjectiveFunction, 
                 n_maxeval: int,
                 batch_size: int = 4,
                 n_init: int = 20,
                 n_beta: int = 20,
                 **kwargs):
        # This class assume minimization
        self.objective_function = objective_function
        self.kwargs = kwargs
        self.batch_size = batch_size
        self.n_init = n_init
        self.n_beta = n_beta
        self.n_maxeval = n_maxeval
        self.delta = 1.96

        # Some checks
        assert np.allclose(objective_function.bounds[:, 0], 0) and np.allclose(objective_function.bounds[:, 1], 1), 'Please scale the input to [0, 1]'

        # Observation history
        self.observed_x = np.zeros((0, self.objective_function.input_dims))     # This is always unit cube
        self.observed_f = np.zeros((0, self.objective_function.num_objectives)) # The true observed f values
        self.observed_f_scaled = np.zeros((0, self.objective_function.num_objectives)) # The scaled observed f values

    @property
    def num_objectives(self):
        return self.objective_function.num_objectives

    @property
    def input_dims(self):
        return self.objective_function.input_dims
    
    def _compute_boundary_points(self):
        ideal_point = np.min(self.observed_f_scaled, axis=0)
        nadir_point = np.max(self.observed_f_scaled, axis=0)
        boundary_points = np.repeat(np.atleast_2d(ideal_point), self.num_objectives, axis=0)
        for i in range(self.num_objectives):
            boundary_points[i][i] = nadir_point[i]
        return boundary_points
    
    def _compute_osd(self, boundary_points):
        ideal_point = np.min(self.observed_f_scaled, axis=0)
        anchor_values_positive = boundary_points - ideal_point
        quasi_normal = -np.dot(anchor_values_positive.T, np.ones((self.num_objectives, 1))).flatten()
        quasi_normal = quasi_normal / np.linalg.norm(quasi_normal)
        return quasi_normal
    
    def _optimize_osb_subprob_each_batch(
            self,
            idx_ref_dir,
            dim, 
            u_beta_points, 
            ortho_search_directions, 
            model_list, 
            observed_x,
            observed_f,
            queue: tmp.Queue,
        ):
        list_x = []
        list_idx = []
        for (idx, u_beta_point, ortho_search_direction) in zip(
            idx_ref_dir,
            u_beta_points,
            ortho_search_directions,
        ):
            x_samples = self._optimize_osb_subprob(
                dim,
                u_beta_point,
                ortho_search_direction,
                model_list,
                observed_x,
                observed_f,
            )
            list_x.append(x_samples)
            list_idx.append((np.ones(len(x_samples), dtype=int)*idx).tolist())
        queue.put((list_idx, list_x))

    def _optimize_osb_subprob(
            self,
            dim, 
            u_beta_point, 
            ortho_search_direction, 
            model_list,
            observed_x,
            observed_f,
        ):
        # Initialize candidates: 
        # 1 closest point to the line, 
        # 1 non-dominated closest point to the line, 
        # 2 random points
        nd_front = NonDominatedSorting().do(observed_f, only_non_dominated_front=True)
        dist = distance_to_line(observed_f, u_beta_point, ortho_search_direction)
        top_k = np.argsort(dist)[0]
        top_k_nd = nd_front[np.argsort(dist[nd_front])[0]]
        if top_k != top_k_nd:
            top_k = np.hstack([top_k, top_k_nd])
        x_trials = deepcopy(observed_x[top_k])
        x_trials = np.vstack([x_trials, np.random.rand(2, dim)])
        # Optimize the subproblem for each initial point
        ress = []
        for x0 in x_trials:
            res = self._minimize_osd_subproblem(
                x0=x0,
                u_beta_point=u_beta_point,
                ortho_search_direction=ortho_search_direction,
                model_list=model_list,
            )
            ress.append(res)
        # Do a bi-objective selection based on lambda value and distance to line (Appendix A.4)
        if len(ress) > 1:
            # Estimate distance to line
            res_xs = np.array([res.x for res in ress])
            temp_fxs, _ = estimate_mean_and_std(res_xs, model_list)
            dist = distance_to_line(temp_fxs, u_beta_point, ortho_search_direction)
            # Construct set of candidates (lambda, distance)
            candidates = [[res.fun, d] for res, d in zip(ress, dist)]
            candidates = np.array(candidates)
            candidates_nadir, candidates_ideal = get_nadir_ideal_points(candidates, minimize=True)
            candidates_refpoint = candidates_nadir + (candidates_nadir - candidates_ideal) * 0.1
            # Compute hypervolume contribution
            hv_sol = Hypervolume(ref_point=candidates_refpoint)
            full_hv = hv_sol(candidates)
            hvc = []
            for i in range(len(candidates)):
                candidates_wo_sol = np.delete(candidates, i, axis=0)
                hv_sol_wo_sol = Hypervolume(ref_point=candidates_refpoint)
                hvc.append(full_hv-hv_sol_wo_sol(candidates_wo_sol))
            hvc = np.array(hvc)
            # Maximize hypervolume contribution
            argmax_hvc = np.argmax(hvc)
            res = ress[argmax_hvc]
        else:
            res = ress[0]
        x_opt = res.x

        # Use a Pareto Front Estimation technique to approximate PF around x_opt
        # In this work, we use First Order Approximation from Schulz et al., 2018
        pfe = FirstOrderApproximation(
            x_opt=x_opt,
            model_list=model_list,
        )    
        x_samples = pfe.do()
        return x_samples

    def _minimize_osd_subproblem(
        self,
        x0,
        u_beta_point,
        ortho_search_direction, 
        model_list,   
    ):
        dim = len(x0)
        bound_np = np.column_stack([np.zeros(dim), np.ones(dim)]) # unit cube

        def fun(x, model_list: ModelListGP, u_beta_point, ortho_search_direction):
            # Main objective of Eq. 2
            x_torch = torch.tensor(np.atleast_2d(x), dtype=torch.double, requires_grad=True)
            posterior = model_list.posterior(x_torch)
            point_to_proj = -posterior.mean
            point_on_line = torch.tensor(np.atleast_2d(u_beta_point), dtype=torch.double)
            line_normal = torch.tensor(np.atleast_2d(ortho_search_direction), dtype=torch.double)
            pa = point_to_proj - point_on_line
            ba = line_normal
            t = (pa*ba).sum(1) / (ba*ba).sum(1)
            d1 = -t
            d1.backward()
            grad = x_torch.grad
            return d1.detach().numpy().item(), grad.detach().numpy().reshape(-1)

        def contraints(x, model_list: ModelListGP, u_beta_point, ortho_search_direction):
            # Constraints of Eq. 2
            x_torch = torch.tensor(np.atleast_2d(x), dtype=torch.double, requires_grad=True)
            with torch.no_grad():
                posterior = model_list.posterior(x_torch)
                point_to_proj = -posterior.mean
            std = posterior.variance.clamp_min(1e-9).sqrt()
            point_on_line = torch.tensor(np.atleast_2d(u_beta_point), dtype=torch.double)
            line_normal = torch.tensor(np.atleast_2d(ortho_search_direction), dtype=torch.double)
            pa = point_to_proj - point_on_line
            ba = line_normal
            t = (pa*ba).sum(1) / (ba*ba).sum(1)

            ucb = point_to_proj + self.delta*std
            lcb = point_to_proj - self.delta*std
            project_point = point_on_line + t.reshape(-1, 1)*ba
            upper_constraint_val = ucb - project_point
            lower_constraint_val = project_point - lcb
            g = torch.hstack((upper_constraint_val, lower_constraint_val))
            return g.detach().numpy().reshape(-1)

        def jac_contraints(x, model_list: ModelListGP, u_beta_point, ortho_search_direction):
            # Jacobian of constraints of Eq. 2
            x_torch = torch.tensor(np.atleast_2d(x), dtype=torch.double, requires_grad=True)
            posterior = model_list.posterior(x_torch)
            point_to_proj = -posterior.mean
            std = posterior.variance.clamp_min(1e-9).sqrt()
            point_on_line = torch.tensor(np.atleast_2d(u_beta_point), dtype=torch.double)
            line_normal = torch.tensor(np.atleast_2d(ortho_search_direction), dtype=torch.double)
            pa = point_to_proj - point_on_line
            ba = line_normal
            t = (pa*ba).sum(1) / (ba*ba).sum(1)
            ucb = point_to_proj + self.delta*std
            lcb = point_to_proj - self.delta*std
            project_point = point_on_line + t.reshape(-1, 1)*ba
            upper_constraint_val = ucb - project_point
            lower_constraint_val = project_point - lcb
            g = torch.hstack((upper_constraint_val, lower_constraint_val)).sum(0)
            grads = []
            gradient_component = torch.zeros(g.shape[0])
            for i in range(g.shape[0]):
                gradient_component[i] = 1.0
                g.backward(gradient_component, retain_graph=True)
                grads.append(x_torch.grad.detach().clone().numpy().flatten())
                x_torch.grad.zero_()
                gradient_component[i] = 0.0
            return np.array(grads)
        
        args = (model_list, np.atleast_2d(u_beta_point), np.atleast_2d(ortho_search_direction))
        constraints_list = [{
            'type': 'ineq',
            'fun': contraints,
            'args': args,
            'jac': jac_contraints,
        }]

        res = minimize(
            fun=fun,
            x0=x0,
            args=args,
            bounds=bound_np,
            constraints=constraints_list,
            method='SLSQP',
            jac=True,
        )

        return res

    def select_by_hvi_kb(
        self,
        candidate_x: np.ndarray,
        original_gp: ModelListGP,
        labels: np.ndarray,
        objective_ref_point: np.ndarray,
    ):
        candidate_x_remain = deepcopy(candidate_x)
        x_vals = []
        f_vals = []
        observed_f_virtual = deepcopy(self.observed_f_scaled)
        observed_x_virtual = deepcopy(self.observed_x)
        for i in range(self.batch_size):
            logging.info(f'\t*** Start batch: {i+1}/{self.batch_size}: ***')
            start_time = time.time()
            assert candidate_x.shape[0] == labels.shape[0]
            if len(x_vals) > 0:
                model_list_virtual = train_gp(observed_x_virtual, observed_f_virtual)
            else:
                model_list_virtual = original_gp
            if len(candidate_x) == 0:
                candidate_x = deepcopy(candidate_x_remain)
            candidate_f = estimate_mean_and_std(candidate_x, model_list_virtual)[0]
            logging.info(f'\tRetrain GP and Estimate mean and std: {time.time() - start_time:.2f}s')
            start_hvc = time.time()
            qehvi_ref_point = -torch.tensor(objective_ref_point, dtype=torch.double)
            partitioning = FastNondominatedPartitioning(
                ref_point=qehvi_ref_point,
                Y=-torch.tensor(observed_f_virtual, dtype=torch.double),
            )
            qHVI = qHypervolumeImprovement(
                model=model_list_virtual,
                ref_point=qehvi_ref_point,
                partitioning=partitioning,
            )
            max_batch_size = 50 if self.num_objectives > 9 else 2048
            X_candidate = torch.tensor(candidate_x, dtype=torch.double)
            hvc = np.zeros(0)
            for xx in X_candidate.split(max_batch_size):
                hvc_i = qHVI(xx.unsqueeze(1)).detach().cpu().numpy()
                hvc = np.concatenate([hvc, hvc_i])
            # hvc = _qHVI(torch.tensor(_candidate_x, dtype=torch.double).unsqueeze(1).to(device=device)).detach().cpu().numpy()
            logging.info(f'\tHVI computation: {time.time() - start_hvc:.2f}s')
            if np.sum(hvc) == 0:
                idx_best = np.random.choice(candidate_f.shape[0])
            else:
                idx_best = np.argmax(hvc)
            label_best = labels[idx_best]
            cand_x_best = candidate_x[idx_best]
            cand_f_best = candidate_f[idx_best]
            x_vals.append(cand_x_best)
            f_vals.append(cand_f_best)
            observed_f_virtual = np.vstack((observed_f_virtual, candidate_f[idx_best]))
            observed_x_virtual = np.vstack((observed_x_virtual, candidate_x[idx_best]))
            idx_to_keep = np.where(labels != label_best)[0] # remove points in the same exploration space
            candidate_x = candidate_x[idx_to_keep]
            candidate_f = candidate_f[idx_to_keep]
            labels = labels[idx_to_keep]
            candidate_x_remain = np.delete(candidate_x_remain, idx_best, axis=0)
            logging.info(f'\tEnd batch {i+1}/{self.batch_size}: {time.time() - start_time:.2f}s')
    
        return np.array(x_vals), np.array(f_vals)

    def run_optimization(self):
        tmp.set_start_method('spawn', force=True)
        start_run = time.time()

        # Draw random initial points
        x_init = draw_sobol_samples(
            bounds=torch.tensor(self.objective_function.bounds.tolist(), dtype=torch.double).t(), 
            n=self.n_init, 
            q=1
        ).squeeze(1).numpy()
        f_init, _ = self.objective_function(x_init)
        self.observed_x = np.vstack((self.observed_x, x_init))
        self.observed_f = np.vstack((self.observed_f, f_init))
        logging.info(f'Initialization: {time.time() - start_run:.2f}s')

        # Main optimization loop
        while len(self.observed_x) < self.n_maxeval:
            start_iter = time.time()
            # scale output to zero mean and unit variance
            y_scaler = StandardScaler()
            y_scaler.fit(self.observed_f)
            self.observed_f_scaled = y_scaler.transform(self.observed_f)
            objective_ref_point = y_scaler.transform(np.atleast_2d(self.objective_function.ref_point)).squeeze()
            
            # Approxiated CHIM via boundary points and orthogonal search directions
            boundary_points = self._compute_boundary_points()
            ortho_search_directions = self._compute_osd(boundary_points)
            convex_points = get_reference_directions("energy", self.num_objectives, self.n_beta, seed=np.random.randint(1e6))
            all_u_beta_points = np.dot(convex_points, boundary_points)
            logging.info(f'Iter {len(self.observed_f)}/{self.n_maxeval} - Approximate CHIM: {time.time() - start_iter:.2f}s')
            check_point = time.time()

            # Train GP models
            model_list = train_gp(self.observed_x, self.observed_f_scaled)
            logging.info(f'Iter {len(self.observed_f)}/{self.n_maxeval} - Train GP: {time.time() - check_point:.2f}s')
            check_point = time.time()

            # Do parallel optimization for mobo-osd-subproblems
            if os.environ.get('SLURM_CPUS_PER_TASK') is not None:
                ncpus = int(os.environ['SLURM_CPUS_PER_TASK'])
            else:
                ncpus = tmp.cpu_count()
            batch_idx_ref_points = np.array_split(np.arange(len(all_u_beta_points)), ncpus)
            batch_u_beta_points = np.array_split(all_u_beta_points, ncpus)
            queue = tmp.Queue()
            process_count = 0
            for idxs_batch_i, u_beta_points_batch_i in zip(batch_idx_ref_points, batch_u_beta_points):
                p = tmp.Process(
                    target=self._optimize_osb_subprob_each_batch, 
                    args=(
                        idxs_batch_i, 
                        self.input_dims, 
                        u_beta_points_batch_i, 
                        np.tile(ortho_search_directions, (len(u_beta_points_batch_i), 1)), # all osd are similar
                        deepcopy(model_list), 
                        self.observed_x, 
                        self.observed_f_scaled, 
                        queue,
                    )
                )
                p.start()
                process_count += 1

            all_candidates = []
            all_idx = []
            for _ in range(process_count):
                idx, x = queue.get()
                all_idx.extend(idx)
                all_candidates.extend(x)
            ...
            x_samples_all = np.vstack(all_candidates)
            y_samples_all = estimate_mean_and_std(x_samples_all, model_list)[0]
            idx_patch = np.hstack(all_idx).tolist()
            logging.info(f'Iter {len(self.observed_f)}/{self.n_maxeval} - Solve MOBO-OSD subproblems: {time.time() - check_point:.2f}s')
            check_point = time.time()
            
            # Update Pareto Front Estimation set
            pfs = ParetoEstimationSet(
                n=100, # increase to store denser points
                xs=self.observed_x, 
                ys=self.observed_f_scaled, 
                boundary_points=boundary_points, 
                ortho_direction=ortho_search_directions
            )
            pfs.insert_data(x_samples_all, y_samples_all, idx_patch)
            pfs.update_cell()
            # Get the Estimated Pareto Set
            approx_x = np.concatenate([pfs_x for pfs_x in pfs.x if len(pfs_x) > 0], axis=0)
            approx_y = np.concatenate([pfs_y for pfs_y in pfs.y if len(pfs_y) > 0], axis=0)
            labels = np.concatenate([pfs_id for pfs_id in pfs.ids if len(pfs_id) > 0], axis=0)
            logging.info(f'Iter {len(self.observed_f)}/{self.n_maxeval} - Pareto Front Estimation: {time.time() - check_point:.2f}s')
            check_point = time.time()

            # Use qHVI to select points from the Estimated Pareto Set
            if len(approx_x) < self.batch_size:
                logging.info(f"{len(self.observed_f_scaled)}/{self.n_maxeval}: Not enough points in pareto set")
                x_val = deepcopy(approx_x)
                n_remaining = self.batch_size - len(approx_x)
                if n_remaining > 0:
                    x_val = np.vstack((x_val, np.random.uniform(0, 1, (n_remaining, self.dim))))
            else:  
                approx_x = np.array(approx_x)[np.array(labels) >= 0] # remove original points
                labels = np.array(labels)[np.array(labels) >= 0]
                x_val, _ = self.select_by_hvi_kb(
                    approx_x,
                    model_list,
                    labels,
                    objective_ref_point,
                )
                pred_means, pred_stds = estimate_mean_and_std(np.atleast_2d(x_val), model_list)
            logging.info(f'Iter {len(self.observed_f)}/{self.n_maxeval} - Aqf HVI: {time.time() - check_point:.2f}s')
            check_point = time.time()

            # Evaluate the selected points, and add to the observation history
            f_val, _ = self.objective_function(np.atleast_2d(x_val))
            self.observed_x = np.vstack((self.observed_x, x_val))
            self.observed_f = np.vstack((self.observed_f, f_val))
            # Compute and print hypervolume (non-scaled objective value)
            hv_pymoo = Hypervolume(ref_point=self.objective_function.ref_point)
            logging.info(f"##### [FINISH] Observed {len(self.observed_f)}/{self.n_maxeval} - HV={hv_pymoo(self.observed_f.squeeze()) :.4f}, iter time: {time.time() - start_iter:.2f}s #####\n")
            ...
        return self.observed_x, self.observed_f