import math
import torch
from scipy.stats import qmc
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
import logging

class SALOptimizer():
    def __init__(self,
                 xmin, xmax, var_in, ObjectiveFactory, constraints,
                 optimization_iter, failed_optimizations_relaxation_iter_bound, failed_optimization_sigma_relaxation,
                 opts, device,
                 max_dist = None):
        
        self.xmin = xmin
        self.xmax = xmax
        self.xdim = self.xmin.shape[0]
        self.var_in = var_in
        self.ObjectiveFactory = ObjectiveFactory
        self.constraints = constraints
        self.optimization_iter = optimization_iter
        self.failed_optimizations_relaxation_iter_bound = failed_optimizations_relaxation_iter_bound
        self.failed_optimization_sigma_relaxation = failed_optimization_sigma_relaxation
        self.opts = opts
        self.device = device
        self.max_dist = max_dist

        self.logger = logging.getLogger('log_sal')

        self.xdiff = self.xmax-self.xmin

        self.fn_evals = []
        self.iters = []
        self.cis = []
        self.newxs = []
        self.optima = []
        self.changed_optima = 0
        self.failed_optimization_runs = 0

    def optimize(self, t, gp, starting_point = None, max_dist = None):
        
        function = self.optimization_function(t, gp)

        if max_dist != None:
            self.max_dist = max_dist

        # find optimized measurement position (take safety into consideration)
        optimum = math.inf
        iters = 0 
        fn_evals = 0
        j = 0
        failed_optimizations = 0
        while j<self.optimization_iter: # do the amount of allowed optimization runs
            self.logger.info(f'start optimization run {j}')
            if j==0 and failed_optimizations == 0 and starting_point != None:
                self.opts.x0 = starting_point
            else:
                self.opts.x0 = (self.xdiff * torch.rand((self.xdim), device=self.device, dtype=torch.double) + self.xmin)
            if len(self.opts.x0.shape)==1:
                self.opts.x0 = self.opts.x0.unsqueeze(1)
            with torch.profiler.record_function("PyGranso"):
                sol = pygranso(var_spec=self.var_in, combined_fn=function, user_opts=self.opts)
            iters = iters + sol.iters
            fn_evals = fn_evals + sol.fn_evals
            self.logger.info("optimization run complete")
            if sol.final.f<optimum and torch.min(sol.final.ci<=1e-4).item(): # the second part checks, whether constraints are being kept
                soln = sol
                optimum = soln.final.f
                self.changed_optima = self.changed_optima + 1
                self.logger.info("found new optimum")
            elif not torch.min(sol.final.ci<=1e-4).item():
                self.logger.warn("failed optimization!")
                self.logger.warn("ci: " + str(sol.final.ci))
                self.logger.warn("x: " + str(sol.final.x))
                self.logger.warn("f: " + str(sol.final.f))
                failed_optimizations = failed_optimizations + 1
                self.failed_optimization_runs = self.failed_optimization_runs + 1
                if failed_optimizations<2:
                    j = j - 1
            j = j + 1
        if math.isinf(optimum):
            ci = "??ci??"
            newx = None # This will be dealt with outside
        else:
            ci = soln.final.ci.cpu().numpy()[:,0]
            newx = soln.final.x.squeeze(1).cpu().numpy()
        self.logger.info("time: %d   opti_iters: %d   fn: %d   opt: %f   ci: %s" % (
            t, 
            iters,
            fn_evals,
            optimum,
            ci)
        )
        self.cis.append(ci)
        self.newxs.append(newx)
        self.iters.append(iters)
        self.fn_evals.append(fn_evals)
        self.optima.append(optimum)

        return newx
    
    def optimization_function(self, t, gp):

        objective = self.ObjectiveFactory(gp, t, self.device)

        def function(x_struct):
            x_wo_time = x_struct.x
            x_with_time = torch.cat((t, x_struct.x))
            x_with_time = x_with_time.unsqueeze(0)
            y = gp(x_with_time)

            obj = objective(x_with_time)

            ci = pygransoStruct()
            for i in range(self.var_in['x'][0]):
                setattr(
                    ci,
                    "xmin"+str(i),
                    -(x_wo_time[i]-self.xmin[i])
                )
                setattr(
                    ci,
                    "xmax"+str(i),
                    +(x_wo_time[i]-self.xmax[i])
                )
            for i in range(len(self.constraints)):
                c = self.constraints[i]
                res = c[0]
                for (a,b) in zip(c[1],y):
                    res = res + a*(b.mean)
                for (a,b) in zip(c[2],y):
                    res = res + a*torch.sqrt(b.variance)
                setattr(
                    ci, 
                    "constraint" + str(i), 
                    res
                )

            if self.max_dist != None:
                ind_old = list(set(range(gp.x.shape[1]))- set(gp.previous_inputs_list) - set([0]))
                x_old = gp.x[-1,ind_old]
                c_dist = torch.norm(x_wo_time-x_old) - self.max_dist
                setattr(
                    ci, 
                    "distance", 
                    c_dist
                )
                # Inner bounds, redundant but might help optimization
                for i in range(self.var_in['x'][0]):
                    setattr(
                        ci,
                        "xmin_inner_"+str(i),
                        +(-x_wo_time[i]+x_old[i]-self.max_dist)
                    )
                    setattr(
                        ci,
                        "xmax_inner_"+str(i),
                        +(x_wo_time[i]-x_old[i]-self.max_dist)
                    )

            
            ce = None

            return [obj,ci,ce]
        return function
    
    def test(self,x):
        return 12.43542745 + 0.5*torch.log(29.18752846*(torch.special.erf(0.6666666667*x[0,2] + 2.666666667) - 1.*torch.special.erf(0.6666666667*x[0,2] - 2.666666667))*torch.special.erf(-2.666666667 + 0.6666666667*x[0,1]) + 1057.381136 - 29.18752846*(torch.special.erf(0.6666666667*x[0,2] + 2.666666667) - 1.*torch.special.erf(0.6666666667*x[0,2] - 2.666666667))*torch.special.erf(2.666666667 + 0.6666666667*x[0,1]))
    
