import torch
import torch.nn as nn
import numpy as np

from lib.ncdssm_components import AuxInferenceModel
from lib.ncdssm_base import BaseNL, BaseLTI, BaseLL
from lib.ncdssm_torch_utils import merge_leading_dims, grad_norm
from lib.ncdssm_type import Tensor, Optional, List
from torch.utils.tensorboard import SummaryWriter
from lib.utils import log_to_tensorboard, make_dir, compute_physionet_intermediate, compute_mimic_intermediate
from datetime import datetime
from lib.data_utils import  align_output_and_target, adjust_obs_for_extrapolation
from lib.losses import rmse, mse, GaussianNegLogLik, bernoulli_nll, mae
import pdb
optim = torch.optim
nn = torch.nn
F = nn.functional


class NCDSSM(nn.Module):
    """Base class for NCDSSM models.

    Parameters
    ----------
    aux_inference_net
        The auxiliary inference model parameterized by a neural network
    y_emission_net
        The emission model parameterized by a neural network
    aux_dim
        The dimension of the auxiliary variables
    z_dim
        The dimension of the latent states
    y_dim
        The dimension of the observations
    u_dim
        The dimension of the control inputs
    integration_method, optional
        The integration method, can be one of "euler" or "rk4",
        by default "rk4"
    integration_step_size, optional
        The integration step size, by default 0.1
    sporadic, optional
        A flag to indicate whether the dataset is sporadic,
        i.e., with values missing in both time and feature dimensions,
        by default False
    """

    def __init__(
        self,
        aux_inference_net: nn.Module,
        y_emission_net: nn.Module,
        aux_dim: int,
        z_dim: int,
        y_dim: int,
        u_dim: int,
        integration_method: str = "rk4",
        integration_step_size: float = 0.1,
        sporadic: bool = False,
        **kwargs,
    ):
        super().__init__()
        assert u_dim == 0, "Support for control inputs is not implemented yet"
        self.aux_inference_net = aux_inference_net
        self.y_emission_net = y_emission_net
        self.aux_dim = aux_dim
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.u_dim = u_dim
        self.integration_method = integration_method
        self.integration_step_size = integration_step_size
        self.sporadic = sporadic

    def forward(
        self,
        y: Tensor,
        mask: Tensor,
        times: Tensor,
        num_samples: int = 1,
        deterministic: bool = False,
    ):
        assert (
            y.size() == mask.size() or y.size()[:-1] == mask.size()
        ), "Shapes of y and mask should match!"
        # Currently assumes that t = 0 is in times
        assert times[0] == 0.0, "First timestep should be 0!"
        batch_size = y.size(0)

        aux_samples, aux_entropy, aux_post_log_prob = self.aux_inference_net(
            y,
            mask,
            num_samples=num_samples,
            deterministic=deterministic,
        )
        aux_samples = merge_leading_dims(aux_samples, ndims=2)
        aux_entropy = merge_leading_dims(aux_entropy, ndims=2)
        aux_post_log_prob = merge_leading_dims(aux_post_log_prob, ndims=2)
        # Compute likelihoods
        if mask.ndim < y.ndim:
            mask = mask.unsqueeze(-1)
        repeated_mask = mask.repeat(num_samples, 1, 1)
        emission_dist = self.y_emission_net(aux_samples)
        pdb.set_trace()
        y_log_likelihood = emission_dist.log_prob(y.repeat(num_samples, 1, 1))
        y_log_likelihood = y_log_likelihood * repeated_mask
        y_log_likelihood = y_log_likelihood.sum(-1).sum(-1)
        filter_result = self.base_ssm.filter(
            aux_samples,
            repeated_mask if self.sporadic else (repeated_mask.sum(-1) > 0).float(),
            times,
        )
        aux_log_likelihood = filter_result["log_prob"]
        aux_entropy = aux_entropy.sum(dim=-1)
        # Compute ELBO
        regularizer = aux_log_likelihood + aux_entropy
        elbo = y_log_likelihood + regularizer
        # Compute IWELBO
        iwelbo = torch.logsumexp(
            y_log_likelihood.view(num_samples, batch_size)
            + aux_log_likelihood.view(num_samples, batch_size)
            - aux_post_log_prob.sum(dim=-1).view(num_samples, batch_size),
            dim=0,
        ) - np.log(num_samples)
        return dict(
            elbo=elbo,
            iwelbo=iwelbo,
            likelihood=y_log_likelihood,
            regularizer=regularizer,
        )

    @torch.no_grad()
    def forecast(
        self,
        y: Tensor,
        mask: Tensor,
        past_times: Tensor,
        future_times: Tensor,
        num_samples: int = 80,
        deterministic: bool = False,
        no_state_sampling: bool = False,
        use_smooth: bool = False,
    ):
        """Make predictions (imputation and forecast) using the observed data.

        Parameters
        ----------
        y
            The tensor of observations, of shape (batch_size, num_timesteps, y_dim)
        mask
            The mask of missing values (1: observed, 0: missing),
            of shape (batch_size, num_timesteps, y_dim), if sporadic,
            else (batch_size, num_timesteps)
        past_times
            The times of the past observations, of shape (num_past_steps,)
        future_times
            The times of the forecast, of shape (num_forecast_steps,)
        num_samples, optional
            The number of sample paths to draw, by default 80
        deterministic, optional
            Whether to peform deterministic sampling from auxiliary model,
            by default False (not really used)
        no_state_sampling, optional
            Whether to sample from the predicted state distributions,
            by default False and only uses the means of the distributions
        use_smooth, optional
            Whether to perform smoothing after filtering (useful for imputation),
            by default False

        Returns
        -------
            The reconstructed context (imputing values, if required) and the forecast
        """
        B, T, _ = y.shape

        aux_samples, _, _ = self.aux_inference_net(
            y,
            mask,
            num_samples=1,
            deterministic=deterministic,
        )
        # aux_samples.shape = num_samples x B x time x aux_dim
        aux_samples = merge_leading_dims(aux_samples, ndims=2)
        if mask.ndim < y.ndim:
            mask = mask.unsqueeze(-1)
        # Generate predictions from the base CDKF
        base_predictions = self.base_ssm.forecast(
            aux_samples,
            mask if self.sporadic else (mask.sum(-1) > 0).float(),
            past_times,
            future_times,
            num_samples=num_samples,
            no_state_sampling=no_state_sampling,
            use_smooth=use_smooth,
        )
        aux_reconstruction = base_predictions["reconstruction"]
        aux_forecast = base_predictions["forecast"]
        z_reconstruction = base_predictions["z_reconstruction"]
        z_forecast = base_predictions["z_forecast"]

        # Decode aux --> y
        reconstruction_emit_dist = self.y_emission_net(
            merge_leading_dims(aux_reconstruction, ndims=2)
        )
        y_reconstruction = reconstruction_emit_dist.sample()
        y_reconstruction = y_reconstruction.view(
            num_samples, B, aux_reconstruction.shape[-2], self.y_dim
        )
        forecast_emit_dist = self.y_emission_net(
            merge_leading_dims(aux_forecast, ndims=2)
        )
        y_forecast = forecast_emit_dist.sample()
        y_forecast = y_forecast.view(num_samples, B, aux_forecast.shape[-2], self.y_dim)

        return dict(
            reconstruction=y_reconstruction,
            forecast=y_forecast,
            z_reconstruction=z_reconstruction,
            z_forecast=z_forecast,
            aux_reconstruction=aux_reconstruction,
            aux_forecast=aux_forecast,
        )


class NCDSSMLTI(NCDSSM):
    """The NCDSSM model with linear time-invariant dynamics.

    Parameters
    ----------
    aux_inference_net
        The auxiliary inference model parameterized by a neural network
    y_emission_net
        The emission model parameterized by a neural network
    aux_dim
        The dimension of the auxiliary variables
    z_dim
        The dimension of the latent states
    y_dim
        The dimension of the observations
    u_dim
        The dimension of the control inputs
    integration_method, optional
        The integration method, can be one of "euler" or "rk4",
        by default "rk4"
    integration_step_size, optional
        The integration step size, by default 0.1
    sporadic, optional
        A flag to indicate whether the dataset is sporadic,
        i.e., with values missing in both time and feature dimensions,
        by default False
    """

    def __init__(
        self,
        aux_inference_net: nn.Module,
        y_emission_net: nn.Module,
        aux_dim: int,
        z_dim: int,
        y_dim: int,
        u_dim: int,
        integration_method: str = "rk4",
        integration_step_size: float = 0.1,
        sporadic: bool = False,
        **kwargs,
    ):
        super().__init__(
            aux_inference_net,
            y_emission_net,
            aux_dim,
            z_dim,
            y_dim,
            u_dim,
            integration_method,
            integration_step_size,
            sporadic,
            **kwargs,
        )
        self.base_ssm = BaseLTI(
            z_dim=z_dim,
            y_dim=aux_dim,
            u_dim=u_dim,
            integration_method=integration_method,
            integration_step_size=integration_step_size,
            sporadic=sporadic,
            **kwargs,
        )


class NCDSSMLL(NCDSSM):
    """The NCDSSM model with locally-linear dynamics.

    Parameters
    ----------
    aux_inference_net
        The auxiliary inference model parameterized by a neural network
    y_emission_net
        The emission model parameterized by a neural network
    aux_dim
        The dimension of the auxiliary variables
    K
        The number of base matrices (i.e., dynamics)
    z_dim
        The dimension of the latent states
    y_dim
        The dimension of the observations
    u_dim
        The dimension of the control inputs
    alpha_net
        A mixing network that takes the state `z` as input
        and outputs the mixing coefficients for the base dynamics
    integration_method, optional
        The integration method, can be one of "euler" or "rk4",
        by default "rk4"
    integration_step_size, optional
        The integration step size, by default 0.1
    sporadic, optional
        A flag to indicate whether the dataset is sporadic,
        i.e., with values missing in both time and feature dimensions,
        by default False
    """

    def __init__(
        self,
        aux_inference_net: AuxInferenceModel,
        y_emission_net: nn.Module,
        aux_dim: int,
        K: int,
        z_dim: int,
        y_dim: int,
        u_dim: int,
        alpha_net: nn.Module,
        integration_method: str = "rk4",
        integration_step_size: float = 0.1,
        sporadic: bool = False,
        **kwargs,
    ):
        super().__init__(
            aux_inference_net,
            y_emission_net,
            aux_dim,
            z_dim,
            y_dim,
            u_dim,
            integration_method,
            integration_step_size,
            sporadic,
            **kwargs,
        )
        self.base_ssm = BaseLL(
            K=K,
            z_dim=z_dim,
            y_dim=aux_dim,
            u_dim=u_dim,
            alpha_net=alpha_net,
            integration_method=integration_method,
            integration_step_size=integration_step_size,
            sporadic=sporadic,
            **kwargs,
        )

    @torch.no_grad()
    def forecast(
        self,
        y: Tensor,
        mask: Tensor,
        past_times: Tensor,
        future_times: Tensor,
        num_samples: int = 80,
        deterministic: bool = False,
        no_state_sampling: bool = False,
        use_smooth: bool = False,
    ):
        B, T, _ = y.shape

        aux_samples, _, _ = self.aux_inference_net(
            y,
            mask,
            num_samples=1,
            deterministic=deterministic,
        )
        # aux_samples.shape = num_samples x B x time x aux_dim
        aux_samples = merge_leading_dims(aux_samples, ndims=2)
        if mask.ndim < y.ndim:
            mask = mask.unsqueeze(-1)
        # Generate predictions from the base CDKF
        base_predictions = self.base_ssm.forecast(
            aux_samples,
            mask if self.sporadic else (mask.sum(-1) > 0).float(),
            past_times,
            future_times,
            num_samples=num_samples,
            no_state_sampling=no_state_sampling,
            use_smooth=use_smooth,
        )
        aux_reconstruction = base_predictions["reconstruction"]
        aux_forecast = base_predictions["forecast"]
        z_reconstruction = base_predictions["z_reconstruction"]
        z_forecast = base_predictions["z_forecast"]
        alpha_reconstruction = base_predictions["alpha_reconstruction"]
        alpha_forecast = base_predictions["alpha_forecast"]

        # Decode aux --> y
        reconstruction_emit_dist = self.y_emission_net(
            merge_leading_dims(aux_reconstruction, ndims=2)
        )
        y_reconstruction = reconstruction_emit_dist.sample()
        y_reconstruction = y_reconstruction.view(
            num_samples, B, aux_reconstruction.shape[-2], self.y_dim
        )
        forecast_emit_dist = self.y_emission_net(
            merge_leading_dims(aux_forecast, ndims=2)
        )
        y_forecast = forecast_emit_dist.sample()
        y_forecast = y_forecast.view(num_samples, B, aux_forecast.shape[-2], self.y_dim)

        return dict(
            reconstruction=y_reconstruction,
            forecast=y_forecast,
            z_reconstruction=z_reconstruction,
            z_forecast=z_forecast,
            alpha_reconstruction=alpha_reconstruction,
            alpha_forecast=alpha_forecast,
            aux_reconstruction=aux_reconstruction,
            aux_forecast=aux_forecast,
        )


class NCDSSMNL(NCDSSM):
    """The NCDSSM model with non-linear dynamics.

    Parameters
    ----------
    aux_inference_net
        The auxiliary inference model parameterized by a neural network
    y_emission_net
        The emission model parameterized by a neural network
    aux_dim
        The dimension of the auxiliary variables
    z_dim
        The dimension of the latent states
    y_dim
        The dimension of the observations
    u_dim
        The dimension of the control inputs
    f, optional
        The dynamics/drift function f(z), by default None
    gs, optional
        The list of diffusion functions g(z), one for each z_dim, by default None
    integration_method, optional
        The integration method, can be one of "euler" or "rk4",
        by default "rk4"
    integration_step_size, optional
        The integration step size, by default 0.1
    sporadic, optional
        A flag to indicate whether the dataset is sporadic,
        i.e., with values missing in both time and feature dimensions,
        by default False
    """

    def __init__(
        self,
        aux_inference_net: AuxInferenceModel,
        y_emission_net: nn.Module,
        aux_dim: int,
        z_dim: int,
        y_dim: int,
        u_dim: int,
        f: nn.Module,
        gs: Optional[List[nn.Module]] = None,
        integration_method: str = "rk4",
        integration_step_size: float = 0.1,
        sporadic: bool = False,
        **kwargs,
    ):
        super().__init__(
            aux_inference_net,
            y_emission_net,
            aux_dim,
            z_dim,
            y_dim,
            u_dim,
            integration_method,
            integration_step_size,
            sporadic,
            **kwargs,
        )
        self.base_ssm = BaseNL(
            z_dim=z_dim,
            y_dim=aux_dim,
            u_dim=u_dim,
            f=f,
            gs=gs,
            integration_method=integration_method,
            integration_step_size=integration_step_size,
            sporadic=sporadic,
            **kwargs,
        )


class NCDSSMNL(NCDSSM):
    """The NCDSSM model with non-linear dynamics.

    Parameters
    ----------
    aux_inference_net
        The auxiliary inference model parameterized by a neural network
    y_emission_net
        The emission model parameterized by a neural network
    aux_dim
        The dimension of the auxiliary variables
    z_dim
        The dimension of the latent states
    y_dim
        The dimension of the observations
    u_dim
        The dimension of the control inputs
    f, optional
        The dynamics/drift function f(z), by default None
    gs, optional
        The list of diffusion functions g(z), one for each z_dim, by default None
    integration_method, optional
        The integration method, can be one of "euler" or "rk4",
        by default "rk4"
    integration_step_size, optional
        The integration step size, by default 0.1
    sporadic, optional
        A flag to indicate whether the dataset is sporadic,
        i.e., with values missing in both time and feature dimensions,
        by default False
    """

    def __init__(
        self,
        args,
        aux_inference_net: AuxInferenceModel,
        y_emission_net: nn.Module,
        aux_dim: int,
        z_dim: int,
        y_dim: int,
        u_dim: int,
        f: nn.Module,
        gs: Optional[List[nn.Module]] = None,
        integration_method: str = "rk4",
        integration_step_size: float = 0.1,
        sporadic: bool = False,
        **kwargs,
    ):
        super().__init__(
            aux_inference_net,
            y_emission_net,
            aux_dim,
            z_dim,
            y_dim,
            u_dim,
            integration_method,
            integration_step_size,
            sporadic,
            **kwargs,
        )
        self.args = args
        self.base_ssm = BaseNL(
            z_dim=z_dim,
            y_dim=aux_dim,
            u_dim=u_dim,
            f=f,
            gs=gs,
            integration_method=integration_method,
            integration_step_size=integration_step_size,
            sporadic=sporadic,
            **kwargs,
        )
        #self._params = list(self.parameters())
        #self._optimizer = optim.Adam(self._params, lr=self.args.lr)
        use_cuda_if_available = True
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() and use_cuda_if_available else "cpu")
        self = self.to(self._device)
        print('done initializing')

    def extrapolation(self, data, track_gradient=False):
        """
            Only used during inference
        """
        no_state_sampling = False 
        use_smooth = False
        num_samples = 5
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs, numeric_event_ids = [
            j.to(self._device) for j in data]

        if numeric_event_ids.ndim > 1:
            assert numeric_event_ids.ndim==2, "more than two dimensions in numeric event ids"
            numeric_event_ids = numeric_event_ids[0,:]

        obs, obs_valid_extrap, obs_valid = adjust_obs_for_extrapolation(self.args.dataset, 
            obs, obs_valid, mask_obs, obs_times)

        with torch.set_grad_enabled(track_gradient):

            past_mask = obs_times <= 0.0
            future_mask = obs_times > 0.0
            obs_times = obs_times.min().abs() + obs_times
            past_obs = truth[past_mask]
            past_times = obs_times[past_mask]
            past_obs_mask = mask_truth[past_mask]
            future_obs = truth[future_mask]
            future_times = obs_times[future_mask]
            future_obs_mask = mask_truth[future_mask]
            
            predict_result = self.forecast(
                past_obs[None, ...].float(),
                past_obs_mask[None, ...].float(), 
                past_times.float() / (3 * 24 * 60 * 60),
                future_times.float() / (3 * 24 * 60 * 60),
                num_samples=num_samples,
                no_state_sampling=no_state_sampling,
                use_smooth=use_smooth)

            obs_hat = predict_result['forecast'].mean(0)

            # predict input
            imput_mse = mse(future_obs[None, ...][..., numeric_event_ids], 
                obs_hat[..., numeric_event_ids], 
                mask=future_obs_mask[None, ...][..., numeric_event_ids])
            imput_mae = mae(future_obs[None, ...][..., numeric_event_ids], 
                obs_hat[..., numeric_event_ids], 
                mask=future_obs_mask[None, ...][..., numeric_event_ids])

        intermediates = {}
        return imput_mse, imput_mae


    def train_epoch(self, dl, optimizer):
        epoch_ll = 0
        epoch_rmse = 0
        epoch_mse = 0
        intermediates = None
        output_mean = None
        output_var = None
        obs = None 
        truth = None 
        mask_obs = None
        imput_metrics = None
        clf_acc = None
        max_grad_norm = 100.0

        if self.args.save_intermediates is not None:
            mask_obs_epoch = []
            intermediates_epoch = []

        if self.args.task == 'extrapolation' or self.args.task == 'interpolation':
            epoch_imput_ll = 0
            epoch_imput_mse = 0
        elif self.args.task == 'classification':
            epoch_labels = []
            epoch_predictions = []

        for i, data in enumerate(dl):
        #with torch.set_grad_enabled(True):
            optimizer.zero_grad()
            assert self.args.task == 'extrapolation', "extrapolation is only supported"
            obs, truth, obs_valid, obs_times, mask_truth, mask_obs, numeric_event_ids = [
                j.to(self._device) for j in data]
            print('obs size ', truth.shape)

            if numeric_event_ids.ndim > 1:
                assert numeric_event_ids.ndim==2, "more than two dimensions in numeric event ids"
                numeric_event_ids = numeric_event_ids[0,:]

            obs_times = obs_times.min().abs().item() + obs_times
            
            assert obs_times.shape[0] == 1, "currently support batch sizes of 1"
            out = self.forward(
                y=truth.float(), 
                times=obs_times[0,:].float() / (3 * 24 * 60 * 60),
                mask=mask_truth.float())
            cond_ll = out["likelihood"]
            reg = out["regularizer"]
            loss = -(cond_ll + 1.0 * reg).mean(0)
            pdb.set_trace()
            loss.backward()
            total_grad_norm = grad_norm(self.parameters())
            if total_grad_norm < float("inf"):
                if max_grad_norm != float("inf"):
                    torch.nn.utils.clip_grad_norm_(
                        self.parameters(), max_norm=max_grad_norm
                    )
                optimizer.step()
            else:
                print("Skipped gradient update!")
                optimizer.zero_grad()

    def eval_epoch(self, dl, wandb=None):
        """Evaluates model on the entire dataset

        :param dl: dataloader containing validation or test data
        :return: evaluation metrics, computed output, input, intermediate variables
        """
        epoch_ll = 0
        epoch_rmse = 0
        epoch_mse = 0
        dynamics_mse = 0
        intermediate_results = {'combo_weights_all':[]}
        output_mean = None
        output_var = None
        obs = None 
        truth = None 
        mask_obs = None
        intermediates = None
        imput_metrics = None
        clf_acc = None

        if self.args.task == 'extrapolation' or self.args.task == 'interpolation':
            epoch_imput_ll = 0
            epoch_imput_mse = 0
            epoch_imput_mae = 0
        elif self.args.task == 'classification':
            epoch_labels = []
            epoch_predictions = []

        if self.args.save_intermediates is not None:
            mask_obs_epoch = []
            intermediates_epoch = []

        for i, data in enumerate(dl):
            imput_mse, imput_mae = self.extrapolation(data, track_gradient=False)
            epoch_imput_mse += imput_mse.item()
            epoch_imput_mae += imput_mae.item()
        return epoch_imput_mse / (i+1) , epoch_imput_mae / (i+1)

    def run_train(self, train_dl, valid_dl, identifier, logger, epoch_start=0, wandb=None):
        """Trains model on trainset and evaluates on test data. Logs results and saves trained model.

        :param train_dl: training dataloader
        :param valid_dl: validation dataloader
        :param identifier: logger id
        :param logger: logger object
        :param epoch_start: starting epoch
        """
        optimizer = optim.Adam(self.parameters(), self.args.lr)
        def lr_update(epoch): return self.args.lr_decay ** epoch
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lr_update)
        
        make_dir(f'../results/tensorboard/{self.args.dataset}')
        writer = SummaryWriter(f'../results/tensorboard/{self.args.dataset}/{identifier}')

        for epoch in range(epoch_start, self.args.epochs):
            start = datetime.now()
            logger.info(f'Epoch {epoch} starts: {start.strftime("%H:%M:%S")}')

            # train
            self.train_epoch(train_dl, optimizer)
            end_training = datetime.now()
            if self.args.tensorboard:
                log_to_tensorboard(self, writer=writer,
                                mode='train',
                                metrics=[train_ll, train_rmse, train_mse],
                                output=train_output,
                                input=train_input,
                                intermediates=intermediates,
                                epoch=epoch,
                                imput_metrics=train_imput_metrics,
                                log_rythm=self.args.log_rythm)

            # eval
            valid_mse, valid_mae = self.eval_epoch(valid_dl, wandb=wandb)
            wandb_dict = {}
            wandb_dict['valid_mse'] = valid_mse
            wandb_dict['valid_mae'] = valid_mae
            print('Epoch: ', epoch)
            print('Eval MSE: ', valid_mse)
            print('Eval MAE: ', valid_mae)

            if self.args.log_wandb:
                wandb.log(wandb_dict)

            if self.args.tensorboard:
                log_to_tensorboard(self, writer=writer,
                                mode='valid',
                                metrics=[valid_ll, valid_rmse, valid_mse],
                                output=valid_output,
                                input=valid_input,
                                intermediates=intermediates,
                                epoch=epoch,
                                imput_metrics=valid_imput_metrics,
                                log_rythm=self.args.log_rythm)
            
            scheduler.step()
