import math
import copy
import numpy as np
import torch
import gpytorch
from pygranso.pygranso import pygranso
from pygranso.pygransoStruct import pygransoStruct
from pygranso.private.getNvar import getNvarTorch
import logging
from itertools import chain
# from torch.profiler import profile, record_function, ProfilerActivity, schedule

def get_string_representation_of_kernel(kernel_expression):
    if kernel_expression._get_name() == "AdditiveKernel":
        s = ""
        for k in kernel_expression.kernels:
            s += get_string_representation_of_kernel(k) + " + "
        return "(" + s[:-3] + ")"
    elif kernel_expression._get_name() == "ProductKernel":
        s = ""
        for k in kernel_expression.kernels:
            s += get_string_representation_of_kernel(k) + " * "
        return "(" + s[:-3] + ")"
    elif kernel_expression._get_name() == "ScaleKernel":
        return f"(c * {get_string_representation_of_kernel(kernel_expression.base_kernel)})"
    elif kernel_expression._get_name() == "RBFKernel":
        return "SE"
    elif kernel_expression._get_name() == "LinearKernel":
        return "LIN"
    elif kernel_expression._get_name() == "PeriodicKernel":
        return "PER"
    elif kernel_expression._get_name() == "MaternKernel":
        if kernel_expression.nu == 1.5:
            return "MAT32"
        elif kernel_expression.nu == 2.5:
            return "MAT52"
        else:
            raise "shit"
    elif kernel_expression._get_name() == "RQKernel":
        return "RQ"
    else:
        return kernel_expression._get_name()

def log_normalized_prior(model, device,prior_dict):
    if prior_dict==None:
        prior_dict = {'SE': {'raw_lengthscale' : {"mean": -0.5 , "std":1.0}},
                    'c':{'raw_outputscale':{"mean": 0.5, "std": 0.1 } },
                    'noise': {'raw_noise':{"mean": -3.0, "std": 1.0 } },
                    'mean': {'raw_constant':{"mean": 10, "std": 0.01 } }
                    }
    params = None
    variances_list = list()
    theta_mu = list()
    covar_string = get_string_representation_of_kernel(model.covar_module)
    covar_string = covar_string.replace("(", "")
    covar_string = covar_string.replace(")", "")
    covar_string = covar_string.replace(" ", "")
    covar_string_list = [s.split("*") for s in covar_string.split("+")]
    covar_string_list.insert(0, ["LIKELIHOOD"])
    covar_string_list.insert(1, ["mean_module"])
    covar_string_list = list(chain.from_iterable(covar_string_list))
    for (param_name, param), cov_str in zip(model.named_parameters(), covar_string_list):

        if params == None:
            params = param
        else:
            if len(param.shape)==0:
                params = torch.cat((params,param.unsqueeze(0)))
            elif len(param.shape)==1:
                params = torch.cat((params,param))
            else:
                params = torch.cat((params,param.squeeze(0)))
        # First param is (always?) noise and is always with the likelihood
        if "likelihood" in param_name:
            theta_mu.append(prior_dict["noise"]["raw_noise"]["mean"])
            variances_list.append(prior_dict["noise"]["raw_noise"]["std"])
            continue
        elif "mean" in param_name:
            theta_mu.append(prior_dict["mean"]["raw_constant"]["mean"])
            variances_list.append(prior_dict["mean"]["raw_constant"]["std"])
            continue
        else:
            try:
                dict_entry = prior_dict[cov_str][param_name.split(".")[-1]]
                if isinstance(dict_entry['mean'],float):
                    entries = int(np.prod(list(param.shape)))
                    for i in range(entries):
                        theta_mu.append(dict_entry["mean"])
                        variances_list.append(dict_entry["std"])
                else:
                    theta_mu = theta_mu + dict_entry["mean"]
                    variances_list = variances_list + dict_entry["std"]
            except Exception as E:
                import pdb
                pdb.set_trace()
                prev_cov = cov_str
    theta_mu = torch.tensor(theta_mu, device=device)
    theta_mu = theta_mu.unsqueeze(0).t()
    sigma = torch.diag(torch.tensor(variances_list, device=device))
    sigma = sigma@sigma
    prior = torch.distributions.MultivariateNormal(theta_mu.t(), sigma)

    # for convention reasons I'm diving by the number of datapoints
    return prior.log_prob(params) / len(*model.train_inputs)

class ConstantKernel(gpytorch.kernels.Kernel):
    is_stationary = True
    def forward(self, x1, x2, **params):
        return torch.ones(x1.shape[-2],x2.shape[-2], device=x1.device)

class GPFullSEARD(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPFullSEARD, self).__init__(train_x, train_y, likelihood)
        self.device = train_x.device
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=train_x.shape[1]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class GPTimeConstantSEARD(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPTimeConstantSEARD, self).__init__(train_x, train_y, likelihood)
        self.device = train_x.device
        self.mean_module = gpytorch.means.ConstantMean()
        d = train_x.shape[1]
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=d-1, active_dims=range(1,d)))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class SALGP():
    def __init__(self,x,y, model=None, sal_crit=0, prior_dict=None, previous_inputs_list=[]):
        self.x = x
        self.y = y
        self.numoutputs=y.shape[1]
        self.numinputs=x.shape[1]
        self.device = self.x.device
        self.prior_dict = prior_dict
        self.previous_inputs_list = previous_inputs_list

        self.likelihoods = [ gpytorch.likelihoods.GaussianLikelihood().to(self.device) for i in range(self.numoutputs) ]
        self.precomputation_of_matrices_for_marginalization = False
        self.sal_crit = sal_crit

        self.logger = logging.getLogger('log_sal')

        self.model = model
        self.models = [ self.get_gp(i) for i in range( self.y.shape[1] ) ]

        for (model, likelihood) in zip(self.models, self.likelihoods):
            model.eval()
            likelihood.eval()

        self.old_models = []

        self.train_adam(training_iter=10, lr=0.1)
    
    def get_gp(self, i):        
        if self.model == "GPTimeConstantSEARD":
            result = GPTimeConstantSEARD(self.x, self.y[:,i], self.likelihoods[i]).to(self.device)
        elif self.model == "GPFullSEARD":
            result = GPFullSEARD(self.x, self.y[:,i], self.likelihoods[i]).to(self.device)
        else:
            raise NotImplementedError("This GP Model is not known")
        for param_name,param in result.named_parameters():
            torch.nn.init.normal_(param, mean=0, std=2)
        return result

    def __call__(self, xx):
        x = xx
        if self.numinputs!=x.shape[1]:
            if x.shape[1] + len(self.previous_inputs_list) == self.numinputs:
                if x.shape[0]==1:
                    x = self.ExtendTensorByPreviousInputsListAndTrainingData(x)
                else:
                    x = self.ExtendTensorByPreviousInputsList(x)
            else:
                raise IndexError("wrong number of inputs in GP prediction")
        x = x.to(self.models[0].device)
        with torch.profiler.record_function("GP prediction"):
            return [ model(x) for model in self.models ]
        
    def set_train_data(self,x,y):
        self.x = x
        self.y = y
        self.precomputation_of_matrices_for_marginalization = False
        for i, model in enumerate(self.models):
            model.set_train_data(self.x, self.y[:,i], strict=False)
            model.eval()
            mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihoods[i], model)
            output = model(model.train_inputs[0])
            if len(model.train_targets.shape)>1:
                loss = -mll(output, model.train_targets.detach().squeeze(1))
            else:
                loss = -mll(output, model.train_targets.detach())
            log_p = log_normalized_prior(model, device=self.device, prior_dict=self.prior_dict)
            loss = loss - log_p

            model.NLL = loss.detach().item()

        return self
    
    def NLL_eval(self, x, y, t):
        result = []
        tt = t * torch.ones(x.shape[0], 1, device=self.device, dtype=torch.float64)
        tx = torch.concatenate((tt,x.to(self.device)),axis=1)
        tx = self.ExtendTensorByPreviousInputsList(tx)
        for i, model in enumerate(self.models):
            model.eval()
            likelihood = self.likelihoods[i]
            likelihood.eval()
            mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
            output = model(tx)
            loss = -mll(output, torch.tensor(y[:,i], device=self.device))
            result.append(loss.detach().item())
        return result

    def train_single_model_adam(self,model,likelihood,training_iter=500, lr=0.1):

        self.precomputation_of_matrices_for_marginalization = False

        model.train()
        likelihood.train()
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Includes GaussianLikelihood parameters

        for i in range(training_iter):
            # Zero gradients from previous iteration
            optimizer.zero_grad()
            # Output from model
            output = model(model.train_inputs[0])
            # Calc loss and backprop gradients
            if len(model.train_targets.shape)>1:
                loss = -mll(output, model.train_targets.detach().squeeze(1))
            else:
                loss = -mll(output, model.train_targets.detach())
            log_p = log_normalized_prior(model, device=self.device, prior_dict=self.prior_dict)
            loss = loss - log_p
            model.NLL = loss.detach().item()
            loss.backward()
            optimizer.step()
        model.eval()
        likelihood.eval()
        self.logger.info(f"loss: {model.NLL}")
        return model

    def train_adam(self,training_iter=500, lr=0.1):
            
        self.precomputation_of_matrices_for_marginalization = False

        for (model, likelihood) in zip(self.models, self.likelihoods):
            self.train_single_model_adam(model, likelihood, training_iter, lr)

        self.print_training_parameters()
        self.old_models.append(copy.deepcopy([model.state_dict() for model in self.models]))

        return self

    def train_single_model_pygranso(self,model,likelihood):

        self.precomputation_of_matrices_for_marginalization = False

        model.train()
        likelihood.train()
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

        opts = pygransoStruct()
        opts.torch_device = self.device
        nvar = getNvarTorch(model.parameters())
        opts.x0 = torch.nn.utils.parameters_to_vector(model.parameters()).detach().reshape(nvar,1)
        opts.opt_tol = 1e-5
        opts.limited_mem_size = 100
        opts.globalAD = True
        opts.quadprog_info_msg = False
        opts.print_level = 0
        opts.halt_on_linesearch_bracket = False
        opts.maxit = 200 # default: 1000

        def user_fn(model):
            # objective function    
            output = model(model.train_inputs[0])
            loss = -mll(output, model.train_targets.detach())
            log_p = log_normalized_prior(model, device=self.device, prior_dict=self.prior_dict)
            loss = loss - log_p
            model.NLL = loss.detach().item()
            ci = None
            ce = None
            return [loss,ci,ce]

        # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True) as prof:
        soln = pygranso(var_spec= model, combined_fn = user_fn, user_opts = opts)
        # prof.step()
        # print(prof.key_averages())
        # prof.export_chrome_trace("traces/trace_train.json")

        self.logger.info('Eval %d/%d - Loss: %.3f - Loss %.3f' % (
            soln.iters, soln.fn_evals, soln.final.f, model.NLL
        ))

        model.eval()
        likelihood.eval()

        return model

    def train_pygranso(self):
            
        self.precomputation_of_matrices_for_marginalization = False

        for (model, likelihood) in zip(self.models, self.likelihoods):
            try:
                self.train_single_model_pygranso(model, likelihood)
            except Exception as e:
                self.logger.warn("train pygranso exception", e)
                print(e)
                self.train_single_model_adam(model, likelihood)
            except:
                self.logger.warn("train pygranso error")
                print("train pygranso error")
                self.train_single_model_adam(model, likelihood)

        self.print_training_parameters()
        self.old_models.append(copy.deepcopy([model.state_dict() for model in self.models]))

        return self
    
    def print_training_parameters(self):
        for model in self.models:
            for param_name, param, constraint in model.named_parameters_and_constraints():
                if constraint!=None:
                    value = constraint.transform(param)
                else:
                    value = param
                value = value.cpu().detach().numpy()
                if len(value.shape) == 2:
                    value = value[0,:]
                self.logger.info(f'Parameter name: {param_name:60} value = {str(value):30} log_value = {str(np.log(value)):30}')

    # cutoff connection to optimizing hyperparameters
    @torch.no_grad()
    def PrecomputeForMarginalization(self):
        if not self.precomputation_of_matrices_for_marginalization:
            gp = self.models[self.sal_crit]
            self.K = gp.covar_module(gp.train_inputs[0]).to_dense() + gp.likelihood.noise * torch.eye(gp.train_inputs[0].shape[0], device=self.device)
            self.Cholesky = torch.linalg.cholesky_ex(self.K)[0]
            self.Kinv = torch.cholesky_inverse(self.Cholesky)
            self.sigma = torch.sqrt(gp.covar_module.outputscale)
            self.lengthscales = gp.covar_module.base_kernel.lengthscale
            self.precomputation_of_matrices_for_marginalization = True

    def ExtendTensorByPreviousInputsList(self, a):
        res=[]
        j=0
        for i in range(self.numinputs):
            if i in self.previous_inputs_list:
                res.append(res[-1])
            else:
                res.append(a[:,j:(j+1)].to(self.device))
                j = j+1
        return torch.cat(tuple(res), axis=1)

    def ExtendTensorByPreviousInputsListAndTrainingData(self, a):
        res=[]
        j=0
        for i in range(self.numinputs):
            if i in self.previous_inputs_list:
                res.append(torch.tensor([[self.x[-1,i-1]]],device=self.device))
            else:
                res.append(a[:,j:(j+1)].to(self.device))
                j = j+1
        return torch.cat(tuple(res), axis=1)

    def IMSPEMarginalizedOverUniformDistribution(self, xxs, a, b):

        assert a.requires_grad == False
        assert b.requires_grad == False

        grad = xxs.requires_grad
        if xxs.shape[1] + len(self.previous_inputs_list) == self.numinputs:
            xs = self.ExtendTensorByPreviousInputsListAndTrainingData(xxs)
        else:
            assert xxs.requires_grad == False
            xs = xxs
        if not grad:
            torch.set_grad_enabled(False)

        self.PrecomputeForMarginalization()
        gp = self.models[self.sal_crit]

        d = self.lengthscales.shape[-1]

        lengthscales = self.lengthscales.reshape([1,1,d])
        xdiff = (gp.train_inputs[0].unsqueeze(1)-gp.train_inputs[0].unsqueeze(0))[:,:,-d:]
        xsum = (gp.train_inputs[0].unsqueeze(1)+gp.train_inputs[0].unsqueeze(0))[:,:,-d:]
        xsdiff = (xs-gp.train_inputs[0])[:,-d:]
        xssum = (xs+gp.train_inputs[0])[:,-d:]
        aa = self.ExtendTensorByPreviousInputsList(a.unsqueeze(0))[:,-d:]
        bb = self.ExtendTensorByPreviousInputsList(b.unsqueeze(0))[:,-d:]

        Ks = gp.covar_module(xs,gp.train_inputs[0]).to_dense()
        kappa = torch.cholesky_solve(torch.transpose(Ks,0,1),self.Cholesky)
        Ss=gp.covar_module(xs,xs).to_dense()-torch.matmul(Ks,kappa)
        Ssinv=1/Ss
        
        SqrtPiToTheD = torch.pow(torch.tensor(torch.pi, device=self.device), 0.5*d)
        xs_short = xs[:,-d:]

        sum = self.sigma**2

        # -add(add((K1[i,j]+Ss1*kappa[i]*kappa[j])*sigma^4*sqrt(Pi)^d*product(l[h]*exp(-(x[j][h]-x[i][h])^2/4/l[h]^2)*(erf((2*a[h]-x[i][h]-x[j][h])/(2*l[h]))-erf((2*b[h]-x[i][h]-x[j][h])/(2*l[h])))/(2*a[h]-2*b[h]),h=1..d),j=1..n),i=1..n)
        # K1[i,j]+Ss1*kappa[i]*kappa[j])
        Mat1 = self.Kinv + torch.matmul(torch.matmul(kappa, Ssinv), torch.transpose(kappa,0,1))
        # exp(-(x[j][h]-x[i][h])^2/4/l[h]^2)
        Mat2 = torch.exp(-0.25*(xdiff/lengthscales)**2)
        # (erf((2*a[h]-x[i][h]-x[j][h])/(2*l[h]))-erf((2*b[h]-x[i][h]-x[j][h])/(2*l[h])))/(2*a[h]-2*b[h])
        Mat3 = (torch.erf(0.5*(2*aa-xsum)/lengthscales)-torch.erf(0.5*(2*bb-xsum)/lengthscales))/(2*(aa-bb))
        # sigma^4*sqrt(Pi)^d*product(l[h])
        factor = self.sigma**4 * SqrtPiToTheD * torch.prod(self.lengthscales)
        sum -= factor * torch.sum(Mat1*torch.prod(Mat2,2)*torch.prod(Mat3,2))

        #+2*add(kappa[i]*Ss1*sigma^4*sqrt(Pi)^d*product(l[h]*exp(-(xs[h]-x[i][h])^2/4/l[h]^2)*(erf((2*a[h]-xs[h]-x[i][h])/(2*l[h]))-erf((2*b[h]-xs[h]-x[i][h])/(2*l[h])))/(2*a[h]-2*b[h]),h=1..d),i=1..n)
        # exp(-(xs[h]-x[i][h])^2/4/l[h]^2)
        Vec1 = torch.exp(-0.25*(xsdiff/lengthscales)**2)
        # (erf((2*a[h]-xs[h]-x[i][h])/(2*l[h]))-erf((2*b[h]-xs[h]-x[i][h])/(2*l[h])))/(2*a[h]-2*b[h])
        Vec2 = (torch.erf(0.5*(2*aa-xssum)/lengthscales)-torch.erf(0.5*(2*bb-xssum)/lengthscales))/(2*(aa-bb))
        sum += 2 * factor * Ssinv[0,0] * torch.sum(kappa.squeeze(1) * (torch.prod(Vec1, 2) * torch.prod(Vec2, 2)).squeeze(0))

        #-Ss1*sigma^4*sqrt(Pi)^d*product(l[h]*(erf((a[h]-xs[h])/l[h])-erf((b[h]-xs[h])/l[h]))/(2*a[h]-2*b[h]),h=1..d)
        sum -= factor * Ssinv[0,0] * torch.prod( (torch.erf(((aa-xs_short)/lengthscales))-torch.erf(((bb-xs_short)/lengthscales))) / (2*(aa-bb)) )

        torch.set_grad_enabled(True)

        return sum

    def IMSPEMarginalizedOverGaussianDistribution(self, xxs, mm, ss):

        assert mm.requires_grad == False
        assert ss.requires_grad == False

        self.PrecomputeForMarginalization()
        gp = self.models[self.sal_crit]

        d = self.lengthscales.shape[-1]

        grad = xxs.requires_grad
        if xxs.shape[1] + len(self.previous_inputs_list) == self.numinputs:
            xs = self.ExtendTensorByPreviousInputsListAndTrainingData(xxs)
        else:
            assert xxs.requires_grad == False
            xs = xxs
        if not grad:
            torch.set_grad_enabled(False)

        m = self.ExtendTensorByPreviousInputsList(mm.unsqueeze(0))[:,-d:]
        s = self.ExtendTensorByPreviousInputsList(ss.unsqueeze(0))[:,-d:]

        lengthscales = self.lengthscales.reshape([1,1,d])
        x = gp.train_inputs[0]

        Ks = gp.covar_module(xs,x).to_dense()
        kappa = torch.cholesky_solve(torch.transpose(Ks,0,1),self.Cholesky)
        Ss=gp.covar_module(xs,xs).to_dense()-torch.matmul(Ks,kappa)
        Ssinv=1/Ss

        x = x[:,-d:]
        m = m[-d:]
        s = s[-d:]
        xs = xs[:,-d:]

        sum = self.sigma**2

        # -add(add((K1[i,j]+Ss1*kappa[i]*kappa[j])*sigma^4*product(l[h]/sqrt(l[h]^2+2*s[h]^2)*exp((-l[h]^2*(x[j][h]-m[h])^2-l[h]^2*(x[i][h]-m[h])^2-s[h]^2*(x[i][h]-x[j][h])^2)/(2*l[h]^2*(l[h]^2+2*s[h]^2))),h=1..d),j=1..n),i=1..n)
        # K1[i,j]+Ss1*kappa[i]*kappa[j])
        Mat1 = self.Kinv + torch.matmul(torch.matmul(kappa, Ssinv), torch.transpose(kappa,0,1))
        # exp((-l[h]^2*(x[j][h]-m[h])^2-l[h]^2*(x[i][h]-m[h])^2-s[h]^2*(x[i][h]-x[j][h])^2)/(2*l[h]^2*(l[h]^2+2*s[h]^2))
        Mat2 = torch.exp(-0.5*((((x-m)**2).unsqueeze(1)+((x-m)**2).unsqueeze(0))/(lengthscales**2+2*s.unsqueeze(0)**2)+s.unsqueeze(0)**2*(x.unsqueeze(1)-x.unsqueeze(0))**2/(lengthscales**2*(lengthscales**2+2*s.unsqueeze(0)**2))))
        # sigma^4*sqrt(Pi)^d*product(l[h]/sqrt(l[h]^2+2*s[h]^2))
        factor = self.sigma**4 * torch.prod(self.lengthscales/torch.sqrt(self.lengthscales**2+2*s**2))
        sum -= factor * torch.sum(Mat1*torch.prod(Mat2,2))

        # #+2*add(kappa[i]*Ss1*sigma^4*product(l[h]/sqrt(l[h]^2+2*s[h]^2)*exp((-l[h]^2*(xs[h]-m[h])^2-l[h]^2*(x[i][h]-m[h])^2-s[h]^2*(x[i][h]-xs[h])^2)/(2*l[h]^2*(l[h]^2+2*s[h]^2))),h=1..d),i=1..n)
        # exp((-l[h]^2*(xs[h]-m[h])^2-l[h]^2*(x[i][h]-m[h])^2-s[h]^2*(x[i][h]-xs[h])^2)/(2*l[h]^2*(l[h]^2+2*s[h]^2)))
        Vec = torch.exp(-0.5*((((xs-m)**2)+(x-m)**2)/(lengthscales**2+2*s.unsqueeze(0)**2)+(s**2*(xs-x)**2)/(lengthscales**2*(lengthscales**2+2*s.unsqueeze(0)**2))))
        sum += 2 * factor * Ssinv[0,0] * torch.sum(kappa.squeeze(1) * torch.prod(Vec, 2).squeeze(0))

        # -Ss1*sigma^4*product(l[h]/sqrt(l[h]^2+2*s[h]^2)*exp(-(m[h]-xs[h])^2/(l[h]^2+2*s[h]^2)),h=1..d)
        sum -= factor * Ssinv[0,0] * torch.prod(torch.exp(-(m-xs)**2/(lengthscales**2+2*s.unsqueeze(0)**2)))

        torch.set_grad_enabled(True)

        return sum

    def NegEntropy(self, xxs):
        
        if xxs.shape[1] + len(self.previous_inputs_list) == self.numinputs:
            xs = self.ExtendTensorByPreviousInputsListAndTrainingData(xxs)
        else:
            xs = xxs

        self.PrecomputeForMarginalization()
        gp = self.models[self.sal_crit]
        x = gp.train_inputs[0]

        Ks = gp.covar_module(xs,x).to_dense()
        kappa = torch.cholesky_solve(torch.transpose(Ks,0,1),self.Cholesky)
        sigma2 = gp.covar_module(xs,xs).to_dense()-torch.matmul(Ks,kappa)
        return -(0.5 * torch.log(2*torch.pi*sigma2) + 0.5)