#!/usr/bin/env python
# coding: utf-8

import numpy as np
from preference.sequential_distributions import Gaussian
from scipy.stats import bernoulli, norm
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

designs = torch.load(f"preference/designs.pt").to(device)

nouter = 10000
ninner = 100

prior_loc = [0.,0.]
prior_scale = [4.,1.]

zeroproof_likelihood = True

class OuterDist(Gaussian):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def _pseudo_log_prob(self, x):
        log_pprior = np.log( self._prior.pdf(x) ).sum()
        log_lik = np.log( likelihood(
            x[:,None], self._Y.asarray()[None,...], self._X.asarray(), 
            observed=True, zeroproof=False
        ) ).sum()
        return log_pprior + log_lik
        
    def _vectorized_likelihood(self, y, x, samples=None):
        if isinstance(samples, type(None)):
            samples = self._samples
        return likelihood(
            y[None,...], x, *samples.T, observed=True, zeroproof=False
        ).prod(axis=(1,2))
    
def sigmoid(x):
    return 1. / (1. + torch.exp(-x))
    
def to_torch(x):
    if not isinstance(x, torch.Tensor):
        return torch.from_numpy(x).float().to(device)
    return x
    
def forward(x, theta, multiplier, observed=False):
    x = to_torch(x)
    theta = to_torch(theta)
    multiplier = to_torch(multiplier)
    for _ in range(theta.ndim):
        x = x[None,...]
    x = multiplier[...,None,None]*x
    loc = x - theta[...,None,None]
    return sigmoid(loc)
    
def likelihood(y, x, theta, multiplier, observed=False, zeroproof=False):
    y = torch.from_numpy(y).float().to(device)
    p = forward(x, theta, multiplier, observed=observed)
    # Updating on the basis of an observation
    if observed:
        y = y.expand((theta.shape[0],) + y.shape[1:])
    # Inner expectation of conditional probability in the embedded case
    if theta.ndim == 2:
        y = y.expand((-1,theta.shape[1],) + y.shape[2:])
    p[y == 0.] = 1.-p[y == 0.]
    return p.cpu().numpy()

def weighted_mean(x, weights):
    return (weights[:,None,None]*x).sum(axis=0)

def _eig(designs, outer_model, nouter=nouter, ninner=ninner):
    py1 = forward(designs, *outer_model.samples.T)
    cond_log_p1 = torch.log(py1)
    cond_log_p0 = torch.log(1.-py1)
    w = to_torch(outer_model.W)
    marg_log_p1 = torch.log(weighted_mean(py1, w))[None,...]
    marg_log_p0 = torch.log(weighted_mean(1.-py1, w))[None,...]
    marg_log_p1 = marg_log_p1.expand(py1.shape)
    marg_log_p0 = marg_log_p0.expand(py1.shape)
    eigv = torch.zeros_like(py1)
    pos_indx = py1 > 0.
    neg_indx = py1 < 1.
    eigv[pos_indx] += py1[pos_indx]*(
        cond_log_p1[pos_indx] - marg_log_p1[pos_indx]
    ) 
    eigv[neg_indx] += (1.-py1[neg_indx])*(
        cond_log_p0[neg_indx] - marg_log_p0[neg_indx]
    )
    return eigv, w

def eig(trialn, designs, outer_model, par_i=None, nouter=nouter, ninner=ninner):
    eigv, w = _eig(designs, outer_model, nouter=nouter, ninner=ninner)
    return weighted_mean(eigv, w)

def _etig_num(
    designs, outer_model, ninner=ninner, theta=None, psi=None, prior=True
):
    idist = outer_model._distribution
    prior_psi_dist = norm(loc=prior_loc[1], scale=prior_scale[1])
    if isinstance(theta, type(None)):
        theta = outer_model.samples[:,0]
    if isinstance(psi, type(None)):
        psi = outer_model.samples[:,1]
    
    # https://en.wikipedia.org/wiki/Multivariate_normal_distribution
    theta_mat = np.repeat(theta[:,None], ninner, axis=1)
    try:
        mu0, cov0 = idist.mean, idist.cov
        cond_mu = ( mu0[1] + cov0[1,0]*(1./cov0[0,0])*(theta - mu0[0]) )[:,None]
        cond_cov = cov0[1,1] - cov0[1,0]*(1./cov0[0,0])*cov0[0,1]
        cond_cov *= np.ones_like(cond_mu)
        mdist = norm(loc=cond_mu, scale=np.sqrt(cond_cov))
    except AttributeError:
        mdist = prior_psi_dist
    mprime = mdist.rvs(size=theta_mat.shape)
    if prior:
        mprime[:,-1] = psi
    wprime = prior_psi_dist.pdf(mprime)
    if not outer_model._is_prior:
        wprime *= likelihood(
            outer_model._Y.asarray()[None,None,...], outer_model._X.asarray(),
            theta_mat, mprime, observed=True, zeroproof=False
        ).prod(axis=(2,3))
    wprime += np.finfo(float).eps
    Q = mdist.pdf(mprime) + np.finfo(float).eps
    wprime /= Q
    wprime /= wprime.sum(axis=1)[:,None]
    cond_p = forward(designs, theta_mat, mprime)
    return (cond_p*wprime[...,None,None]).sum(axis=1)

def _etig(designs, outer_model, nouter=nouter, ninner=ninner):
    py1 = to_torch(_etig_num(designs, outer_model))
    cond_log_p1 = torch.log(py1)
    cond_log_p0 = torch.log(1.-py1)
    py1 = to_torch(forward(designs, *outer_model.samples.T))
    w = to_torch(outer_model.W)
    marg_log_p1 = torch.log(weighted_mean(py1, w))[None,...]
    marg_log_p0 = torch.log(weighted_mean(1.-py1, w))[None,...]
    marg_log_p1 = marg_log_p1.expand(py1.shape)
    marg_log_p0 = marg_log_p0.expand(py1.shape)
    tigv = torch.zeros_like(py1)
    pos_indx = py1 > 0.
    neg_indx = py1 < 1.
    tigv[pos_indx] += py1[pos_indx]*(
        cond_log_p1[pos_indx] - marg_log_p1[pos_indx]
    ) 
    tigv[neg_indx] += (1.-py1[neg_indx])*(
        cond_log_p0[neg_indx] - marg_log_p0[neg_indx]
    )
    return tigv, w

def elig(
    trialn, designs, outer_model, par_i=None, nouter=nouter, ninner=ninner
):
    etigv = etig(
        trialn, designs, outer_model, par_i, nouter=nouter, ninner=ninner
    )
    eigv = eig(
        trialn, designs, outer_model, par_i, nouter=nouter, ninner=ninner
    )
    return eigv - etigv

def etig(
    trialn, designs, outer_model, par_i=None, nouter=nouter, ninner=ninner
):
    tigv, w = _etig(designs, outer_model, nouter=nouter, ninner=ninner)
    return weighted_mean(tigv, w)