"""
CODE ADAPTED FROM: https://github.com/microsoft/cf-ode/tree/main

"""
import pytorch_lightning as pl
import sys
sys.path.insert(0,"../")
#import utils
import torch
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import os
import argparse
import numpy as np
from scipy.integrate import odeint
import pandas as pd
import random

#python train_model.py --logger_type=wandb --seed=313 --max_epochs 20 --N_int 400 --N_obs 600 --batch_size 256 --num_blocks 1 --lr 0.01 --loss_func mse --observational 2 --reg1 0.00 --reg2 0.00 --state_dim 20


def create_pendulum_data_actions(N, observational, noise_std, obs_length, seed = 421):

    np.random.seed(seed)
    g = 9.81
    l_s = np.random.rand(N) * 2 + 1.5 # 4 + 0.5
    phi, delta = 1,1
    thetas_0 = np.random.rand(N)*5 + 0.5

    A = np.random.rand(N,3)*15 #uniform 0-15

    def sigmoid(x):
        return 1/(1 + np.exp(-x))

    def df_dt(x,t, l):
        return x[1], -(g/l)*np.sin(x[0])

    def dfu_dt(x,t,phi,delta):
            return (phi*x[1]*x[3]-delta*x[2]*x[3], -phi*x[2], phi*x[1], -delta*x[3])

    def df_dt_complete(x,t,l,phi,delta):
        return (x[1], -(g/l)*np.sin(x[0])*(1+x[2])) + dfu_dt(x[2:],t,phi,delta)

    def fun_u(t):
        return 10*sigmoid(4*t-5)*(1-sigmoid(4*t-6))
    
    def df_dt_fun(x,t,l):
        return (x[1], -(g/l)*np.sin(x[0])*(1+fun_u(t-10)))

    def vfun(x):
        return 0.02*(np.cos(5*x-0.2) * (5-x)**2)**2
        #return 0.2*(np.cos(10*x) * (3-x)**2)**2
        

    X = []
    At = []

    t = np.linspace(0.08,6,obs_length)+np.random.uniform(low=-0.08, high=0.08, size=obs_length)#np.sort(np.linspace(0,8,40)+np.random.uniform(low=-0.05, high=0.05, size=40))
    t = np.append(t,0.0)
    t = np.sort(t)


    for i in range(N):
        t_interv = [t[np.random.randint(0,obs_length//2 )], t[np.random.randint((obs_length//2)+1 , obs_length)]] #[t[5]]#
        v_treatment = np.zeros((len(t),1))
        theta_0 = thetas_0[i]
        x0 = np.array([theta_0,0])
        x = odeint(df_dt, x0, t[t <= t_interv[0]], args = (l_s[i],)) #before interventions

        x0_ = np.concatenate((x[-1][None,:],np.array([0,1,0,0])[None,:]),1)[0]
        # interventions
        for i_t in range(len(t_interv)):
            if i_t == len(t_interv)-1: #last intervention
                mask_t = t > t_interv[i_t]
            else:
                mask_t = (t > t_interv[i_t]) & (t <=  t_interv[i_t+1])
            
            #generate interventional or priviledged policy
            if observational:
                A_ = A[i, i_t] * (l_s[i]/4) * (np.sin(x0_[0])) / (np.abs(x0_[1])+1)
            else:
                A_ = A[i, i_t] * np.cos(x0_[0]-0.9)/2

            x0_ = np.concatenate((x0_[:-1],[A_]),0)
            t_ = np.concatenate(([0],t[mask_t]-t[t==t_interv[i_t]]),0 )
            x_ = odeint(df_dt_complete,x0_,t_,args=(l_s[i],phi,delta))
            x = np.concatenate((x,x_[1:,:2]),0)
            v_treatment[mask_t,0] = x_[1:,-1] #continuous dimming action 
            x0_ = x_[-1,:]
        X.append(torch.Tensor(x[:,0]))
        At.append(torch.Tensor(v_treatment))

    #noise
    X = torch.stack(X) + noise_std * torch.randn(N,len(t))
    
    At = torch.stack(At)
    t_X = torch.Tensor(np.tile(t[None,:],(X.shape[0],1)))

    X = torch.stack([np.sin(X), np.cos(X)],-1)

    return X, At, thetas_0, t_X 





class PendulumDataset(Dataset):
    def __init__(self,N, observational, state_dim,noise_std, seed, obs_length, missing_prop=0):

        X, A, thetas_0, t_X = create_pendulum_data_actions( 
            N=N, 
            observational=observational, 
            noise_std=noise_std, 
            seed = seed, 
            obs_length=obs_length)
        
        self.X = X

        # add Indicator for inyterventional/observational data
        A = torch.cat((A, torch.zeros_like(A[:,:,:1])+observational), dim=-1)
        self.A = A
        #self.p = p
        self.thetas_0 = thetas_0
        self.t_X  = t_X
        #self.state0 = torch.normal(mean = 0, std = 1, size = (N, state_dim))
        self.weights = torch.full_like(self.X, fill_value=True)
        #sanity check for nans
        if torch.isnan(X).any():
            print('data has NaN values')

    def __getitem__(self,idx):
        return self.X[idx], self.A[idx], self.t_X[idx], self.weights[idx] # self.p[idx], self.thetas_0[idx], 
    def __len__(self):
        return self.X.shape[0]


class PendulumDataModule(pl.LightningDataModule):
    def __init__(self,batch_size, seed, obs_length, N_int, N_obs, observational, state_dim, noise_std, num_workers = 4, **kwargs):
        
        super().__init__()
        self.batch_size = batch_size
        self.seed = seed
        self.obs_length=obs_length
        self.num_workers = num_workers

        self.train_shuffle = True

        self.input_dim = 2
        self.output_dim = 2
        self.action_dim = 2 

        self.N_int = N_int
        self.N_obs = N_obs
        self.noise_std = noise_std

        self.state_dim = state_dim
        self.observational = observational

    def prepare_data(self):
        # GENERATE TRAINING DATA 
        valN = 1000
        if self.observational == 0: #only observational
            self.N_int=0
            self.train = PendulumDataset( 
                N= self.N_obs, 
                observational = True, 
                state_dim = self.state_dim, 
                noise_std = self.noise_std, 
                seed=self.seed, 
                obs_length= self.obs_length,
                )
            self.train_len= len(self.train)
        elif self.observational == 1: #only interventional
            self.N_obs=0
            self.train = PendulumDataset( 
                N= self.N_int, 
                observational = False, 
                state_dim = self.state_dim, 
                noise_std = self.noise_std, 
                seed=self.seed, 
                obs_length= self.obs_length,
                )
            self.train_len= len(self.train)
        else: #both
            self.dataset_Obs = PendulumDataset(
                N= self.N_obs, 
                observational = True, 
                state_dim = self.state_dim, 
                noise_std = self.noise_std, 
                seed=self.seed, 
                obs_length= self.obs_length,
                )
            self.dataset_Int = PendulumDataset(
                N = self.N_int, 
                observational = False, 
                state_dim = self.state_dim, 
                noise_std = self.noise_std, 
                seed=self.seed+1, 
                obs_length= self.obs_length,
                )
            self.train = [self.dataset_Obs, self.dataset_Int]
            self.train_len= len(self.dataset_Obs) +len(self.dataset_Int)

        print("Training with Interventional data {0} kpl and Observational data {1} kpl".format(self.N_int, self.N_obs))
        print("Total train set is size of {0} ".format(self.train_len))

        #GENERATE VALIDATION DATA
        valN = min(1000, self.train_len) #validate with 100o or with train size
        testN = 1000
        self.val= PendulumDataset( 
                N= valN, 
                observational = False, 
                state_dim = self.state_dim, 
                noise_std = self.noise_std, 
                seed=self.seed + 2 , 
                obs_length= self.obs_length,
                )
        print("Validating with {} interventional data".format(len(self.val)))

        if self.batch_size==0:
            self.train_batch_size = self.N_int #len(train_idx)
            self.val_batch_size = valN #len(val_idx)
            self.test_batch_size = valN #len(test_idx)
        else:
            self.train_batch_size = min(self.batch_size, self.train_len)
            self.val_batch_size = min(self.batch_size,valN)
            self.test_batch_size = min(self.batch_size,testN) #test in smaller batches to evaluate loglikelihood


    def seed_worker(self, worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)
    
    def train_dataloader(self):
        g = torch.Generator()
        g.manual_seed(self.seed)
        if self.observational > 1:
            return DataLoader(
                ConcatDataset([self.train[0],self.train[1]]),
                batch_size=self.train_batch_size,
                shuffle=self.train_shuffle,
                num_workers=self.num_workers,
                drop_last=False,
                pin_memory=True,
                #worker_init_fn=self.seed_worker,
                #generator=g,
                persistent_workers=True
            )
        else:
            return DataLoader(
                self.train,
                batch_size=self.train_batch_size,
                shuffle=self.train_shuffle,
                num_workers=self.num_workers,
                drop_last=True,
                pin_memory=True,
                #worker_init_fn=self.seed_worker,
                #generator=g,
                persistent_workers=True
            )

    def val_dataloader(self):
        g = torch.Generator()
        g.manual_seed(self.seed+1)
        return DataLoader(
            self.val,
            batch_size=self.val_batch_size,
            shuffle= False,
            num_workers=self.num_workers,
            drop_last=True,#False
            pin_memory=True,
            #worker_init_fn=self.seed_worker,
            #generator=g,
            persistent_workers=True
        )

    def test_dataloader(self):
        testN = 1000 #int(self.N *0.2)
        self.test = PendulumDataset(
                N= testN, 
                observational =False, 
                state_dim = self.state_dim, 
                noise_std = self.noise_std, 
                seed=self.seed+5, 
                obs_length= self.obs_length,
                #missing_prop=0.2,
                )
        print("Testing with {} interventional data".format(len(self.test)))
        return DataLoader(
            self.test,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
            worker_init_fn=self.seed_worker,
            )

    @classmethod
    def add_dataset_specific_args(cls, parent):
        import argparse
        parser = argparse.ArgumentParser(parents=[parent], add_help=False)
        parser.add_argument('--observational', type=int, default=0, help= "0: only observational: 1: only interventional 2: both")
        parser.add_argument('--seed', type=int, default=42)
        parser.add_argument('--batch_size', type=int, default=128) #128
        parser.add_argument('--N_int', type=int, default=500) 
        parser.add_argument('--N_obs', type=int, default=500) 
        parser.add_argument('--obs_length', type=int, default=30) 
        parser.add_argument('--noise_std', type=float, default=0)
        return parser

if __name__=="__main__":
    datam = PendulumDataModule(batch_size = 32, seed = 42, noise_std = 0., N_ts=1000)
    datam.prepare_data()
