import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt

print(torch.__version__)

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Fixing strange error if run in MacOS
JITTER = 1e-6
EPS = 1e-10
PI = 3.1415

class fides_mis(nn.Module):
    def __init__(self, 
                xtr, 
                ytr, 
                xte,
                mis_index,
                train_begin_index,
                train_samples_num,
                fidelity_num, 
                seed, 
                niteration,
                learning_rate,
                normal_y_mode=0):
        super(fides_mis, self).__init__()

        # initiate parameters
        self.mis_index = mis_index
        self.train_begin_index = train_begin_index
        self.train_samples_num = train_samples_num
        self.fidelity_num = fidelity_num
        self.seed = seed
        self.niteration = niteration
        self.learning_rate = learning_rate

        x = xtr[0]
        
        # normalize X independently for each dimension
        self.Xmean = x.mean(0)
        self.Xstd = x.std(0)
        self.X = (x - self.Xmean.expand_as(x)) / (self.Xstd.expand_as(x) + EPS)
        
        self.Y = ytr
        self.Ymean = []
        self.Ystd = []

        self.xte = xte

        # GP hyperparameters
        self.log_beta = nn.Parameter(torch.ones(1) * 0.0000001)
        self.log_length_scale = nn.Parameter(torch.zeros(xtr[0].shape[1]))    # ARD length scale
        self.log_length_scale_z = nn.Parameter(torch.zeros(1))    # ARD length scale for t
        self.log_scale = nn.Parameter(torch.zeros(1))   # kernel scale

        # Matern3 hyperparameters for x
        self.log_length_matern3 = torch.nn.Parameter(torch.zeros(xtr[0].shape[1]))  # Matern3 Kernel length
        self.log_coe_matern3 = torch.nn.Parameter(torch.zeros(1))  # Matern3 Kernel coefficient
        # Matern3 hyperparameters for z
        self.log_length_matern3_z = torch.nn.Parameter(torch.zeros(1)) 


        # parameters for residuals part
        # initialize the parameter of function G
        self.b = nn.Parameter(torch.ones(1))

    def y_norm(self, y, normal_y_mode = 0):
        if normal_y_mode == 0:
            # normalize y all together
            y_m = y.mean()
            y_s = y.std()
            # self.Ymean.append(y_m)
            # self.Ystd.append(y_s)
            Y = (y - y_m.expand_as(y)) / (y_s.expand_as(y) + EPS)
        elif normal_y_mode == 1:
            # normalize y by each dimension
            y_m = y.mean(0)
            y_s = y.std(0)
            # self.Ymean.append(y_m)
            # self.Ystd.append(y_s)
            Y = (y - y_m.expand_as(y)) / (y_s.expand_as(y) + EPS)
        return Y, y_m, y_s

    def y_denorm(self, y, index,normal_y_mode = 0):
        if normal_y_mode == 0:
            # normalize y all together
            y_m = self.Ymean[index]
            y_s = self.Ystd[index]
            Y = y * y_s + y_m
        elif normal_y_mode == 1:
            # normalize y by each dimension
            y_m = self.Ymean[index]
            y_s = self.Ystd[index]
            Y = y * y_s + y_m
        return Y
    
    # def kernel_matern3
    def kernel(self, x1, x2):
        """
        latex formula:
        \sigma ^2\left( 1+\frac{\sqrt{3}d}{\rho} \right) \exp \left( -\frac{\sqrt{3}d}{\rho} \right)
        :param x1: x_point1
        :param x2: x_point2
        :return: kernel matrix
        """
        const_sqrt_3 = torch.sqrt(torch.ones(1) * 3)
        x1 = x1 / self.log_length_matern3.exp()
        x2 = x2 / self.log_length_matern3.exp()
        distance = const_sqrt_3 * torch.cdist(x1, x2, p=2)
        k_matern3 = self.log_coe_matern3.exp() * (1 + distance) * (- distance).exp()
        return k_matern3

    def forward(self, x_tr, y_tr, y_m, y_s, Xte):
        n_test = Xte.size(0)
        Xte = ( Xte - self.Xmean.expand_as(Xte) ) / self.Xstd.expand_as(Xte)

        Sigma = self.kernel(x_tr, x_tr) + self.log_beta.exp().pow(-1) * torch.eye(x_tr.size(0)) \
            + JITTER * torch.eye(x_tr.size(0))

        kx = self.kernel(x_tr, Xte)
        L = torch.cholesky(Sigma)
        LinvKx,_ = torch.triangular_solve(kx, L, upper = False)

        # option 1
        mean = kx.t() @ torch.cholesky_solve(y_tr, L)  # torch.linalg.cholesky()
        
        var_diag = self.kernel(Xte, Xte).diag().view(-1, 1) \
            - (LinvKx**2).sum(dim = 0).view(-1, 1)

        # add the noise uncertainty
        var_diag = var_diag + self.log_beta.exp().pow(-1)

        # de-normalized
        mean = mean * y_s.expand_as(mean) + y_m.expand_as(mean)
        var_diag = var_diag.expand_as(mean) * y_s**2

        return mean, var_diag

    def negative_log_likelihood(self, x, y):
        y_num, y_dimension = y.shape
        Sigma = self.kernel(x, x) + self.log_beta.exp().pow(-1) * torch.eye(
            x.size(0)) + JITTER * torch.eye(x.size(0))

        L = torch.linalg.cholesky(Sigma)
        #option 1 (use this if torch supports)
        Gamma,_ = torch.triangular_solve(y, L, upper = False)
        #option 2
        # gamma = L.inverse() @ Y       # we can use this as an alternative because L is a lower triangular matrix.

        nll =  0.5 * (Gamma ** 2).sum() +  L.diag().log().sum() * y_dimension  \
            + 0.5 * y_num * torch.log(2 * torch.tensor(PI)) * y_dimension
        return nll

    def train_adam(self, x, y, niteration=10, lr=0.1):
        # adam optimizer
        # uncommont the following to enable
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        optimizer.zero_grad()
        for i in range(niteration):
            optimizer.zero_grad()
            # self.update()
            loss = self.negative_log_likelihood(x, y)
            loss.backward()
            optimizer.step()
            # print('loss_nll:', loss.item())
            # print('iter', i, ' nll:', loss.item())
            print('iter', i, 'nll:{:.5f}'.format(loss.item()))

    def warp(self, lf1, hf1, lf2, hf2):
        '''
        l = [lf1, hf1, lf2, hf2]
        tem = []
        for i in range(4):
            tem.append(1 - pow((1 - pow(l[i], self.warp_a[i])), self.warp_b[i]) )
        
        l1 = tem[0]
        l2 = tem[2]
        h1 = tem[1]
        h2 = tem[3]

        return l1, h1, l2, h2
        '''
        return lf1, hf1, lf2, hf2

    def forward_res(self, x_tr, y_tr, l1, h1, l2, h2, y_m, y_s, Xte):
        n_test = Xte.size(0)
        Xte = ( Xte - self.Xmean.expand_as(Xte) ) / self.Xstd.expand_as(Xte)

        Sigma = self.kernel_res(x_tr, x_tr, l1, h1, l2, h2) + self.log_beta.exp().pow(-1) * torch.eye(x_tr.size(0)) \
            + JITTER * torch.eye(x_tr.size(0))

        kx = self.kernel_res(x_tr, Xte, l1, h1, l2, h2)
        L = torch.cholesky(Sigma)
        LinvKx,_ = torch.triangular_solve(kx, L, upper = False)

        # option 1
        mean = kx.t() @ torch.cholesky_solve(y_tr, L)  # torch.linalg.cholesky()
        
        var_diag = self.kernel(Xte, Xte).diag().view(-1, 1) \
            - (LinvKx**2).sum(dim = 0).view(-1, 1)

        # add the noise uncertainty
        var_diag = var_diag + self.log_beta.exp().pow(-1)

        # de-normalized
        mean = mean * y_s.expand_as(mean) + y_m.expand_as(mean)
        var_diag = var_diag.expand_as(mean) * y_s**2
        
        return mean, var_diag
    
    def kernel_res(self, X1, X2, l1, h1, l2, h2):
        lf1, hf1, lf2, hf2 = self.warp(l1, h1, l2, h2)

        N = 100
        torch.manual_seed(self.seed)
        # print(torch.rand(1))
        z1 = torch.rand(N) * (hf1 - lf1) + lf1 
        z2 = torch.rand(N) * (hf2 - lf2) + lf2

        X1 = X1 / self.log_length_scale.exp()
        X2 = X2 / self.log_length_scale.exp()
        # X1_norm2 = X1 * X1
        # X2_norm2 = X2 * X2
        X1_norm2 = torch.sum(X1 * X1, dim=1).view(-1, 1)
        X2_norm2 = torch.sum(X2 * X2, dim=1).view(-1, 1)

        K = -2.0 * X1 @ X2.t() + X1_norm2.expand(X1.size(0), X2.size(0)) + X2_norm2.t().expand(X1.size(0), X2.size(0))  
        #this is the effective Euclidean distance matrix between X1 and X2.
        K = self.log_scale.exp() * torch.exp(-0.5 * K)
        
        # z part use MCMC to calculate the integral
        dist_z = (z1 / self.log_length_scale_z.exp() - z2 / self.log_length_scale_z.exp()) ** 2
        z_part1 = -self.b * (z1 - hf1)
        z_part2 = -self.b * (z2 - hf2)
        z_part  = (z_part1 + z_part2 - 0.5 * dist_z).exp()
        z_part_mc = z_part.mean() * (hf1 - lf1) * (hf2 - lf2)
        # z_part_mc = z_part.mean()
        
        K_ard = z_part_mc * K
        return K_ard

    def negative_log_likelihood_res(self, x, y, l1, h1, l2, h2):
        y_num, y_dimension = y.shape
        Sigma = self.kernel_res(x, x, l1, h1, l2, h2) + self.log_beta.exp().pow(-1) * torch.eye(
        x.size(0)) + JITTER * torch.eye(x.size(0))

        L = torch.linalg.cholesky(Sigma)
        #option 1 (use this if torch supports)
        Gamma,_ = torch.triangular_solve(y, L, upper = False)
        #option 2
        # gamma = L.inverse() @ Y       # we can use this as an alternative because L is a lower triangular matrix.

        nll =  0.5 * (Gamma ** 2).sum() +  L.diag().log().sum() * y_dimension  \
            + 0.5 * y_num * torch.log(2 * torch.tensor(PI)) * y_dimension

        return nll

    def train_adam_res(self, x, y, l1, h1, l2, h2, niteration=10, lr=0.1):
        # adam optimizer
        # uncommont the following to enable
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        optimizer.zero_grad()
        for i in range(niteration):
            optimizer.zero_grad()
            # self.update()
            loss = self.negative_log_likelihood_res(x, y,  l1, h1, l2, h2)
            loss.backward(retain_graph=True) 
            # loss.backward() 
            optimizer.step()
            # print('loss_nll:', loss.item())
            # print('iter', i, ' nll:', loss.item())
            print('iter', i, 'nll:{:.5f}'.format(loss.item()))
 
    def train_and_test_lowest_fidelity(self):
        x = self.X[self.train_begin_index:self.train_begin_index+self.train_samples_num]
        y = self.Y[0][self.train_begin_index:self.train_begin_index+self.train_samples_num]
        y, y_m, y_s = self.y_norm(y)
        self.train_adam(x = x, y = y, niteration = 200, lr = 0.02)
        self.yte_mean, self.yte_var = self.forward(x, y, y_m, y_s, self.xte)
        print("finish lowest fidelity")

    def train_mis_data(self):
        self.missing_var = []
        ytr_f = []
        ytr_f.append(self.Y[0][self.train_begin_index:self.train_begin_index+self.train_samples_num])
        
        Y_m = []
        Y_s = []
        for i in range(1, self.fidelity_num):
            missing_index = self.mis_index[i+1]
            exist_index = list(range(0, self.train_samples_num))
            misss = []
            if missing_index != None:
                for j in missing_index:
                    exist_index = list(set(exist_index) - set(list(range(j[0], j[1]))))
                    misss += list(range(j[0], j[1]))

            xtr_exist = torch.stack([self.X[j] for j in exist_index])
            ytr_exist = torch.stack([self.Y[i][j] for j in exist_index])
            xtr_missing = torch.stack([self.X[j] for j in misss])
            ytr_exist, ytr_m, ytr_s = self.y_norm(ytr_exist)
            Y_m.append(ytr_m)
            Y_s.append(ytr_s)

            self.train_adam(x = xtr_exist, y = ytr_exist, niteration = 100, lr = 0.02)
            missing_mean, missing_variance = self.forward(xtr_exist, ytr_exist, ytr_m, ytr_s, xtr_missing)
            self.missing_var.append(missing_variance)

            ytr_full = []
            for j in range(self.train_samples_num):
                if (j in misss):
                    ind = misss.index(j)
                    ytr_full.append(missing_mean[ind])
                else:
                    ind = exist_index.index(j)
                    ytr_full.append(ytr_exist[ind])
            ytr_full = torch.stack(ytr_full)
            ytr_f.append(ytr_full)
        return ytr_f

        

    def train_mod(self):
        self.train_and_test_lowest_fidelity()
        ytr_f = self.train_mis_data()

        for fid in range(self.fidelity_num - 1):
            hf = fid + 1
            lf = fid
            
            ytr_res = ytr_f[hf][self.train_begin_index:self.train_begin_index + self.train_samples_num] - ytr_f[lf][self.train_begin_index:self.train_begin_index+self.train_samples_num]

            xtr_res = self.X[self.train_begin_index:self.train_begin_index + self.train_samples_num]
            
            ytr_res, yy_m, yy_s = self.y_norm(ytr_res)
            self.train_adam_res(x = xtr_res, y = ytr_res, l1 = lf, h1 = hf, l2 = lf, h2 = hf, niteration = 100, lr = 0.02)
            yte_res_mean, yte_res_var = self.forward_res(xtr_res, ytr_res, l1 = lf, h1 = hf, l2 = lf, h2 = hf, y_m = yy_m, y_s = yy_s, Xte = self.xte)
            self.yte_mean += yte_res_mean
            self.yte_var += abs(yte_res_var)
            print("finish", hf, "fidelity")

        return self.yte_mean, self.yte_var
        