import torch
import torch.nn as nn
import copy
import pytorch_lightning as pl
import argparse
import pandas as pd
import numpy as np
import plotly.express as px
from torchdiffeq import odeint
import diffeq_layers

#from causalode.sinkhorn import SinkhornDistance

from torch import Tensor
import torch.nn.functional as F
#import torchsde

#decoder
class MLPSimple(nn.Module):
    def __init__(self,input_dim,output_dim, hidden_dim, depth, activations = None, dropout_p = None):
        super().__init__()
        self.input_layer = nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.ReLU())
        self.output_layer = nn.Sequential(nn.Linear(hidden_dim,output_dim))
        if activations is None:
            activations = [nn.ReLU() for _ in range(depth)]
        if dropout_p is None:
            dropout_p = [0. for _ in range(depth)]
        assert len(activations) == depth
        #self.layers = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim,hidden_dim),nn.Dropout(dropout_p[i]),activations[i]) for i in range(depth)])
        #batchNorm instead of dropout
        self.layers = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim,hidden_dim),nn.BatchNorm1d(hidden_dim),activations[i]) for i in range(depth)])
    def forward(self,x):
        x = self.input_layer(x)
        for mod in self.layers:
            x = mod(x)
        x = self.output_layer(x)
        return x

"""
def divergence_approx(f, y, e=None):
    e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0] #batchSize x stateDim
    e_dzdx_e = e_dzdx.mul(e) #batchSize x stateDim

    cnt = 0
    while not e_dzdx_e.requires_grad and cnt < 10:
        # print("RequiresGrad:f=%s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt=%d"
        #       % (f.requires_grad, y.requires_grad, e_dzdx.requires_grad,
        #          e.requires_grad, e_dzdx_e.requires_grad, cnt))
        e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
        e_dzdx_e = e_dzdx * e
        cnt += 1
    approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
    regTerm2 = torch.square(e_dzdx).sum(dim=-1)
    assert approx_tr_dzdx.requires_grad, \
        "(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
        % (
        f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad, e_dzdx_e.requires_grad, cnt)
    return approx_tr_dzdx.unsqueeze(-1), regTerm2.unsqueeze(-1)
"""

def divergence_approx(f, y, e=None):

    samples = []
    sqnorms = []
    for  e_ in e:
        e_dzdx = torch.autograd.grad(f, y, e_, create_graph=True)[0]
        n = e_dzdx.view(y.size(0),-1).pow(2).mean(dim=1, keepdim=True)
        sqnorms.append(n)
        e_dzdx_e = e_dzdx * e_
        samples.append(e_dzdx_e.view(y.shape[0], -1).sum(dim=1, keepdim=True))

    S = torch.cat(samples, dim=1)
    approx_tr_dzdx = S.mean(dim=1)

    regTerm2 = torch.cat(sqnorms, dim=1).mean(dim=1)


    return approx_tr_dzdx, regTerm2


def sample_rademacher_like(y):
    return torch.randint(low=0, high=2, size=y.shape, device=y.device) * 2 - 1

NONLINEARITIES = {
    "tanh": nn.Tanh(),
    "relu": nn.ReLU(),
    "softplus": nn.Softplus(),
    "elu": nn.ELU(),
    #"swish": Swish(),
    #"square": Lambda(lambda x: x ** 2),
    #"identity": Lambda(lambda x: x),
}


class ODEnet(nn.Module):
    """
    Helper class to make neural nets for use in continuous normalizing flows
    """

    def __init__(self, hidden_dims, state_dim, contex_dim, layer_type="concat", nonlinearity="softplus"):
        super().__init__()
        base_layer = {
            "ignore": diffeq_layers.IgnoreLinear,
            "squash": diffeq_layers.SquashLinear,
            "scale": diffeq_layers.ScaleLinear,
            "concat": diffeq_layers.ConcatLinear,
            "concat_v2": diffeq_layers.ConcatLinear_v2,
            "concatsquash": diffeq_layers.ConcatSquashLinear,
            "concatscale": diffeq_layers.ConcatScaleLinear,
        }[layer_type]

        # build models and add them
        layers = []
        activation_fns = []
        hidden_shape = state_dim

        for dim_out in (hidden_dims + (state_dim[0],)):
            layer_kwargs = {}
            layer = base_layer(hidden_shape[0], dim_out, contex_dim+1, **layer_kwargs)
            layers.append(layer)
            activation_fns.append(NONLINEARITIES[nonlinearity])

            hidden_shape = list(copy.copy(hidden_shape))
            hidden_shape[0] = dim_out

        self.layers = nn.ModuleList(layers)
        self.activation_fns = nn.ModuleList(activation_fns[:-1])

    def forward(self, context, y):
        dx = y
        for l, layer in enumerate(self.layers):
            dx = layer(context, dx)
            # if not last layer, use nonlinearity
            if l < len(self.layers) - 1:
                dx = self.activation_fns[l](dx)
        return dx

class ODEfunc(nn.Module):
    def __init__(self, diffeq, obs_dim):
        super(ODEfunc, self).__init__()
        self.diffeq = diffeq
        self.time = 0
        self.obs_dim = obs_dim
        self.divergence_fn = divergence_approx
        self.register_buffer("_num_evals", torch.tensor(0.))
        self.register_buffer("regTerm1", torch.tensor(0.))
        self.register_buffer("regTerm2", torch.tensor(0.))

    def input_index(self, t, time_list, context, _s, parent_forcing, addNoise):
        """
        Returns the index of the last element in ordered_list
        that is smaller than value, or -1 if no such
        element exists.
        """
        
        high = self.time
        if time_list[high] < t and  time_list[high+1] >= t:
            high=high
            low=10000
        elif time_list[high] >= t and  time_list[high-1] < t:
            high=high-1
            low=2222
        else:
            #high = torch.bucketize(t,time_list[:-1])-1
            #high[high<0]=0
            
            low = 0 # max(self.time-1, 0)
            high =  len(time_list)-2 #min(self.time+1, len(time_list)-2) #dont take the last index of time because we dont pass observations for that
            while low < high:
                mid = (low + high +1) // 2
                if time_list[mid] >= t: #>=
                    high = mid - 1
                else:
                    low = mid
            

        
        
        #print(t)
        #print(time_list)
        #high = torch.bucketize(t,time_list)-1
        #high[high<0]=0
        #print(high)
        #print(context.shape)
        self.time = high
        try:
            return _s, context[:,high,:], time_list[high]
            
        except:
            print('watheefuuk')
            print(context.shape)
            print(time_list.shape)
            print(low)
            print(high)
            print(self.time)
            print(t)
            print(time_list)
            self.time = high
        self.time = high
        
        return _s, context[:,high,:], time_list[high],

    def input_index_all_included(self, t, time_list, context, _s, parent_forcing, addNoise):
        """
        Returns the index of the last element in ordered_list
        that is smaller than value, or -1 if no such
        element exists.
        """

        """
        high = self.time
        if time_list[high] < t and  time_list[high+1] >= t-0.1:
            high=high
        elif time_list[high-1] < t and  time_list[high] >= t:
            self.time = high-1
            high=high-1
        elif time_list[high+1] < t and  time_list[high+2] >= t:
            self.time = high+1
            high=high+1
        else:
        """
        #print('nans ode')
        high = self.time
        if time_list[high] < t and  time_list[high+1] >= t:
            high=high
        else:
            low = 0 # max(self.time-1, 0)
            high =  len(time_list)-2 #min(self.time+1, len(time_list)-2) #dont take the last index of time because we dont pass observations for that
            while low < high:
                mid = (low + high +1) // 2
                if time_list[mid] >= t: #>=
                    high = mid - 1
                else:
                    low = mid
            self.time=high

        # If missing actions use previous timestep where actions are defined
        if torch.isnan(context[:,high,:-self.obs_dim]).any():
            cnt_ = 0
            #print('missing actions')
            while torch.isnan(context[:,high,:-self.obs_dim]).any() and cnt_ < 10:
                #print(high)
                high -= 1
                cnt_ +=1
        
        if torch.abs(t-time_list[high])<0.0001:
            # if missing actions use previous time stamp where actions included

            # Replace missing observations in the context with model state
            if torch.isnan(context[:,high,-self.obs_dim:]).any():
                mask=torch.isnan(context[:,high,-self.obs_dim:])
                context[:,high,-self.obs_dim:][mask] = _s[:,-self.obs_dim:][mask][0]
                #print(context[:,high,-self.obs_dim:][mask].data)
                #print(_s[:,-self.obs_dim:][mask])
                self.context[:,high,:].data = context

            if parent_forcing : #parent forcing
                _s[:,-self.obs_dim:].data =  context[:,high,-self.obs_dim:] 

        if addNoise:
            noise = torch.randn_like(_s, device=self.device)*0.05
            _s = ((noise + _s).detach() - _s).detach() + _s
            #noise_s = torch.randn_like(startState)*0.1
            #startState = ((noise_s + startState).detach() - startState).detach() + startState
        self.time = high
        #print('nans')
        #print(torch.isnan(context[:,high,:]).any())
        if torch.isnan(context[:,high,:]).any():
            print('time')
            print(t)
            print(time_list[high])
            print( t-time_list[high])
            print('nans')
            print(torch.isnan(context[:,high,:-self.obs_dim]).any())
            print(context[:,high,:])

        #print(torch.isnan(time_list[high]).any())
        #print(torch.isnan(_s).any())
        return _s, context[:,high,:], time_list[high], #+noise

    def before_odeint(self, context, integration_times, parent_forcing=True, addNoise=True, deBug=False, logpx=None, e=None): # x, context ,integration_times, logpx=None, 
        self._e = e
        self._num_evals.fill_(0)
        self.regTerm1 = torch.tensor([0.]) #.requires_grad_(True)
        self.regTerm2 = torch.tensor([0.])  # .fill_(0)#.requires_grad_(True)
        #self.obs = obs
        self.time = 0
        self.integration_times = integration_times #+ 0.0001
        self.context = context
        self.parent_forcing = parent_forcing
        self.logpx0 = torch.tensor([0.])
        self.deBug = deBug
        self.logpx = logpx
        self.addNoise = addNoise

        # use slower version if context has nans or using paret forcing
        if torch.isnan(self.context).any() or self.parent_forcing:
            self.get_fun = lambda t, _x: (self.input_index_all_included(t, self.integration_times, self.context, _x, self.parent_forcing, self.addNoise))
        else:
            self.get_fun = lambda t, _x: (self.input_index(t, self.integration_times, self.context, _x, self.parent_forcing, self.addNoise))


        
        #if logpx==None:
        
        
        
        #else:
        #    get_fun = lambda t, _x: (input_index(t, integration_times, x, context, _x), log_px)
        
        #wrapping_fun = lambda t, _x: self.odefunc(t, state = get_fun(t, _x),
        #                      params = true_params)
        

    def get_regularization_terms(self):
        return self.regTerm1, self.regTerm2

    def get_num_evals(self):
        return self._num_evals

    def forward(self, t, states):
        y  = states[0]
        
        y, c, t_ = self.get_fun(t, y)
        """
        c = states[1]
        """
        t_ = t_ - t

        #print(c.size())
        #print(t_.size())
        t = torch.ones(y.size(0), 1, device= y.device) * t_.clone().detach().requires_grad_(True)
        self._num_evals += 1
        for state in states:
            state.requires_grad_(True)

        # Sample and fix the noise.
        if self._e is None:
            #Gaussian noise
            #self._e = torch.randn_like(y, requires_grad=True).to(y) #test other distribtuions
            #Rademacher noise
            self._e = [sample_rademacher_like(y) for k in range(5)] #number of samples for approximation
            #self._e = torch.randint_like(y, low=0, high=2).float() * 2 - 1.
        
        with torch.set_grad_enabled(True):   
            #c = states[1]
            tc = torch.cat([t, c], dim=1) #torch.cat([t, c.view(y.size(0), -1)], dim=1) #dim=1
            y.requires_grad_(True)
            tc.requires_grad_(True)
            #print('find nans')
            #print(torch.isnan(tc).any())
            #print(torch.isnan(y).any())
            dy = self.diffeq(tc, y)
            if self.deBug:
                print('DY')
                print(t[0])
                print(dy.mean())
                print(tc.mean())
                print(y.mean())

            if len(states) == 1: #no divergence needed
                return (dy)

            else:  # loglikelihood estimates
                #dy_ = dy.clone().detach()
                divergence, regTerm2 = self.divergence_fn(dy, y, e=self._e) #.unsqueeze(-1)
                #print("divergence calculated")
                #regTerm1 = torch.square(dy.clone()).sum(dim=-1)
                regTerm1 = 0.5 * torch.sum(torch.pow(dy, 2) , -1 ,keepdims=True)
                #regTerm2 = torch.abs(-gradPhi[:,-1].unsqueeze(1) + alph[0] * dv  )
                #self.regTerm1 += torch.mean(regTerm1).detach().to(self.regTerm1)
                #self.regTerm2 += torch.mean(regTerm2).detach().to(self.regTerm2)
                #-divergence/ torch.zeros_like(states[1])
                return (dy,) +( -divergence, regTerm1.requires_grad_(True), regTerm2.requires_grad_(True))
            