import sys
import numpy as np
import models
import torch
import math
import gpytorch
import matplotlib.pyplot as plt
from torch.nn import Module
import torch.nn as nn

class sin_target(object):
    def __init__(self, fidelity_fix=None):
        if fidelity_fix is None:
            self.fidelity = 2
        else:
            self.fidelity = 1
            self.fidelity_fix=fidelity_fix
        self.x_dim = 1
        self.a = torch.tensor([0.5, -0.5])
        self.b = torch.tensor([0, 0.5])
        self.bounds = torch.tensor([[0.], [1.]])
    def noise_level(self, tr_x, index_x):
        if self.fidelity == 1:
            index_x = self.fidelity_fix * torch.ones_like(index_x, dtype=torch.long)
        if len(tr_x.shape) == 2:
            tr_x = tr_x[:, 0]
        return self.a[index_x] * tr_x + self.b[index_x]

    def query_ground_truth(self, tr_x, index_x):
        if len(tr_x.shape) == 2:
            tr_x = tr_x[:, 0]
        tr_y_gt = torch.sin(tr_x * (2 * math.pi))
        tr_y_gt[torch.logical_or((tr_x < 0), (tr_x > 1))] = 0
        return tr_y_gt

    def query(self, tr_x, index_x):
        if self.fidelity == 1:
            index_x = self.fidelity_fix * torch.ones(tr_x.shape[0], dtype=torch.long)

        noise_level = self.noise_level(tr_x, index_x)
        return self.query_ground_truth(tr_x, index_x) + torch.randn(tr_x.shape[0])*noise_level

class band_gap_target(object):
    def __init__(self, dir, follow, cost=None):
        if cost is None:
            cost = [1, 1]
        self.fidelity = 2
        self.Z = torch.load(dir+'/Z'+follow+'.ts')
        self.Y = torch.load(dir+'/Y'+follow+'.ts')
        self.Y_0 = torch.load(dir+'/Y_0'+follow+'.ts')
        self.Y_1 = torch.load(dir+'/Y_1'+follow+'.ts')+0.9
        self.size = self.Z.shape[0]
        self.Y_low = [self.Y_0, self.Y_1]
        self.cost = cost
    def input_by_num(self, num_x):
        return self.Z[num_x, :]

    def query_ground_truth_by_num(self, num_x):
        return self.Y[num_x, 0]

    def query_by_num(self, num_x, index_x):
        output = torch.ones([num_x.shape[0]])
        for i in range(num_x.shape[0]):
            output[i] = self.Y_low[index_x[i]][num_x[i], 0]
        return output



if __name__ == '__main__':

    num_x = 500
    num_f = 2
    a = torch.tensor([0.5, -0.5])
    b = torch.tensor([0, 0.5])
    tr_x = torch.linspace(0, 1, num_x)
    tr_y_gt = torch.sin(tr_x * (2 * math.pi))
    index_x = torch.randint(0, num_f, [num_x])
    noise_level = (a[index_x] * tr_x + b[index_x])
    tr_y = tr_y_gt + (torch.randn(num_x))*(noise_level)
    noise_model = models.LinearNoise(num_f, 1)
    # print(noise_model.noise)
    likelihood = models.LearntNoiseLikelihood(tr_x, index_x, noise_model)
    model = models.ExactGPModel(tr_x, tr_y, likelihood)

    class LinearNoise_true(Module):
        def __init__(self, num_x, x_dim):
            super(LinearNoise_true, self).__init__()
            # self.model_list = nn.ModuleList([nn.Linear(x_dim, 1) for _ in range(num_x)])

        def forward(self, train_x, index_x):
            noise_level = (a[index_x] * train_x + b[index_x])
            return noise_level


    training_iter = 1000
    model.train()
    likelihood.train()
    noise_model_2 = LinearNoise_true(0, 0)
    likelihood_2 = models.LearntNoiseLikelihood(tr_x, index_x, noise_model_2)
    optimizer = torch.optim.Adam(model.parameters(), lr=.01)  # Includes GaussianLikelihood parameters

    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    # mll2 = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood_2, model)
    optimal_loss = 100

    for i in range(training_iter):
        optimizer.zero_grad()
        output = model(tr_x)
        loss = -mll(output, tr_y)
        # loss_2 = -mll2(output, tr_y)
        loss.backward()
        optimizer.step()
        if optimal_loss > loss.item():
            torch.save(model.state_dict(), './optimal_model.md')
            optimal_loss = loss.item()
        print(i, optimal_loss, loss.item(), noise_model.F0_weight.item(), noise_model.F0_bias.item(), noise_model.F1_weight.item(), noise_model.F1_bias.item())
    model.load_state_dict(torch.load('./optimal_model.md'))
    model.eval()
    likelihood.eval()

    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        test_x = torch.linspace(0, 1, 51)
        pred = model(test_x)
    # print('\n\n')
    # for a in model.parameters():
    #     print(a)

    with torch.no_grad():
        # Initialize plot

        # Get upper and lower confidence bounds
        lower, upper = pred.mean.numpy() - np.sqrt(pred.variance.numpy()), pred.mean.numpy() + np.sqrt(pred.variance.numpy())
        # Plot training data as black stars
        plt.plot(tr_x[index_x==0].numpy(), tr_y[index_x==0].numpy(), 'k*', alpha=0.5)
        plt.plot(tr_x[index_x==1].numpy(), tr_y[index_x==1].numpy(), 'g*', alpha=0.5)
        plt.plot(test_x.numpy(), torch.sin(test_x * (2 * math.pi)).numpy(), 'r')
        plt.plot(test_x.numpy(), pred.mean.numpy(), 'y')
        # Plot predictive means as blue line
        # ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')
        # Shade between the lower and upper confidence bounds
        plt.fill_between(test_x.numpy(), lower, upper, alpha=0.5)
        plt.ylim([-1.5, 1.5])
        plt.legend(['Fidelity1', 'Fidelity2', 'Ground Truth', 'Prediction'])
        plt.show()


        pred_0 = likelihood(pred, test_x=test_x, index_x=torch.zeros_like(test_x).type_as(index_x))
        lower, upper = pred_0.mean.numpy() - np.sqrt(pred_0.variance.numpy()), pred_0.mean.numpy() + np.sqrt(
            pred_0.variance.numpy())
        # Plot training data as black stars
        plt.plot(tr_x[index_x == 0].numpy(), tr_y[index_x == 0].numpy(), 'k*', alpha=0.5)
        plt.plot(tr_x[index_x == 1].numpy(), tr_y[index_x == 1].numpy(), 'g*', alpha=0.5)
        plt.plot(test_x.numpy(), torch.sin(test_x * (2 * math.pi)).numpy(), 'r')
        plt.plot(test_x.numpy(), pred_0.mean.numpy(), 'y')
        # Plot predictive means as blue line
        # ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')
        # Shade between the lower and upper confidence bounds
        plt.fill_between(test_x.numpy(), lower, upper, alpha=0.5)
        plt.ylim([-1.5, 1.5])
        plt.legend(['Fidelity1', 'Fidelity2', 'Ground Truth', 'Prediction'])
        plt.show()


        pred_1 = likelihood(pred, test_x=test_x, index_x=torch.ones_like(test_x).type_as(index_x))
        lower, upper = pred_1.mean.numpy() - np.sqrt(pred_1.variance.numpy()), pred_1.mean.numpy() + np.sqrt(
            pred_1.variance.numpy())
        # Plot training data as black stars
        plt.plot(tr_x[index_x == 0].numpy(), tr_y[index_x == 0].numpy(), 'k*', alpha=0.5)
        plt.plot(tr_x[index_x == 1].numpy(), tr_y[index_x == 1].numpy(), 'g*', alpha=0.5)
        plt.plot(test_x.numpy(), torch.sin(test_x * (2 * math.pi)).numpy(), 'r')
        plt.plot(test_x.numpy(), pred_1.mean.numpy(), 'y')
        # Plot predictive means as blue line
        # ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')
        # Shade between the lower and upper confidence bounds
        plt.fill_between(test_x.numpy(), lower, upper, alpha=0.5)
        plt.ylim([-1.5, 1.5])
        plt.legend(['Fidelity1', 'Fidelity2', 'Ground Truth', 'Prediction'])
        plt.show()

        plt.plot(test_x.numpy(), abs(noise_model.model_list[0](test_x.unsqueeze(1))))
        plt.plot(test_x.numpy(), abs(noise_model.model_list[1](test_x.unsqueeze(1))))
        plt.plot(test_x.numpy(), (a[0] * test_x + b[0]).numpy())
        plt.plot(test_x.numpy(), (a[1] * test_x + b[1]).numpy())
        plt.legend(['Fidelity1', 'Fidelity2'])
        plt.show()

        # pred_0 = likelihood_2(pred, test_x=test_x, index_x=torch.zeros_like(test_x).type_as(index_x))
        # lower, upper = pred_0.mean.numpy() - np.sqrt(pred_0.variance.numpy()), pred_0.mean.numpy() + np.sqrt(
        #     pred_0.variance.numpy())
        # # Plot training data as black stars
        # plt.plot(tr_x[index_x == 0].numpy(), tr_y[index_x == 0].numpy(), 'k*', alpha=0.5)
        # plt.plot(tr_x[index_x == 1].numpy(), tr_y[index_x == 1].numpy(), 'g*', alpha=0.5)
        # plt.plot(test_x.numpy(), torch.sin(test_x * (2 * math.pi)).numpy(), 'r')
        # plt.plot(test_x.numpy(), pred_0.mean.numpy(), 'y')
        # Plot predictive means as blue line
        # ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')
        # Shade between the lower and upper confidence bounds
        # plt.fill_between(test_x.numpy(), lower, upper, alpha=0.5)
        # plt.ylim([-1.5, 1.5])
        # plt.legend(['Fidelity1', 'Fidelity2', 'Ground Truth', 'Prediction'])
        # plt.show()
        #
        # pred_1 = likelihood_2(pred, test_x=test_x, index_x=torch.ones_like(test_x).type_as(index_x))
        # lower, upper = pred_1.mean.numpy() - np.sqrt(pred_1.variance.numpy()), pred_1.mean.numpy() + np.sqrt(
        #     pred_1.variance.numpy())
        # # Plot training data as black stars
        # plt.plot(tr_x[index_x == 0].numpy(), tr_y[index_x == 0].numpy(), 'k*', alpha=0.5)
        # plt.plot(tr_x[index_x == 1].numpy(), tr_y[index_x == 1].numpy(), 'g*', alpha=0.5)
        # plt.plot(test_x.numpy(), torch.sin(test_x * (2 * math.pi)).numpy(), 'r')
        # plt.plot(test_x.numpy(), pred_1.mean.numpy(), 'y')
        # # Plot predictive means as blue line
        # # ax.plot(test_x.numpy(), observed_pred.mean.numpy(), 'b')
        # # Shade between the lower and upper confidence bounds
        # plt.fill_between(test_x.numpy(), lower, upper, alpha=0.5)
        # plt.ylim([-1.5, 1.5])
        # plt.legend(['Fidelity1', 'Fidelity2', 'Ground Truth', 'Prediction'])
        # plt.show()
        #
        # plt.plot(test_x.numpy(), torch.sqrt(noise_model(test_x.unsqueeze(1), torch.zeros(test_x.shape[0]))).numpy())
        # plt.plot(test_x.numpy(), torch.sqrt(noise_model(test_x.unsqueeze(1), torch.ones(test_x.shape[0]))).numpy())
        # plt.plot(test_x.numpy(), (a[0] * test_x + b[0]).numpy())
        # plt.plot(test_x.numpy(), (a[1] * test_x + b[1]).numpy())
        # plt.legend(['noiselevel prediction f1', 'noiselevel prediction f2', 'noiselevel f1', 'noiselevel f2'])
        # plt.show()