import os
import argparse
import math
import random
import logging
import numpy as np
import numpy.random as npr
import matplotlib
import matplotlib.pyplot as plt
import torch.optim as optim
from torch import nn
import torch
import torchcde
from torchdiffeq import odeint
import pdb



class LatentODEfunc(nn.Module):

    def __init__(self, latent_dim=4, nhidden=20):
        super(LatentODEfunc, self).__init__()
        self.elu = nn.ELU(inplace=True)
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, nhidden)
        self.fc3 = nn.Linear(nhidden, latent_dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.fc1(x)
        out = self.elu(out)
        out = self.fc2(out)
        out = self.elu(out)
        out = self.fc3(out)
        return out

class RecognitionRNN(nn.Module):

    def __init__(self, latent_dim=4, obs_dim=2, nhidden=25, nbatch=1):
        super(RecognitionRNN, self).__init__()
        self.nhidden = nhidden
        self.nbatch = nbatch
        self.i2h = nn.Linear(obs_dim + nhidden, nhidden)
        self.h2o = nn.Linear(nhidden, latent_dim * 2)

    def forward(self, x, h):
        combined = torch.cat((x, h), dim=1)
        h = torch.tanh(self.i2h(combined))
        out = self.h2o(h)
        return out, h

    def initHidden(self):
        return torch.zeros(1, self.nhidden)


class Decoder(nn.Module):

    def __init__(self, latent_dim=4, obs_dim=2, nhidden=20):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(latent_dim, nhidden)
        self.fc2 = nn.Linear(nhidden, obs_dim)

    def forward(self, z):
        out = self.fc1(z)
        out = self.relu(out)
        out = self.fc2(out)
        return out


def log_normal_pdf(x, mean, logvar):
    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
    const = torch.log(const)
    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))

def normal_kl(mu1, lv1, mu2, lv2):
    v1 = torch.exp(lv1)
    v2 = torch.exp(lv2)
    lstd1 = lv1 / 2.
    lstd2 = lv2 / 2.

    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
    return kl



class NeuralODE(nn.Module):
    def __init__(self, obs_dim, latent_dim, device, batch_size=200):
        super(NeuralODE, self).__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.func = LatentODEfunc(self.latent_dim, 16).to(device)
        self.rec = RecognitionRNN(self.latent_dim, obs_dim + 1, 16, 1).to(device)
        self.dec = Decoder(self.latent_dim, obs_dim, 16).to(device)
        self.batch_size = batch_size

    def forward(self, samples, orig_ts, **kwargs):
        bs, _ = samples.shape[0], len(orig_ts)
        samples = torch.cat([samples, orig_ts[...,None]], dim=-1)
        #sample_idx = npr.choice(bs, self.batch_size, replace=False)
        #samples = samples[sample_idx, ...]
        h = self.rec.initHidden().to(self.device).repeat(samples.shape[0], 1)

        for t in reversed(range(samples.size(1))):
            obs = samples[:, t, :]
            out, h = self.rec.forward(obs, h)
        qz0_mean, qz0_logvar = out[:, :self.latent_dim], out[:, self.latent_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(self.device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

        # forward in time and solve ode for reconstructions
        unique_ts = orig_ts.unique()
        sorted_unique_ts, sorted_indices = unique_ts.sort()
        indices = torch.searchsorted(sorted_unique_ts, orig_ts)
        pred_z = odeint(self.func, z0, sorted_unique_ts).permute(1, 0, 2)
        gather_by_time = [pred_z[i,indices[i]][None,...] for i in range(bs)]
        pred_z = torch.cat(gather_by_time, dim=0)
        pred_x = self.dec(pred_z)
        return pred_x, qz0_mean, qz0_logvar
        '''
        else:
            h = self.rec.initHidden().to(device).repeat(samples.shape[0], 1)

            for t in reversed(range(samples.size(1))):
                obs = samples[:, t, :]
                out, h = self.rec.forward(obs, h)
            qz0_mean, qz0_logvar = out[:, :self.latent_dim], out[:, self.latent_dim:]
            epsilon = torch.randn(qz0_mean.size()).to(device)
            z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean

            # forward in time and solve ode for reconstructions
            pred_z = odeint(self.func, z0, torch.tensor(orig_ts)).permute(1, 0, 2)
            pred_x = self.dec(pred_z)
            return pred_x, qz0_mean, qz0_logvar, None
        '''

    def calculate_loss(self, out, target, mask_target, numeric_event_ids):
        pred_x, qz0_mean, qz0_logvar = out
        target_x = target
        noise_std = 0.01
        noise_std_ = torch.zeros(pred_x.size()).to(self.device) + noise_std
        noise_logvar = 2. * torch.log(noise_std_).to(self.device)
        numeric_target_mask = mask_target[:,:,numeric_event_ids]
        num_target_x = target_x[:,:,numeric_event_ids]
        num_pred_x = pred_x[:,:,numeric_event_ids] 
        num_noise_logvar = noise_logvar[:,:,numeric_event_ids]
        logpx = log_normal_pdf(
            num_target_x[numeric_target_mask], 
            num_pred_x[numeric_target_mask], 
            num_noise_logvar[numeric_target_mask]).sum(-1).sum(-1)
        pz0_mean = pz0_logvar = torch.zeros(qz0_mean.size()).to(self.device)
        analytic_kl = normal_kl(qz0_mean, qz0_logvar,
                                pz0_mean, pz0_logvar).sum(-1)
        loss = torch.mean(-logpx + analytic_kl, dim=0)
        return loss
