import numpy as np
# from pinnutils import *
from scipy.interpolate import griddata
from itertools import product, combinations
from scipy.io import savemat, loadmat

import math
import scipy.io as sio
# from pinnutils import *
import sys
from tqdm import tqdm_notebook as tqdm
import os
import time as timeit
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32768"
import itertools


lst2 = np.asarray(list(itertools.product([0, 1], repeat=2)), dtype = bool)
lst3 = np.asarray(list(itertools.product([0, 1], repeat=3)), dtype = bool)
lst4 = np.asarray(list(itertools.product([0, 1], repeat=4)), dtype = bool)

def combination(con,nR):
    lst = np.asarray(list(itertools.product([0, 1], repeat=nR)))
    con = con.T
    sum_ = 0;
    for i in range(len(lst)):
        sum_ = sum_ + np.prod(con**lst[i,:],axis=1)
        
    return sum_

def combination_two(con,lst2,truth):
    con = con.T
    sum_ = 0; weight_sum = 0;
    for i in range(len(lst2)):
        temp = (con[:,0]**lst2[i,0])*(con[:,1]**lst2[i,1])
        sum_ = sum_ + temp
        weight_sum = weight_sum + truth[i]*temp

    return  weight_sum, sum_

def combination_three(con,lst3,truth):
    con = con.T
    sum_ = 0; weight_sum = 0;
    for i in range(len(lst3)):
        temp = (con[:,0]**lst3[i,0])*(con[:,1]**lst3[i,1])*(con[:,2]**lst3[i,2])
        sum_ = sum_ + temp
        weight_sum = weight_sum + truth[i]*temp
        
    return  weight_sum, sum_

def combination_four(con,lst4,truth):
    con = con.T
    sum_ = 0; weight_sum = 0;
    for i in range(len(lst4)):
        temp = (con[:,0]**lst4[i,0])*(con[:,1]**lst4[i,1])*(con[:,2]**lst4[i,2])*(con[:,3]**lst4[i,3])
        sum_ = sum_ + temp
        weight_sum = weight_sum + truth[i]*temp
    return  weight_sum, sum_


Gn = {'g1':0,'g2':1,'g3':2,'g4':3,
      'g6':4,'g7':5,'g8':6
     }

def interact_BF(x,Ndim,mvec,kvec,nvec):
    x_ = np.zeros_like(x)
    for i in range(0,Ndim):
        x_[i,:] = (x[i,:]/kvec[i])**nvec[i]
 
    f = np.zeros_like(x);

    # g1 <-- ~(g4 | g6)
    keys = ['g4', 'g6']; arr_ind = np.array([Gn[key] for key in keys])
    truth = np.where(~(lst2[:, 0] | lst2[:, 1]), 1, 0)
    num_, den_ = combination_two(x_[arr_ind,:],lst2,truth)
    f[0,:] = mvec[0]*(num_/den_);

    # g2 <-- g1
    f[1,:] = mvec[0]*x_[Gn['g1'],:]/(1+x_[Gn['g1'],:])
    
    # g3 <-- g2
    f[2,:] = mvec[0]*x_[Gn['g2'],:]/(1+x_[Gn['g2'],:])
    
    # g4 <-- (g3 | g4) & (~g6)
    keys = ['g3', 'g4', 'g6']; arr_ind = np.array([Gn[key] for key in keys])
    truth = np.where( (lst3[:, 0] | lst3[:, 1]) & ~lst3[:,2], 1, 0)
    num_, den_ = combination_three(x_[arr_ind,:],lst3,truth)
    f[3,:] = mvec[0]*(num_/den_);
    
    # g6 <-- (g6 | g3) & (~g4)
    keys = ['g6', 'g3', 'g4']; arr_ind = np.array([Gn[key] for key in keys])
    truth = np.where( (lst3[:, 0] | lst3[:, 1]) & ~lst3[:,2], 1, 0)
    num_, den_ = combination_three(x_[arr_ind,:],lst3,truth)
    f[4,:] = mvec[0]*(num_/den_);
    
    # g7 <-- g4
    f[5,:] = mvec[0]*x_[Gn['g4'],:]/(1+x_[Gn['g4'],:])
    
    # g8 <-- ~g4
    keys = ['g4', 'g4']; arr_ind = np.array([Gn[key] for key in keys])
    truth = np.where(~lst2[:, 0] & ~lst2[:, 1], 1, 0)
    num_, den_ = combination_two(x_[arr_ind,:],lst2,truth)
    f[6,:] = mvec[0]*(num_/den_);
    
    return f

def simulate_BF(nsamples,Nsnaps,Ndim,seed_,maxiter,c,lx,r,lp,mvec, kvec, nvec,nm='add'):
    # P = 0; S = 1; F = 2; E = 3; C = 4;
    dt = 0.001; 
    np.random.seed(seed_)
    xold =  np.random.normal(5,1,(Ndim,nsamples))
    xnew =  xold + 0;

    samples_fullx = np.zeros((Nsnaps,nsamples,Ndim+1))
    
    count = 0; 
    tt = np.zeros((Nsnaps,))
    
    for i in range(0,maxiter-1):
        if((i%int(maxiter/Nsnaps)==0)):
            samples_fullx[count,:,0:Ndim] = xold.T
            samples_fullx[count,:,Ndim]   = dt*i
            tt[count] = dt*i
            count = count + 1;
            
        if(nm=='add'): 
            noisex = np.sqrt(dt)*np.random.normal(0,1,(Ndim,nsamples))
        elif(nm=='mult'): 
            noisex = np.sqrt(xold)*np.sqrt(dt)*np.random.normal(0,1,(Ndim,nsamples))
        else:
            noisex = np.sqrt(interact_BF((r/lp)*xold,Ndim,mvec,kvec,nvec) + lx*xold)*np.sqrt(dt)*np.random.normal(0,1,(Ndim,nsamples))
            
        xnew  = xold + dt*(interact_BF((r/lp)*xold,Ndim,mvec,kvec,nvec) - lx*xold) + c*noisex

        ind = np.where(xnew<0)
        xnew[ind[0],ind[1]] = xold[ind[0],ind[1]] # -(xnew[ind[0],ind[1]])
        
        xold  = xnew + 0; 
        
    print(count,dt*i, " SFI dt ",tt[1] - tt[0])
    
    return samples_fullx, tt

if __name__ == "__main__":
    print( " lst2 ", len(lst2))
    print( " lst3 ", len(lst3))
    print( " lst4 ", len(lst4))
    
    Ndim = 7;
    mvec = 20*np.ones((Ndim,)); 
    kvec = 10*np.ones((Ndim,)); 
    nvec = 10*np.ones((Ndim,));

    dt = 0.001; lp = 1;  r = 10;  lx = 5; 
    # control lx to balance between drift and degradation
    c = np.double(sys.argv[1]); 
    noise_model = str(sys.argv[2]);

    nsamples = 20_000; Nsnaps = 200; seed = 10; maxiter = 4001; 
    samples_full, tt = simulate_BF(nsamples,Nsnaps,Ndim,seed,maxiter,c,lx,r,lp,nm=noise_model);
    dt = tt[1] - tt[0];

    tind = np.arange(0,Nsnaps,1); nsnaps = len(tind);
    print(" nsnaps ", len(tind), " tind ", tt[tind], tind, " dtt ", np.round(tt[tind][1] - tt[tind][0],2))

    samples = samples_full[tind,:,:]
    mdic = {"samples":samples,"tt":tt[tind],"c":c,"maxiter":maxiter,"seed":seed,
           "dt":dt,",lp":lp,"lx":lx,"r":r}
    savemat("longtraj_BF_"+str(noise_model)+"c"+str(c)+"_dim_"+str(Ndim)+".mat", mdic)

    print(" samples_full shape ", samples_full.shape, " SFI dt ", tt[1] - tt[0], tt.shape)
    print(" done with simulations.... ")