"""
Adopted toy dataset generation from XXXX
"""

import os
import sys

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from torchdiffeq import odeint_adjoint as odeint


class ODEDef(nn.Module):

    '''
    nn.Module wrapper for an user-defined ode to work well with torchdiffeq odeint_adjoint
    
    Attributes:
        f               function with parameters (t,y,params) that defines dy/dt
        params          parameters to supply to f
        p_transform     optional transformation to apply to params when generating a trajectory (add noise to parameters for example)
        y_transform     optional transformation to apply to solution output tensor y (add state variable, noise, bias, etc)
        device          device on which tensors are stored
        params_traj     private attribute defining params used for current trajectory

    Methods:
        forward         used internally by torchdiff for solving the ode
        get_trajectory  main method for use, returns a solution tensor without output optionally transformed by y_transform
    '''
    def __init__(
        self,
        f,
        params,
        p_transform=None,
        y_transform=None,
        method = 'dopri5',
        device=None
    ):
        super().__init__()
        self.f = f
        self.params = params
        self.p_transform = p_transform
        self.y_transform = y_transform
        self.method = method
        self.device = device
        self.params_traj = None

    def forward(self,t,y):
        return self.f(t,y,self.params_traj)

    def get_trajectory(self,y_0,t):
    # solve the ode starting at y_0 over t
    # apply parameter transforms prior to computing the trajectory
    # apply y trasform after computing the trajectory
        if self.p_transform is not None:
            self.params_traj = self.p_transform(self.params)
        else:
            self.params_traj = self.params

        with torch.no_grad():
            y = odeint(self,y_0,t,method = self.method)
        
        if self.y_transform is not None:
            y = self.y_transform(t,y,self.params)

        return(y)



class ODEDataSet(Dataset):
    '''
    Class for using and making ODE datasets for use with pytorch dataloader

    Data sets are made of Y, a T-by-M-by-N tensor, t a T length tensor, and
    label a M length list, where T is the number of time samples, M is the
    number of trajectories, and N is the number of dimensions of the ODE

    Attributes:
        data_file   string file name for saving data set 
        Y           T-by-M-by-N tensor
        t           T length tensor
        label       M length list

    Methods:
        __len__         get number of trajectories in dataset (M)
        __getitem__     retrieve (Y,t,label) tuple
        make_dataset    construct a data set from a list of odes, y_0's, labels
    
    '''
    def __init__(
        self,
        data_file = None,
        Y = None,
        t = None,
        label = None,
        item_transform = None,
        output_type = None,
        device = None
    ):
        self.data_file = data_file
        self.Y = Y
        self.t = t
        self.label = label
        self.item_transform = item_transform
        self.output_type = output_type
        self.device = device

        if self.data_file is not None:
            if os.path.exists(self.data_file):
                self.Y, self.t, self.label = torch.load(self.data_file)


    def __len__(self):
        return self.Y.shape[1]

    def __getitem__(self,idx):

        Y_idx = self.Y[:,idx,:]
        t_idx = self.t
        label_idx = self.label[idx]

        if self.item_transform is not None:
            Y_idx,t_idx,label_idx = self.item_transform(Y_idx,t_idx,label_idx)

        return Y_idx,t_idx,label_idx

    def make_dataset(self,ode_list,y_0_list,t_max,t_samples,label_list):

        '''
        make a data set from a list of ode's, y_0's over time range 0-t_max with t_samples
        Y: tensor T-by-M-by-N
        t: tensor T
        
        where
            T is number of time samples, set by t_samples
            M is number of trajectories, set by sum of y_0_list[i].shape[0]
            N is number of features of data, equal to dimensionality of y_0_list[i].shape[1]

        '''
        
        # need the total data size to build dataset
        n_traj_ind = [x.shape[0] for x in y_0_list]
        n_traj = sum(n_traj_ind)
        self.t = torch.linspace(0., t_max, t_samples).to(self.device)

        # don't know the output shape yet, sample the first
        Y_1 = ode_list[0].get_trajectory(y_0_list[0][0:1,:],self.t)
        self.Y = torch.zeros(t_samples,n_traj,Y_1.shape[2]).to(self.device)

        # one label per sample
        if label_list is not None:
            assert len(label_list) == n_traj, 'length of label list must equal total number of trajectories'
            self.label = label_list

        m_idx = 0
        for grp_idx, ode_f in enumerate(ode_list):
            y_0_grp = y_0_list[grp_idx]
            for s_idx in range(y_0_grp.shape[0]):
                y_s = ode_f.get_trajectory(y_0_grp[s_idx:(s_idx+1),:],self.t)
                self.Y[:,m_idx:(m_idx+1),:] = y_s #1d slice to preserve dimensions on LHS
                m_idx += 1
        
        if self.data_file is not None:
            torch.save([self.Y, self.t, self.label],self.data_file)

def cubic_oscillator_conditional_sd(t,y,params):
    # cubic oscillator for conditional ode framework with static and dynamic conditions
    #
    # defines an oscillator with cubic dynamics as in duvenaud et. al 2018 but with linear change in parameters with time
    #
    # y is 1-by-4 
    #   first two indices of d2 are spatial variables
    #   second two are state variables
    # params is tuple (A1,dA1_dc,dc_dt)
    #   A1 = baseline parameter state
    #   dA1_dc = change in A per change condition
    #   dc_dt = change in condition per time

    # unpack parameters
    A1,dA1_dc,dc_dt = params

    dy = y.clone().detach()

    # set derivatives for static, dynamic params
    dy[:,2] = 0
    dy[:,3] = dc_dt

    # modify A1 based on current dynamic state
    curr_A = A1+dA1_dc*y[0,3]

    # define derivative for spatial coords, following torchdiffeq, make the cubic dependence explicit 
    dy[:,0:2] = torch.mm(y[:,0:2]**3,curr_A) 

    return dy


def cubic_oscillator(t,y,params):
    # torchdiffeq example equation
    A = params
    return torch.mm(y**3,A)