# Modeling Irregular Time Series with Continuous Recurrent Units (CRUs)
# Copyright (c) 2022 Robert Bosch GmbH
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# This source code is derived from Pytorch RKN Implementation (https://github.com/ALRhub/rkn_share)
# Copyright (c) 2021 Philipp Becker (Autonomous Learning Robots Lab @ KIT)
# licensed under MIT License
# cf. 3rd-party-licenses.txt file in the root directory of this source tree.

from lib.CRU import CRU
#from lib.GRU import GRU
from lib.mTAND import mTAND
from lib.utils import MyLayerNorm2d
from lib.GRUD import GRUD
from lib.TACD_GRU import TACD_GRU
#from lib.LatentODE import LatentODE
from lib.LatentODE_alt_impl import LatentODE
from lib.ncdssm_components import AuxInferenceModel, GaussianOutput
from lib.ncdssm_modules import MLP
from lib.ncdssm import NCDSSMLTI, NCDSSMLL, NCDSSMNL
from lib.ContiFormer import ContiFormer
from lib.Raindrop import Raindrop
from lib.ODERNN import ODERNN
from lib.GRU_delta import GRUDelta
import torch
import pdb
nn = torch.nn


# new code component
def load_model(args, train_dl, means=None):

    if args.gru:
        if args.dataset == 'physionet':
            model = GRU(target_dim=37, lsd=args.latent_state_dim, args=args,
                    use_cuda_if_available=True)
        elif args.dataset == 'mimic':
            if args.task == 'extrapolation':
                model = GRU(target_dim=506, lsd=args.latent_state_dim, args=args,
                        use_cuda_if_available=True)
            elif args.task == 'classification':
                model = GRU(target_dim=506, lsd=args.latent_state_dim, args=args,
                        use_cuda_if_available=True)
        elif args.dataset == 'ushcn':
            if args.task == 'extrapolation':
                model = GRU(target_dim=5, lsd=args.latent_state_dim, args=args,
                        use_cuda_if_available=True)
    elif args.mTAND:
        if args.dataset == 'physionet':
            model = mTAND(target_dim=37, 
                        lsd=args.latent_state_dim, args=args, 
                        use_cuda_if_available=True)
        elif args.dataset == 'fBM':
            if args.task == 'extrapolation':
                model = mTAND(target_dim=1,
                            lsd=args.latent_state_dim, args=args, 
                            use_cuda_if_available=True)
        elif args.dataset == 'mimic':
            if args.task == 'extrapolation':
                model = mTAND(target_dim=506,
                            lsd=args.latent_state_dim, args=args, 
                            use_cuda_if_available=True)
            elif args.task == 'next_obs_prediction':
                model = mTAND(target_dim=506,
                            lsd=args.latent_state_dim, args=args, 
                            use_cuda_if_available=True)
            elif args.task == 'classification':
                model = mTAND(target_dim=506,
                            lsd=args.latent_state_dim, args=args, 
                            use_cuda_if_available=True)
        elif args.dataset == 'ushcn':
            model = mTAND(target_dim=5, 
                        lsd=args.latent_state_dim, args=args, 
                        use_cuda_if_available=True)
    elif args.grud:
        assert not means is None, "Means cannot be none for GRUD"
        if args.dataset == 'physionet':
            model = GRUD(target_dim=37, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'mimic':
            model = GRUD(target_dim=506, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'ushcn':
            model = GRUD(target_dim=5, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'fBM':
            model = GRUD(target_dim=1, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)

    elif args.grudplus:
        #assert not means is None, "Means cannot be none for GRUD"
        if args.dataset == 'physionet':
            model = GRUDPlus(target_dim=37, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=False)
        elif args.dataset == 'mimic':
            model = GRUDPlus(target_dim=506, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=True)
        elif args.dataset == 'ushcn':
            if args.task == 'extrapolation':
                model = GRUDPlus(target_dim=5, lsd=args.latent_state_dim, 
                        args=args, means=means, use_cuda_if_available=True,
                        use_encoder=False)

    elif args.raindrop:
        if args.dataset == 'physionet':
            model = Raindrop(target_dim=37, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=False)
        elif args.dataset == 'mimic':
            model = Raindrop(target_dim=506, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=False)
        elif args.dataset == 'ushcn':
            if args.task == 'extrapolation':
                model = Raindrop(target_dim=5, lsd=args.latent_state_dim, 
                        args=args, means=means, use_cuda_if_available=True,
                        use_encoder=False)
    elif args.tacd_gru:
        if args.dataset == 'physionet':
            model = TACD_GRU(target_dim=37, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=False)
        elif args.dataset == 'mimic':
            model = TACD_GRU(target_dim=506, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=False)
        elif args.dataset == 'fBM':
            model = TACD_GRU(target_dim=1, 
                    lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=False)
        elif args.dataset == 'ushcn':
            #if args.task == 'extrapolation':
            model = TACD_GRU(target_dim=5, lsd=args.latent_state_dim, 
                    args=args, means=means, use_cuda_if_available=True,
                    use_encoder=False)
        elif args.dataset == 'activity':
            # classification
            model = TACD_GRU(target_dim=12, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True,
                    use_encoder=True)

    elif args.ncdssmnl:
        if args.dataset == 'mimic':
            model = get_NCDSSMNL(args)

    elif args.contiformer:
        if args.dataset == 'mimic':
            model = ContiFormer(target_dim=506, lsd=args.latent_state_dim, args=args)
        elif args.dataset == 'physionet':
            model = ContiFormer(target_dim=37, lsd=args.latent_state_dim, args=args)
        elif args.dataset == 'ushcn':
            model = ContiFormer(target_dim=5, lsd=args.latent_state_dim, args=args)

    elif args.grudelta:
        if args.dataset == 'mimic':
            model = GRUDelta(target_dim=506, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'physionet':
            model = GRUDelta(target_dim=37, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'ushcn':
            model = GRUDelta(target_dim=5, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)

    elif args.ode_rnn:
        if args.dataset == 'mimic':
            model = ODERNN(target_dim=506, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'physionet':
            model = ODERNN(target_dim=37, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'ushcn':
            model = ODERNN(target_dim=5, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)

    elif args.latent_ode:
        if args.dataset == 'mimic':
            model = LatentODE(target_dim=506, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'physionet':
            model = LatentODE(target_dim=37, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)
        elif args.dataset == 'ushcn':
            model = LatentODE(target_dim=5, lsd=args.latent_state_dim, 
                    args=args, means=means,
                    use_cuda_if_available=True)

    elif args.nhlstm:
        if args.dataset == 'mimic':
            model = LSTM_NHP(target_dim=506, lsd=args.latent_state_dim, args=args,
                    use_cuda_if_available=True)
        elif args.dataset == 'activity':
            model = LSTM_NHP(target_dim=37, lsd=args.latent_state_dim, args=args,
                    use_cuda_if_available=True)
    elif args.peannlstm:
        model = LSTM_PEANN(target_dim=37, lsd=args.latent_state_dim, args=args,
                use_cuda_if_available=True, train_dl=train_dl)
    
    else:
        # all other than proposed method go here
        # Pendulum 
        if args.dataset == 'pendulum':
            if args.task =='regression':
                model = Pendulum_reg(target_dim=2, lsd=args.latent_state_dim, args=args, 
                    layer_norm=False, use_cuda_if_available=True)
            elif args.task == 'interpolation':
                model = Pendulum(target_dim=(1, 24, 24), lsd=args.latent_state_dim, args=args, 
                    layer_norm=True, use_cuda_if_available=True, bernoulli_output=True)
            else:
                raise Exception('Task not available for Pendulum data')
            
        # USHCN
        elif args.dataset == 'ushcn':
            model = Physionet_USHCN(target_dim=5, 
                        lsd=args.latent_state_dim, args=args,
                        use_cuda_if_available=True)

        # Physionet
        elif args.dataset == 'physionet':
            model = Physionet_USHCN(target_dim=37, lsd=args.latent_state_dim, args=args,
                    use_cuda_if_available=True)

        # MIMIC
        elif args.dataset == 'mimic':
            model = Physionet_USHCN(target_dim=506, lsd=args.latent_state_dim, args=args,
                    use_cuda_if_available=True)

    return model


def get_NCDSSMNL(args):
    aux_inf_base_net = nn.Identity()
    aux_inf_dist_net = GaussianOutput(
        nn.Identity(),
        dist_dim=506, # might change
        use_tied_cov=True,
        use_trainable_cov=False,
        sigma=1e-4,
    )
    aux_inference_net = AuxInferenceModel(
        aux_inf_base_net,
        aux_inf_dist_net,
        aux_dim=506,
        concat_mask=False,
    )
    y_emission_net = GaussianOutput(
        nn.Identity(),
        dist_dim=506,
        use_tied_cov=True,
        use_trainable_cov=False,
        use_independent=False,
        sigma=1e-4,
    )
    non_linear_drift_func = MLP(
        in_dim=10, # hyper param
        h_dim=64,
        out_dim=10,
        nonlinearity=nn.Softplus,
        last_nonlinearity=False,
        n_hidden_layers=1,
        zero_init_last=False,
        apply_spectral_norm=True,
    )
    H = None
    model = NCDSSMNL(
        args,
        aux_inference_net,
        y_emission_net,
        aux_dim=506,
        z_dim=10,
        y_dim=506,
        u_dim=0,
        f=non_linear_drift_func,
        integration_step_size=0.1,
        integration_method='euler',
        H=H,
        sporadic=True,
    )
    return model

# new code component
class Physionet_USHCN(CRU):

    def __init__(self, target_dim: int, lsd: int, args,
                 use_cuda_if_available: bool = True):
        self.hidden_units = args.hidden_units
        self.target_dim = target_dim

        super(Physionet_USHCN, self).__init__(target_dim, lsd,
                                             args, use_cuda_if_available)


    def _build_enc_hidden_layers(self):
        layers = []
        layers.append(nn.Linear(self.target_dim, self.hidden_units))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm(self.hidden_units))

        layers.append(nn.Linear(self.hidden_units, self.hidden_units))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm(self.hidden_units))

        layers.append(nn.Linear(self.hidden_units, self.hidden_units))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm(self.hidden_units))
        # size last hidden
        return nn.ModuleList(layers).to(dtype=torch.float64), self.hidden_units

    def _build_dec_hidden_layers_mean(self):
        return nn.ModuleList([
            nn.Linear(in_features=2 * self._lod, out_features=self.hidden_units),
            nn.ReLU(),
            nn.LayerNorm(self.hidden_units),

            nn.Linear(in_features=self.hidden_units, out_features=self.hidden_units),
            nn.ReLU(),
            nn.LayerNorm(self.hidden_units),

            nn.Linear(in_features=self.hidden_units, out_features=self.hidden_units),
            nn.ReLU(),
            nn.LayerNorm(self.hidden_units)
        ]).to(dtype=torch.float64), self.hidden_units

    def _build_dec_hidden_layers_var(self):
        return nn.ModuleList([
            nn.Linear(in_features=3 * self._lod, out_features=self.hidden_units),
            nn.ReLU(),
            nn.LayerNorm(self.hidden_units)
        ]).to(dtype=torch.float64), self.hidden_units

# new code component
class Physionet_USHCN_old(CRU):

    def __init__(self, target_dim: int, lsd: int, args,
                 use_cuda_if_available: bool = True):
        self.hidden_units = args.hidden_units
        self.target_dim = target_dim
        self._layer_norm = True
        if self._layer_norm:
            layers.append(MyLayerNorm2d(channels=12))
        layers.append(nn.ReLU())
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        # hidden layer 2
        layers.append(nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=2, padding=1))
        if self._layer_norm:
            layers.append(MyLayerNorm2d(channels=12))
        layers.append(nn.ReLU())
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        # hidden layer 3
        layers.append(nn.Flatten())
        layers.append(nn.Linear(in_features=108, out_features=30))
        layers.append(nn.ReLU())
        return nn.ModuleList(layers), 30

    def _build_dec_hidden_layers(self):
        return nn.ModuleList([
            nn.Linear(in_features=2 * self._lod, out_features=144),
            nn.ReLU(),
            nn.Unflatten(1, [16, 3, 3]),

            nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=5, stride=4, padding=2),
            MyLayerNorm2d(channels=16),
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=16, out_channels=12, kernel_size=3, stride=2, padding=1),
            MyLayerNorm2d(channels=12),
            nn.ReLU()
        ]), 12


# taken from https://github.com/ALRhub/rkn_share/ and modified
class Pendulum_reg(CRU):

    def __init__(self, target_dim: int, lsd: int, args, layer_norm: bool,
                 use_cuda_if_available: bool = True):

        self._layer_norm = layer_norm
        super(Pendulum_reg, self).__init__(target_dim, lsd, args, use_cuda_if_available)

    def _build_enc_hidden_layers(self):
        layers = []
        # hidden layer 1
        layers.append(nn.Conv2d(in_channels=1, out_channels=12,
                      kernel_size=5, padding=2))
        if self._layer_norm:
            layers.append(MyLayerNorm2d(channels=12))
        layers.append(nn.ReLU())
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        # hidden layer 2
        layers.append(nn.Conv2d(in_channels=12, out_channels=12,
                      kernel_size=3, stride=2, padding=1))
        if self._layer_norm:
            layers.append(MyLayerNorm2d(channels=12))
        layers.append(nn.ReLU())
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        # hidden layer 3
        layers.append(nn.Flatten())
        layers.append(nn.Linear(in_features=108, out_features=30))
        layers.append(nn.ReLU())
        return nn.ModuleList(layers).to(dtype=torch.float64), 30

    def _build_dec_hidden_layers_mean(self):
        return nn.ModuleList([
            nn.Linear(in_features=2 * self._lod, out_features=30),
            nn.Tanh()
        ]), 30

    def _build_dec_hidden_layers_var(self):
        return nn.ModuleList([
            nn.Linear(in_features=3 * self._lod, out_features=30),
            nn.Tanh()
        ]), 30

