#!/usr/bin/env python
# coding: utf-8

import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

nvars = 4
sigma2 = torch.ones((1,), device=device)

def condition_on_theta(M, S, theta):
    M1 = torch.hstack([ 
        torch.tensor(theta, device=device), 
        M[1:] + S[1:,0]*(1./S[0,0])*(theta - M[0]) 
    ])
    S1 = torch.zeros((nvars,nvars), device=device)
    S1[1:,1:] += S[1:,1:] - S[1:,0]*(1./S[0,0])@S[0,1:]
    return M1, S1

def condition_on_vector_of_thetas(M, S, thetas):
    M1 = torch.hstack([ 
        thetas[:,None],
        M[None,1:] + (S[1:,0]*(1./S[0,0]))[None,:]*(thetas - M[0])[:,None] 
    ])
    S1 = torch.zeros((nvars,nvars), device=device)
    S1[1:,1:] += S[1:,1:] - S[1:,0]*(1./S[0,0])@S[0,1:]
    return M1, S1

def kldivergence(mu0, mu1, sigma2_0, sigma2_1):
    kld = sigma2_0 / sigma2_1 + (mu1 - mu0)**2. / sigma2_1
    kld -= 1.
    kld += torch.log(sigma2_1 / sigma2_0)
    return kld / 2.

def _EIG(x, M, S):
    eig = sigma2 + x@S@x.T
    eig /= sigma2
    eig = torch.log(eig) / 2.
    return torch.diag(eig)

def EIG(t, x, M, S, par_i=None):
    eig = _EIG(x, M, S)
    return dict(clb=None, x=eig.argmax())

def _ETIG(x, M, S):
    M1, S1 = condition_on_theta(M, S, 0.)
    etig = sigma2 + x@S@x.T
    etig /= sigma2 + x@S1@x.T
    etig = torch.log(etig) / 2.
    return torch.diag(etig)

def ETIG(t, x, M, S, par_i=None):
    etig = _ETIG(x, M, S)
    return dict(clb=None, x=etig.argmax())

def get_prior():
    return torch.zeros((nvars,), device=device), 10.*torch.eye(nvars, device=device)
