from function_transformer_attention import ODEFuncTransformerAtt
from function_GAT_attention import ODEFuncAtt
from function_laplacian_diffusion import LaplacianODEFunc
from block_transformer_attention import AttODEblock
from block_constant import ConstantODEblock
from block_mixed import MixedODEblock
from block_transformer_hard_attention import HardAttODEblock
from block_transformer_rewiring import RewireAttODEblock

from function_laplacian_convection import ODEFuncLapCONV

from function_GAT_convection import ODEFuncAttConv

from function_transformer_convection import ODEFuncTransConv
from block_transformer_attention import AttODEblock_PLOT

from block_constant_fractional import ConstantODEblock_FRAC
from block_transformer_fractional import AttODEblock_FRAC

from block_constant_graph import ConstantODEblock_GRAPH
from block_attention_graph import AttODEblock_GRAPH

from function_transformer_graphcon import ODEFuncTransformerAtt_graphcon
from function_GAT_graphcon import ODEFuncAtt_graphcon

from function_laplacian_multiterm import LaplacianODEFunc_multiterm
from block_constant_fractional_multiterm import ConstantODEblock_FRAC_TERM
from function_transformer_multiterm import ODEFuncTransformerAtt_multiterm
from block_constant_graph_terms import ConstantODEblock_GRAPH_TERMS

from function_GAT_convection_multiterm import ODEFuncAttConv_CDE_term
from function_laplacian_graphcon_multiterm import LaplacianODEFunc_graphcon_terms
from function_GAT_graphcon_multiterm import ODEFuncAtt_graphcon_terms
from function_transformer_graphcon_multiterm import ODEFuncTransformerAtt_graphcon_terms
from block_constant_fractional_order import ConstantODEblock_FRAC_MULTI_ORDER


class BlockNotDefined(Exception):
  pass

class FunctionNotDefined(Exception):
  pass


def set_block(opt):
  ode_str = opt['block']
  if ode_str == 'mixed':
    block = MixedODEblock
  elif ode_str == 'attention':
    block = AttODEblock
  elif ode_str == 'hard_attention':
    block = HardAttODEblock
  elif ode_str == 'rewire_attention':
    block = RewireAttODEblock
  elif ode_str == 'constant':
    block = ConstantODEblock
  elif ode_str == 'attplot':
    block = AttODEblock_PLOT
  elif ode_str == 'constant_frac':
    block = ConstantODEblock_FRAC
  elif ode_str == 'att_frac':
    block = AttODEblock_FRAC

  elif ode_str == 'constant_graph':
    block = ConstantODEblock_GRAPH
  elif ode_str == 'att_graph':
    block = AttODEblock_GRAPH
 
  elif ode_str == 'constant_term':
    block = ConstantODEblock_FRAC_TERM
  elif ode_str == 'constantgraph_term':
    block = ConstantODEblock_GRAPH_TERMS
  elif ode_str == 'constant_fracorder':
    block = ConstantODEblock_FRAC_MULTI_ORDER


  else:
    raise BlockNotDefined
  return block


def set_function(opt):
  ode_str = opt['function']
  if ode_str == 'laplacian':
    f = LaplacianODEFunc
  elif ode_str == 'GAT':
    f = ODEFuncAtt
  elif ode_str == 'transformer':
    f = ODEFuncTransformerAtt
  

  elif ode_str == 'lapconv':
    f = ODEFuncLapCONV
  
  elif ode_str == 'gatconv':
    f = ODEFuncAttConv
 
  elif ode_str == 'transconv':
    f = ODEFuncTransConv
  
  elif ode_str == 'transgraphcon':
    f = ODEFuncTransformerAtt_graphcon
  elif ode_str == 'gatgraphcon':
    f = ODEFuncAtt_graphcon
  
  elif ode_str == 'lapterm':
    f = LaplacianODEFunc_multiterm
  elif ode_str == 'gatconvterm':
    f = ODEFuncAttConv_CDE_term
  elif ode_str == 'transterm':
    f = ODEFuncTransformerAtt_multiterm
  
  elif ode_str == 'lapgraphconterm':
    f = LaplacianODEFunc_graphcon_terms
  elif ode_str == 'gatgraphconterm':
    f = ODEFuncAtt_graphcon_terms
  elif ode_str == 'transgraphconterm':
    f = ODEFuncTransformerAtt_graphcon_terms
  



  else:
    raise FunctionNotDefined
  return f
