###########################
# Latent ODEs for Irregularly-Sampled Time Series
# Author: Yulia Rubanova
###########################

import gc
import numpy as np
import sklearn as sk
import numpy as np
#import gc
import torch
import torch.nn as nn
from torch.nn.functional import relu

import lib.utils as utils
from lib.utils import get_device
from lib.encoder_decoder import *
from lib.likelihood_eval import *

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence, Independent


def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None):
        n_data_points = mu_2d.size()[-1]

        if n_data_points > 0:
                gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1)
                log_prob = gaussian.log_prob(data_2d)
                log_prob = log_prob / n_data_points
        else:
                log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
        return log_prob


def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas):
        # masked_log_lambdas and masked_data
        n_data_points = masked_data.size()[-1]

        if n_data_points > 0:
                log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices]
                #log_prob = log_prob / n_data_points
        else:
                log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze()
        return log_prob



def compute_binary_CE_loss(label_predictions, mortality_label):
        #print("Computing binary classification loss: compute_CE_loss")

        mortality_label = mortality_label.reshape(-1)

        if len(label_predictions.size()) == 1:
                label_predictions = label_predictions.unsqueeze(0)

        n_traj_samples = label_predictions.size(0)
        label_predictions = label_predictions.reshape(n_traj_samples, -1)

        idx_not_nan = ~torch.isnan(mortality_label)
        if len(idx_not_nan) == 0.:
                print("All are labels are NaNs!")
                ce_loss = torch.Tensor(0.).to(get_device(mortality_label))

        label_predictions = label_predictions[:,idx_not_nan]
        mortality_label = mortality_label[idx_not_nan]

        if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0:
                print("Warning: all examples in a batch belong to the same class -- please increase the batch size.")

        assert(not torch.isnan(label_predictions).any())
        assert(not torch.isnan(mortality_label).any())

        # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
        mortality_label = mortality_label.repeat(n_traj_samples, 1)
        ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)

        # divide by number of patients in a batch
        ce_loss = ce_loss / n_traj_samples
        return ce_loss


def compute_multiclass_CE_loss(label_predictions, true_label, mask, weight=None):
        #print("Computing multi-class classification loss: compute_multiclass_CE_loss")

        if (len(label_predictions.size()) == 3):
                label_predictions = label_predictions.unsqueeze(0)

        n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()


        _ori_mask = mask

        if False:
                # assert(not torch.isnan(label_predictions).any())
                # assert(not torch.isnan(true_label).any())

                # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
                true_label = true_label.repeat(n_traj_samples, 1, 1)

                label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims)
                true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims)

                # choose time points with at least one measurement
                mask = torch.sum(mask, -1) > 0

                # repeat the mask for each label to mark that the label for this time point is present
                pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0)

                label_mask = mask
                pred_mask = pred_mask.repeat(n_traj_samples,1,1,1)
                label_mask = label_mask.repeat(n_traj_samples,1,1,1)

                pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp,  n_dims)
                label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)

                if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
                        assert(label_predictions.size(-1) == true_label.size(-1))
                        # targets are in one-hot encoding -- convert to indices
                        _, true_label = true_label.max(-1)

                res = []
                for i in range(true_label.size(0)):
                        pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool())
                        labels = torch.masked_select(true_label[i], label_mask[i].bool())

                        pred_masked = pred_masked.reshape(-1, n_dims)

                        if (len(labels) == 0):
                                continue

                        ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long())
                        res.append(ce_loss)

                ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
                ce_loss = torch.mean(ce_loss)

        else:
                cel = nn.CrossEntropyLoss(weight=weight)
                if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
                        assert(label_predictions.size(-1) == true_label.size(-1))
                        # targets are in one-hot encoding -- convert to indices
                        _, true_label = true_label.max(-1)

                if (mask == 1).all():
                        # SKIP applying for mask
                        true_label = true_label.repeat(n_traj_samples * n_traj, 1, 1).flatten()
                        # true_label = true_label.repeat(n_traj_samples, 1, 1).flatten()
                        label_predictions = label_predictions.reshape(-1, n_dims)
                        _ce_loss = cel(label_predictions, true_label.long())
                else:
                        time_steps_to_consider = (torch.sum(_ori_mask, -1) > 0).repeat(n_traj_samples, 1, 1).flatten()
                        true_label = true_label.repeat(n_traj_samples, 1, 1).flatten()
                        label_predictions = label_predictions.reshape(-1, n_dims)
                        _ce_loss = cel(label_predictions[time_steps_to_consider], true_label[time_steps_to_consider].long())
                ce_loss = _ce_loss
#       print(ce_loss)
#       print(_ce_loss)



        # # divide by number of patients in a batch
        # ce_loss = ce_loss / n_traj_samples
        return ce_loss




def compute_masked_likelihood(mu, data, mask, likelihood_func, _type, obsrv_std=None):
        # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
        n_traj_samples, n_traj, n_timepoints, n_dims = data.size()

        #time_steps_to_consider = (torch.sum(mask, -1) > 0).repeat(n_traj_samples, 1, 1)
        time_steps_to_consider = (torch.sum(mask, -1) > 0)

        if False:

                res = []
                for i in range(n_traj_samples):
                        for k in range(n_traj):
                                for j in range(n_dims):
                                        data_masked = torch.masked_select(data[i,k,:,j], mask[i,k,:,j].bool())

                                        #assert(torch.sum(data_masked == 0.) < 10)

                                        mu_masked = torch.masked_select(mu[i,k,:,j], mask[i,k,:,j].bool())
        #                               print(mu_masked.detach().cpu().numpy())
        #                               print(data_masked.detach().cpu().numpy())
        #                               print('---')
                                        log_prob = likelihood_func(mu_masked, data_masked, indices = (i,k,j))
                                        res.append(log_prob)
                # shape: [n_traj*n_traj_samples, 1]

                res = torch.stack(res, 0).to(get_device(data))
                res = res.reshape((n_traj_samples, n_traj, n_dims))
                # Take mean over the number of dimensions
                res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean
                res = res.transpose(0,1)




        # create a mask to map original shape into an unstructured shape
        _mask = mask.bool()
        # set target shape as the one without n_timepoints
        target_shape = n_traj_samples, n_traj, n_dims
        # apply mask to obtained an unstructured tensor
        _mu = mu[_mask]
        _data = data[_mask]

        if _type == 'logprob':
            # compute log prob as unstructured data shape
            gaussian = Independent(Normal(loc = _mu, scale=obsrv_std), 0)
            log_prob = gaussian.log_prob(_data)

            # map the masked log-prob back to the original shape
            lol = torch.zeros(data.shape).to(data.device)
            lol[_mask] = log_prob

            # create a diviser to average the log prob across timepoints
            diviser = _mask.float().sum(2)
            # replace any zero as 1
            diviser[diviser==0] = 1

            # get mean along dim
            lol = lol.sum(2) / diviser

        elif _type == 'mse':
            # compute mse as unstructured data shape
            #mse = nn.MSELoss()(_mu, _data)
            mse = (_mu - _data)**2

            # map the masked mse back to the original shape
            lol = torch.zeros(data.shape).to(data.device)
            lol[_mask] = mse

            # create a diviser to average the log prob across timepoints
            diviser = _mask.float().sum(2)
            # replace any zero as 1
            diviser[diviser==0] = 1

            # get mean along timepoints
            lol = lol.sum(2) / diviser

        # marginalise the dim
        lol = lol.mean(-1)

        # transpose back to the expected shape
        lol = lol.transpose(0,1)
#       assert torch.isclose(lol, res).all(), (lol, res)


        return lol


        return res


def masked_gaussian_log_density(mu, data, obsrv_std, mask = None):
        # these cases are for plotting through plot_estim_density
        if (len(mu.size()) == 3):
                # add additional dimension for gp samples
                mu = mu.unsqueeze(0)

        if (len(data.size()) == 2):
                # add additional dimension for gp samples and time step
                data = data.unsqueeze(0).unsqueeze(2)
        elif (len(data.size()) == 3):
                # add additional dimension for gp samples
                data = data.unsqueeze(0)

        n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()

        assert(data.size()[-1] == n_dims)

        # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
        if mask is None:
                mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
                n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
                data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)

                res = gaussian_log_likelihood(mu_flat, data_flat, obsrv_std)
                res = res.reshape(n_traj_samples, n_traj).transpose(0,1)
        else:
                # Compute the likelihood per patient so that we don't priorize patients with more measurements
                func = lambda mu, data, indices: gaussian_log_likelihood(mu, data, obsrv_std = obsrv_std, indices = indices)
                #res = compute_masked_likelihood(mu, data, mask, func)
                res = compute_masked_likelihood(mu, data, mask, func, _type='logprob', obsrv_std=obsrv_std)
        return res



def mse(mu, data, indices = None):
        n_data_points = mu.size()[-1]

        if n_data_points > 0:
                mse = nn.MSELoss()(mu, data)
        else:
                mse = torch.zeros([1]).to(get_device(data)).squeeze()
        return mse


def compute_mse(mu, data, mask = None):
        # these cases are for plotting through plot_estim_density
        if (len(mu.size()) == 3):
                # add additional dimension for gp samples
                mu = mu.unsqueeze(0)

        if (len(data.size()) == 2):
                # add additional dimension for gp samples and time step
                data = data.unsqueeze(0).unsqueeze(2)
        elif (len(data.size()) == 3):
                # add additional dimension for gp samples
                data = data.unsqueeze(0)

        n_traj_samples, n_traj, n_timepoints, n_dims = mu.size()
        assert(data.size()[-1] == n_dims)

        # Shape after permutation: [n_traj, n_traj_samples, n_timepoints, n_dims]
        if mask is None:
                mu_flat = mu.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
                n_traj_samples, n_traj, n_timepoints, n_dims = data.size()
                data_flat = data.reshape(n_traj_samples*n_traj, n_timepoints * n_dims)
                res = mse(mu_flat, data_flat)
        else:
                # Compute the likelihood per patient so that we don't priorize patients with more measurements
                #res = compute_masked_likelihood(mu, data, mask, mse)
                res = compute_masked_likelihood(mu, data, mask, mse, _type='mse')
        return res




def compute_poisson_proc_likelihood(truth, pred_y, info, mask = None):
        # Compute Poisson likelihood
        # https://math.stackexchange.com/questions/344487/log-likelihood-of-a-realization-of-a-poisson-process
        # Sum log lambdas across all time points
        if mask is None:
                poisson_log_l = torch.sum(info["log_lambda_y"], 2) - info["int_lambda"]
                # Sum over data dims
                poisson_log_l = torch.mean(poisson_log_l, -1)
        else:
                # Compute likelihood of the data under the predictions
                truth_repeated = truth.repeat(pred_y.size(0), 1, 1, 1)
                mask_repeated = mask.repeat(pred_y.size(0), 1, 1, 1)

                # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
                int_lambda = info["int_lambda"]
                f = lambda log_lam, data, indices: poisson_log_likelihood(log_lam, data, indices, int_lambda)
                poisson_log_l = compute_masked_likelihood(info["log_lambda_y"], truth_repeated, mask_repeated, f)
                poisson_log_l = poisson_log_l.permute(1,0)
                # Take mean over n_traj
                #poisson_log_l = torch.mean(poisson_log_l, 1)

        # poisson_log_l shape: [n_traj_samples, n_traj]
        return poisson_log_l



