#!/usr/bin/env python
# coding: utf-8
"""
Utility function
"""

import numpy as np
from pyepo import EPO
from pyepo.utlis import getArgs
from gurobipy import GRB

def _solve_in_pass(cp, optmodel, processes, pool):
    """
    A function to solve optimization in the forward/backward pass
    """
    # number of instance
    ins_num = len(cp)
    # single-core
    if processes == 1:
        sol = []
        obj = []
        for i in range(ins_num):
            # solve
            optmodel.setObj(cp[i])
            solp, objp = optmodel.solve()
            sol.append(solp)
            obj.append(objp)
    # multi-core
    else:
        # get class
        model_type = type(optmodel)
        # get args
        args = getArgs(optmodel)
        # parallel computing
        res = pool.amap(_solveWithObj4Par, cp, [args] * ins_num,
                        [model_type] * ins_num).get()
        # get res
        sol = np.array(list(map(lambda x: x[0], res)))
        obj = np.array(list(map(lambda x: x[1], res)))
    return sol, obj

def _cspo_solve_in_pass(cp, selected_models, processes, pool, warm_start):
    """
    A function to solve optimization in the forward/backward pass
    """
    # number of instance
    ins_num = len(cp)
    # single-core
    if processes == 1:
        sol = []
        obj = []
        for i in range(ins_num):
            # solve
            optmodel = selected_models[i]
            optmodel.setObj(cp[i])
            solp, objp = optmodel.solve()
            sol.append(solp)
            obj.append(objp)
            ## Warm start Gurobi model.
            if warm_start and optmodel._model.status == GRB.OPTIMAL:
                solution = {v.varName: v.X for v in optmodel._model.getVars()}
                for v in optmodel._model.getVars():
                    v.Start = solution[v.varName]
                optmodel._model.update()
    # multi-core
    else:
        # get class
        model_type = []
        # get args
        args = []
        for i, optmodel in enumerate(selected_models):
            model_type.append(type(optmodel))
            args.append(getArgs(optmodel))
        # parallel computing
        res = pool.amap(_solveWithObj4Par, cp, args,
                        model_type).get()
        # get res
        sol = np.array(list(map(lambda x: x[0], res)))
        obj = np.array(list(map(lambda x: x[1], res)))
    return sol, obj


def _cache_in_pass(cp, optmodel, solpool):
    """
    A function to use solution pool in the forward/backward pass
    """
    # number of instance
    # ins_num = len(cp)
    # best solution in pool
    # print(f'solpool size {solpool.shape} and cost size {cp.shape}')
    solpool_obj = cp @ solpool.T
    if optmodel[0].modelSense == EPO.MINIMIZE:
        ind = np.argmin(solpool_obj, axis=1)
    if optmodel[0].modelSense == EPO.MAXIMIZE:
        ind = np.argmax(solpool_obj, axis=1)
    obj = np.take_along_axis(solpool_obj, ind.reshape(-1,1), axis=1).reshape(-1)
    sol = solpool[ind]
    return sol, obj

# def _cspo_cache_in_pass(cp, selected_models, solpool):
#     """
#     A function to use solution pool in the forward/backward pass
#     """
#     # number of instance
#     ins_num = len(cp)
#     # best solution in pool
#     solpool_obj = cp @ solpool.T
#     if optmodel.modelSense == EPO.MINIMIZE:
#         ind = np.argmin(solpool_obj, axis=1)
#     if optmodel.modelSense == EPO.MAXIMIZE:
#         ind = np.argmax(solpool_obj, axis=1)
#     obj = np.take_along_axis(solpool_obj, ind.reshape(-1,1), axis=1).reshape(-1)
#     sol = solpool[ind]
#     return sol, obj

def _cspo_cache_in_pass(cp, selected_models, processes, pool, warm_start):
    """
    A function to use solution pool in the forward/backward pass for CSPO
    
    Args:
        cp (np.ndarray): predicted cost vectors
        selected_models (list): list of optimization models
        processes (int): number of processors
        pool (ProcessPool): process pool object
        warm_start (bool): whether to use warm start
        
    Returns:
        tuple: solutions (np.ndarray) and objectives (np.ndarray)
    """
    # number of instances
    ins_num = len(cp)
    sol = []
    obj = []
    
    for i in range(ins_num):
        optmodel = selected_models[i]
        model_idx = i % len(selected_models)  # Get the model index
        
        # Get the solution pool for this model
        if hasattr(optmodel, 'solpool') and optmodel.solpool is not None:
            solpool = optmodel.solpool
            # Calculate objectives for all solutions in pool
            solpool_obj = cp[i] @ solpool.T
            
            # Find best solution based on model sense
            if optmodel.modelSense == EPO.MINIMIZE:
                ind = np.argmin(solpool_obj)
            else:  # MAXIMIZE
                ind = np.argmax(solpool_obj)
                
            # Get the best solution and its objective
            best_sol = solpool[ind]
            best_obj = solpool_obj[ind]
            
            # If warm starting is enabled, set the solution as starting point
            if warm_start and hasattr(optmodel, '_model') and optmodel._model is not None:
                solution = {v.varName: best_sol[j] for j, v in enumerate(optmodel._model.getVars())}
                for v in optmodel._model.getVars():
                    v.Start = solution[v.varName]
                optmodel._model.update()
        else:
            # If no cached solutions, solve from scratch
            optmodel.setObj(cp[i])
            best_sol, best_obj = optmodel.solve()
            
        sol.append(best_sol)
        obj.append(best_obj)
        
    return np.array(sol), np.array(obj)

def _solveWithObj4Par(cost, args, model_type):
    """
    A function to solve function in parallel processors

    Args:
        cost (np.ndarray): cost of objective function
        args (dict): optModel args
        model_type (ABCMeta): optModel class type

    Returns:
        tuple: optimal solution (list) and objective value (float)
    """
    # rebuild model
    optmodel = model_type(**args)
    # set obj
    optmodel.setObj(cost)
    # solve
    sol, obj = optmodel.solve()
    return sol, obj


def _check_sol(c, w, z):
    """
    A function to check solution is correct
    """
    ins_num = len(z)
    for i in range(ins_num):
        if abs(z[i] - c[i] @ w[i]) / (abs(z[i]) + 1e-3) >= 1e-3:
            raise AssertionError(
                "Solution {} does not macth the objective value {}.".
                format(c[i] @ w[i], z[i][0]))


class sumGammaDistribution:
    """
    creates a generator of samples for the Sum-of-Gamma distribution
    """
    def __init__(self, kappa, n_iterations=10, seed=135):
        self.κ = kappa
        self.n_iterations = n_iterations
        self.rnd = np.random.RandomState(seed)

    def sample(self, size):
        # init samples
        samples = 0
        # calculate samples
        for i in range(1, self.n_iterations+1):
            samples += self.rnd.gamma(1/self.κ, self.κ/i, size)
        samples -= np.log(self.n_iterations)
        samples /= self.κ
        return samples
