# %%
import numpy as np
import math
import time
import matplotlib.pyplot as plt
from scipy.special import gammaln
from scipy.optimize import minimize

from dms_variants.ispline import Isplines
from multiprocessing import Pool
from functools import lru_cache


class CoxVB_trunc_post():
    def __init__(self, X, L, delta1, delta3, s, k_n, knots, a0l, rho0l, mu0, alpha0):
        self.X = X              # covariates
        self.n = X.shape[0]     # number of samples
        self.p = X.shape[1]     # dimension of covariates
        self.s = s              # degree of ispline basis
        self.m = s + k_n        # number of ispline basis
        self.knots = knots      # knots
        self.L = L              # left-censored
        self.delta1 = delta1    
        self.delta3 = delta3
        self.ELBOs = [1, 2]

        # isplines
        self.bil1 = self.isplines_values()

        # valid masking
        self.valid_mask = self.valid_masking()

        # hyperparameters
        self.a0l = a0l
        self.rho0l = rho0l
        self.mu0 = mu0
        self.alpha0 = alpha0

        # parameters
        self.E_beta = mu0
        self.Var_beta = np.identity(self.p) / self.alpha0
        self.a_gammal = a0l
        self.b_gammal = rho0l

        # expectations
        self.E_Zi = np.zeros(self.n)
        self.E_Zil = np.zeros((self.n, self.m))
        self.E_gammal = self.a0l / self.rho0l
        self.E_betaexpi = self.update_E_betaexpi()

        # initializing beta for gradient/hessian
        self.beta_ = self.mu0


    def calc_elbo(self):
        E_Lambdai1 = np.array([np.dot(self.E_gammal, self.bil1[i]) for i in range(self.n)])
        E_betaxi = np.array([np.dot(self.X[i], self.E_beta) for i in range(self.n)])

        np.seterr(divide = 'ignore')

        logp_O_before = self.E_Zi * (np.log(E_Lambdai1) + E_betaxi) - E_Lambdai1 * self.E_betaexpi - gammaln(self.E_Zi + 1)
        logp_O = np.sum(logp_O_before[self.valid_mask])
        logp_beta = - self.alpha0/2 * np.dot(self.E_beta - self.mu0, self.E_beta - self.mu0)
        
        logp_gamma = np.sum((self.a0l - 1) * np.log(self.E_gammal) - self.rho0l * self.E_gammal)

        logq_beta = - 0.5 * np.log(np.linalg.det(self.Var_beta))
        
        logq_gamma = np.sum(self.a_gammal * np.log(self.b_gammal) - gammaln(self.a_gammal) + (self.a_gammal - 1) * np.log(self.E_gammal) - self.b_gammal * self.E_gammal)


        ELBO = logp_O + logp_beta + logp_gamma - logq_beta - logq_gamma

        return ELBO

    def has_converged(self, tol):
        diff = abs(self.ELBOs[-1] - self.ELBOs[-2])
        return diff < tol

    def coordinate_ascent(self, iterate, tol=1e-7):
        itr = 0
        while itr<iterate and not self.has_converged(tol):
            itr += 1
            # updating
            self.E_Zi = self.update_E_Zi()
            self.E_Zil = self.update_E_Zil()
            self.a_gammal = self.update_a_gammal()
            self.b_gammal = self.update_b_gammal()
            self.E_gammal = self.update_E_gammal()
            self.beta_ = self.beta_initializing()
            self.Var_beta = self.update_Var_beta()
            self.E_beta = self.update_E_beta()
            self.E_betaexpi = self.update_E_betaexpi()

            # calculate ELBO
            ELBO = self.calc_elbo()
            self.ELBOs.append(ELBO)

    def isplines_values(self):
        # just evaluate outputs
        if self.s == 1:
            bil1 = np.zeros((self.n, self.m))
            for i in range(self.n):
                knots_rev = self.knots[1:]
                # where knots_rev == self.L
                j = np.where(knots_rev == self.L[i])[0][0]
                bil1[i, :j+1] = 1
        else:
            isplines_left = Isplines(self.s, self.knots, self.L)

            start_time = time.time()
            with Pool() as pool:
                bil1_list = pool.map(parallel_compute_ispline, [(isplines_left, l) for l in range(1, isplines_left.n + 1)])
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f'Execution time: {elapsed_time:.6f} seconds')

            bil1 = np.stack(bil1_list, axis=1)  # n*m size

        return bil1

    def valid_masking(self):
        sum_bil1 = np.sum(self.bil1, axis=1)

        return (sum_bil1 > 0)
    

    def update_E_betaexpi(self):
        """
        X : n x p
        E_beta : p
        Var_beta : p x p
        """
        temp = np.dot(self.X, self.E_beta) + 0.5 * np.einsum('ij,jk,ik->i', self.X, self.Var_beta, self.X)
        E_betaexpi = np.exp(temp)

        return E_betaexpi

    def update_E_Zi(self):
        """
        E_gammal : m
        bil1 : n x m
        E_betaexpi : n
        """
        rel_tol = 1e-07
        abs_tol = 1e-09

        lambdai1 = np.array([self.E_betaexpi[i] * np.dot(self.E_gammal, self.bil1[i]) for i in range(self.n)])

        denom = np.array([np.where(math.isclose(1 - np.exp(- lambdai1[i]), 0, rel_tol = rel_tol, abs_tol = abs_tol), 1, 1 - np.exp(- lambdai1[i])) for i in range(self.n)])

        E_Zi = np.multiply(lambdai1, self.delta1) / denom

        return E_Zi

    def update_E_Zil(self):
        rel_tol = 1e-07
        abs_tol = 1e-09

        pil = [np.multiply(self.E_gammal, self.bil1[i]) / np.where(math.isclose(np.dot(self.E_gammal, self.bil1[i]), 0, rel_tol=rel_tol, abs_tol=abs_tol), 1, np.dot(self.E_gammal, self.bil1[i])) for i in range(self.n)]
        pil = np.stack(pil, axis=0) # n*m size

        E_Zil = [self.E_Zi[i]*pil[i] for i in range(self.n)]
        E_Zil = np.stack(E_Zil, axis=0) # n*m size

        return E_Zil

    def update_a_gammal(self):
        a_gammal = self.a0l + np.array([np.sum(self.E_Zil.T[l]) for l in range(self.m)])

        return a_gammal

    def update_b_gammal(self):
        b_gammal = self.rho0l + np.array([np.dot(self.bil1.T[l], self.E_betaexpi) for l in range(self.m)])
        
        return b_gammal

    def update_E_gammal(self):
        E_gammal = self.a_gammal / self.b_gammal

        return E_gammal

    def minus_Laplace_func(self, beta_):
        E_Lambdai1 = np.array([np.dot(self.E_gammal, self.bil1[i]) for i in range(self.n)])
        betaxi_ = np.array([np.dot(self.X[i], beta_) for i in range(self.n)])
        betaexpi_ = np.exp(betaxi_)

        logp_O = np.sum(self.E_Zi * (np.log(E_Lambdai1) + betaxi_) - E_Lambdai1 * betaexpi_)

        logp_beta = - self.alpha0 / 2 * np.dot(beta_ - self.mu0, beta_ - self.mu0)

        return - (logp_O + logp_beta)

    def beta_initializing(self):
        beta_init = self.E_beta

        result = minimize(self.minus_Laplace_func, beta_init, method='L-BFGS-B')

        beta_ = result.x

        return beta_


    def gradient_beta(self):
        """
        X : n x p
        E_Zi : n
        """
        beta_ = self.beta_

        E_Lambdai = np.array([np.dot(self.E_gammal, self.bil1[i]) for i in range(self.n)])

        betaexpi = np.exp(np.array([np.dot(self.X[i], beta_) for i in range(self.n)]))
        gradient = np.dot(self.E_Zi - E_Lambdai*betaexpi, self.X) - self.alpha0*(beta_ - self.mu0)

        return gradient

    def hessian_beta(self):
        beta_ = self.beta_

        E_Lambdai = np.array([np.dot(self.E_gammal, self.bil1[i]) for i in range(self.n)])

        betaexpi = np.exp(np.array([np.dot(self.X[i], beta_) for i in range(self.n)]))

        hessian = - np.einsum('i,ij,ik->jk', E_Lambdai * betaexpi, self.X, self.X) - self.alpha0 * np.identity(self.p)

        return hessian

    def update_Var_beta(self):
        hessian = self.hessian_beta()
        Var_beta = np.linalg.inv(- hessian)

        return Var_beta

    def update_E_beta(self):
        beta_ = self.beta_

        gradient = self.gradient_beta()
        E_beta = beta_ + np.dot(gradient, self.Var_beta.T)

        return E_beta

    def plot_elbo(self):
        reduced_elbo = self.ELBOs[2:]

        plt.plot(reduced_elbo, marker='o', linestyle='-', color='b')
        plt.title('Change of ELBO')
        plt.xlabel('iterations')
        plt.ylabel('ELBO')
        plt.grid(True)
        plt.show()

    def get_beta(self):
        return self.E_beta

    def get_Var_beta(self):
        return self.Var_beta

    def get_cum_hazard(self, x_plot):
        # just evaluate output
        if self.s == 1:
            bl_plot = np.zeros((len(x_plot), self.m))
            for i in range(len(x_plot)):
                knots_rev = self.knots[1:]
                # where knots_rev == x_plot
                j = np.where(knots_rev == x_plot[i])[0][0]
                bl_plot[i, :j+1] = 1
        else:
            isplines = Isplines(self.s, self.knots, x_plot)
            bl_plot = np.stack([isplines.I(l) for l in range(1, isplines.n + 1)], axis=1)    # xplot * m size


        est_Lambda_plot = np.array([np.dot(self.E_gammal, bl_plot[i]) for i in range(len(x_plot))])
        
        return est_Lambda_plot
    
    def get_estimation(self, x_plot):
        est_beta = self.get_beta()
        est_Lambda_plot = self.get_cum_hazard(x_plot)
        est_surv_plot = np.exp(- est_Lambda_plot)
        return est_beta, est_Lambda_plot, est_surv_plot


@lru_cache(maxsize=None)
def cached_compute_ispline(isplines_left, l):
    return isplines_left.I(l)

def parallel_compute_ispline(args):
    isplines_left, l = args
    return cached_compute_ispline(isplines_left, l)