from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.priors import GammaPrior
from gpytorch.priors import NormalPrior
import numpy as np
import os
import pickle as pkl
import torch
from torch.distributions import Binomial as Binomial

torch.set_default_dtype(torch.double)
seed = 200
torch.manual_seed(seed)

cc = 3
scale = 1.25
psi_ls_prior =  GammaPrior(concentration=cc, rate=scale**(-1))

# Normal prior that we will exponentiate for theta
theta_ls_prior = NormalPrior(loc=1, scale=1)

nsamples = 2000 # samples from the prior ptheta and ppsi
n=80 # number of trajectories observed
M = 20

N = 7

def set_kernel(nsamples=nsamples):
    return ScaleKernel(RBFKernel(ard_num_dims=1,batch_shape=torch.Size([nsamples])) + RBFKernel(ard_num_dims=1,batch_shape=torch.Size([nsamples])))

def set_GP(kernel, target_ls, nuisance_ls, x, batch=False, nsamples=nsamples):
    kernel.base_kernel.kernels[0].lengthscale = target_ls
    kernel.base_kernel.kernels[1].lengthscale = nuisance_ls
    kernel.outputscale = 1.
    evalkernel = kernel(x).to_dense()
    return (
        torch.distributions.multivariate_normal.MultivariateNormal(
            torch.zeros(len(x)),
            evalkernel + 1e-11 * torch.eye(len(x)).repeat(nsamples, 1, 1),
        )
        if batch
        else torch.distributions.multivariate_normal.MultivariateNormal(
            torch.zeros(len(x)), evalkernel + 1e-11 * torch.eye(len(x))
        )
    )

tildef_theta = 6.
tildef_psi = 3.
tildef = torch.tensor([tildef_theta, tildef_psi])

theta_samples = theta_ls_prior.sample((nsamples, 1)).exp()
psi_samples = psi_ls_prior.sample((nsamples, 1))
Theta = torch.cat((theta_samples, psi_samples), dim=1)

def GP_sample(x, tildef_theta=tildef_theta, tildef_psi=tildef_psi, target_clust_size=0):
    kernel = ScaleKernel(RBFKernel(ard_num_dims=1) + RBFKernel(ard_num_dims=1))
    kernel.outputscale = 1
    res = torch.zeros((n, len(x)))
    PSI=[]
    for i in range(n):
        if i < target_clust_size:
            psi = tildef_psi
        else:
            psi = psi_ls_prior.sample()
        GPtheta = set_GP(kernel, tildef_theta,  psi, x)
        res[i] = GPtheta.sample()
        PSI.append(psi)
    return (res, PSI)

# Define zeta as E_theta[p(y | x, theta, psi)]
# For each psi, sample M theta, and then compute the expectation as 1/m \sum_{i=1}^M p(y | x, theta_i, psi)
# The distribution the thetas are sampled from changes at each iteration of the `get_robust_eta()` loop
def logp_psi(theta, psi, x, y):
    NM = theta.shape[0]
    logprob = set_GP(set_kernel(NM), theta, psi, x, batch=True, nsamples=NM).log_prob(y[:, None, :]).T
    logprob = logprob.reshape((M,NM // M,*logprob.shape[1:]))
    logprob = logprob.exp().mean(dim=0).log()
    return logprob

def zeta(x, y, f, theta=None):
    _, psi = f[:, 0], f[:, 1]
    if isinstance(theta, type(None)):
        theta = theta_ls_prior.sample((M, 1)).exp()
    theta = theta[:M,:].repeat_interleave(psi.shape[0], dim=0)
    psi = psi.repeat((M,))[:,None]
    logprob = logp_psi(theta, psi, x, y)
    max_logp = logp_psi(theta, psi, x, torch.zeros_like(y))
    return (logprob - max_logp).exp()

# First, compute zeta(X, Y, psi) by taking the expectation over theta wrt P_Theta
# Then, use the resulting weights to get an (approximate) likelihood for theta, and approximate the
#     resulting posterior \hat{P}^R_Theta using a Gaussian
# Then, recompute zeta(X, Y, psi) by taking the expectation over theta wrt \hat{P}^R_Theta
# Repeat for `reps` iterations, and output the resulting weights
def get_robust_eta(eta, X, Y, Theta, L_psi, reps=3):
    for _ in range(reps):
        target_ls, nuisance_ls = Theta[:, 0], Theta[:, 1]
        L = set_GP(set_kernel(), target_ls, nuisance_ls, X, batch=True).log_prob(Y[:, None, :]).T * eta
        L = L.sum(dim=1)
        L += L_psi
        L -= L.exp().mean().log()
        wghts = L.exp() / L.exp().sum()
        wght_theta = Theta.clone()
        mn_theta = wght_theta[:,0]@wghts
        var_theta = ((wght_theta[:,0] - mn_theta)**2.)@wghts
        mn_theta = mn_theta * torch.ones((M,1))
        var_theta = var_theta * torch.ones((M,1))
        theta_samples_ = torch.normal(mean=mn_theta, std=var_theta**.5).exp()
        eta = zeta(X,Y,Theta,theta=theta_samples_)
    return eta

def generate_source_data(
    Xgrid, tildef_theta=tildef_theta, tildef_psi=tildef_psi, target_clust_size=25
):
    Y0, PSI = GP_sample(
        Xgrid, tildef_psi=tildef_psi, tildef_theta=tildef_theta,
        target_clust_size=target_clust_size
    )
    return (Xgrid, Y0), PSI

def ig_theta(D0, Z, robust=True, reps=10, tildef=tildef, N=None):
    X, Y = D0
    Theta_ = Theta.clone()
    Theta_[:,0] = tildef[0]

    eta = zeta(X,Y,Theta)
    I = torch.where(~torch.isnan(Z))[0]
    L_psi = Binomial(N, eta[:,I]).log_prob(Z[None,I])
    L_psi = L_psi.sum(dim=1)
    L_psi -= L_psi.exp().mean().log()
    
    # ~I is of the same shape as I
    Y, eta = Y[~I], eta[:,~I]
    target_ls, nuisance_ls = Theta[:, 0], Theta[:, 1]
    target_ls_, nuisance_ls_ = Theta_[:, 0], Theta_[:, 1]
    
    L = set_GP(set_kernel(), target_ls, nuisance_ls, X, batch=True).log_prob(Y[:, None, :]).T
    L_ = set_GP(set_kernel(), target_ls_, nuisance_ls_, X, batch=True).log_prob(Y[:, None, :]).T

    if robust:
        eta = get_robust_eta(eta, X, Y, Theta, L_psi, reps=reps)
        L *= eta
        L_ *= eta
        L_ = L_.sum(dim=1)
        wghts = L_psi / L_psi.sum()
        L_ = (wghts@L_.exp()).log()

    else:
        L_ = L_.sum(dim=1)
        L_ = L_.exp().mean().log()
        
    L = L.sum(dim=1)
    L = L.exp().mean().log()
        
    return L_ - L

def simulation(
    res=25, reps=3, percent=1., tildef_theta=tildef_theta, target_clust_size=25,
    pflip=0.
):
    try:
        tildef_ = tildef.clone()
        tildef_[0] = tildef_theta
        tildef_[1] = psi_ls_prior.sample((1, 1))
        Xgrid = torch.linspace(0., 10., res)[:,None]
        D0,PSI = generate_source_data(
            Xgrid, tildef_theta=tildef_theta, tildef_psi=tildef_[1],
            target_clust_size=target_clust_size
        )
        X, Y = D0
        Eta_star = zeta(X,Y,tildef_[None,:])[0,:]
        eta_mask = torch.bernoulli(pflip*torch.ones_like(Eta_star)).bool()
        Eta_star[eta_mask] = 1. - Eta_star[eta_mask]
        Z = Binomial(N, Eta_star).sample()
        nfeedback= int(percent * n)
        Z[nfeedback:] = torch.nan
        ig_robust = ig_theta((X,Y), Z=Z, reps=reps, tildef=tildef_, N=N)
        ig = ig_theta((X,Y), Z=Z, robust=False, tildef=tildef_, N=N)

        return ig_robust - ig

    except RuntimeError:
        return torch.tensor(torch.nan)

nsims = 50

def run_simulations(i, pct, res, reps, target_clust_size, tildef_theta, pflip):
    fname = f"results_{i}"
    if not os.path.exists(fname):
        ig_diffs = [ simulation(
            reps=int(reps), res=int(res), percent=pct, tildef_theta=tildef_theta,
            target_clust_size=int(target_clust_size), pflip=pflip
        ) for _ in range(nsims) ]
        rr = np.array([ r.detach().numpy() for r in ig_diffs ])
        with open(fname, "wb") as wfh:
            pkl.dump(dict(index=i, results=rr), wfh)
    else:
        print(f"{fname} exists.")
