import math
import torch
import gpytorch
from matplotlib import pyplot as plt
# from evaluate import load
# bertscore = load("bertscore")
# predictions = ["hello there", "general kenobi"]
# references = ["hello there", "general kenobi"]
# results = bertscore.compute(predictions=predictions, references=references, lang="en")

# from evaluator import USE, SentenceEncoder
# use = SentenceEncoder() 
# use_score = use.get_sim(sentence, candidate_sent) 



class CustomKernel(gpytorch.kernels.kernel.Kernel):
    def __init__(
        self,
        **kwargs,
    ):
        super().__init__(**kwargs)
        print(self.has_lengthscale)
        self.register_parameter(name='hyperparameter1', parameter=torch.nn.Parameter(torch.ones(*self.batch_shape, 1)))
        self.register_parameter(name='hyperparameter2', parameter=torch.nn.Parameter(torch.ones(*self.batch_shape, 1)))

    def forward(self, x1, x2, diag=False, **params):
        '''
        todo: replace the following kernel with one that suitable for measuring similarities between texts x1 and x2
        '''
        x1_ = x1.div(self.hyperparameter1)
        x2_ = x2.div(self.hyperparameter1)
        euc_dist = self.covar_dist(x1_, x2_, square_dist=True, diag=diag, **params)
        return euc_dist.div_(-2).exp_() * self.hyperparameter2

    def hyperparameters(self,):
        return ', '.join(["%.3f" % param.item() for param in [self.hyperparameter1, self.hyperparameter2]])

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = CustomKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class GPR:
    def __init__(self):
       self.model = None
       self.likelihood = None

    def fit(self, train_x, train_y, training_iter=50):
        # initialize likelihood and model
        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        model = ExactGPModel(train_x, train_y, likelihood)

        # Find optimal model hyperparameters
        model.train()
        likelihood.train()

        # Use the adam optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, training_iter)

        # "Loss" for GPs - the marginal log likelihood
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

        for i in range(training_iter):
            # Zero gradients from previous iteration
            optimizer.zero_grad()
            # Output from model
            output = model(train_x)
            # Calc loss and backprop gradients
            loss = -mll(output, train_y)
            loss.backward()
            print('Iter %d/%d - Loss: %.3f   hyperparameters: %s   noise: %.3f' % (
                i + 1, training_iter, loss.item(),
                model.covar_module.hyperparameters(),
                model.likelihood.noise.item()
            ))
            optimizer.step()
            # scheduler.step()
        
        self.model = model
        self.likelihood = likelihood

    def infer(self, test_x):
        # Get into evaluation (predictive posterior) mode
        self.model.eval()
        self.likelihood.eval()

        # Make predictions by feeding model through likelihood
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            normal_dist = self.model(test_x)
            observed_pred = self.likelihood(normal_dist)
        return normal_dist.covariance_matrix, observed_pred.mean, observed_pred.confidence_region()

def unit_test_of_gpr():
    train_x = torch.linspace(0, 1, 15)
    train_y = torch.sin(train_x * (2 * math.pi))
    
    gpr = GPR()
    gpr.fit(train_x, train_y)

    with torch.no_grad():
        test_x = torch.linspace(0, 5, 51)
        gram, pred_mean, (lower, upper) = gpr.infer(test_x)

        # Initialize plot
        fig, ax = plt.subplots(1, 1, figsize=(4, 3))

        # Plot training data as black stars
        ax.plot(train_x.numpy(), train_y.numpy(), 'k*')
        # Plot predictive means as blue line
        ax.plot(test_x.numpy(), pred_mean.numpy(), 'b')
        # Shade between the lower and upper confidence bounds
        ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
        ax.set_ylim([-3, 3])
        ax.legend(['Observed Data', 'Mean', 'Confidence'])
        fig.savefig("gpr.pdf")

        print("uncertainties", gram.diagonal(dim1=-2, dim2=-1).numpy())

class BayesianOpt:
    def __init__(self, x, logp_func, n_samples, random_perturb_func):
        
        self.logp_func = logp_func
        self.n_samples = n_samples

        x_logp = logp_func(x)
        # randomly perturb once
        x_hat = random_perturb_func(x)
        x_hat_logp = logp_func(x_hat)

        self.samples = [x, x_hat]
        self.targets = [x_logp, x_hat_logp]

        self.gpr = GPR()
        self.gpr.fit(self.samples, self.targets)
    
    def main_loop(self):
        while len(self.samples) - 1 < self.n_samples:
            next_x = self.maximum_uncertainty()
            next_x_logp = self.logp_func(next_x)
            self.samples.append(next_x)
            self.targets.append(next_x_logp)
            self.gpr.fit(self.samples, self.targets)
        return
    
    def uncertaunty_fn(self, x):
        return self.gpr.infer(x)[0].diagonal(dim1=-2, dim2=-1)

    def maximum_uncertainty(self):
        '''
        todo: 
        find a text next_x that maximizes self.uncertaunty_fn(next_x) around the original text self.samples[0] (just like attack on texts)
        '''
        pass

def unit_test_of_bayesian_opt(): 
    x = 'This is a dog'
    logp_func = gpt2_logp # todo: implement this function
    n_samples = 5
    random_perturb_func = t5_modify #todo: implement this function
    bo = BayesianOpt(x, logp_func, n_samples, random_perturb_func)
    bo.main_loop()

    ret_texts = bo.samples
    ret_logps = bo.targets
    with torch.no_grad():
        ret_grammatrix = bo.gpr.infer(x)[0]

    # todo: construct a detector with these statistics

    return 

if __name__ == '__main__':
    unit_test_of_bayesian_opt()
