from base_classes import ODEblock
import torch
from utils import get_rw_adj, gcn_norm_fill_val
from block_fractional_euler import caputoEuler,caputoEuler_corrector , GL_method,PIEX_method,PIIM_method,PIIM_trap_method,implicit_l1,caputoEuler_memory,product_trap

class ConstantODEblock_FRAC(ODEblock):
  def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])):
    super(ConstantODEblock_FRAC, self).__init__(odefunc, regularization_fns, opt, data, device, t)

    self.aug_dim = 2 if opt['augment'] else 1
    self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
    if opt['data_norm'] == 'rw':
      edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1,
                                                                   fill_value=opt['self_loop_weight'],
                                                                   num_nodes=data.num_nodes,
                                                                   dtype=data.x.dtype)
    else:
      edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, edge_weight=data.edge_attr,
                                           fill_value=opt['self_loop_weight'],
                                           num_nodes=data.num_nodes,
                                           dtype=data.x.dtype)
    self.odefunc.edge_index = edge_index.to(device)
    self.odefunc.edge_weight = edge_weight.to(device)
    # self.reg_odefunc = None
    # self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight

    if opt['adjoint']:
      from torchdiffeq import odeint_adjoint as odeint
    else:
      from torchdiffeq import odeint

    self.train_integrator = odeint
    self.test_integrator = odeint
    self.set_tol()
    self.device = device
    self.opt = opt
  def forward(self, x):
    t = self.t.type_as(x)

    integrator = self.train_integrator if self.training else self.test_integrator
    
    # reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) )

    # func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc
    # state = (x,) + reg_states if self.training and self.nreg > 0 else x

    func = self.odefunc
    state = x

    if "graphconterm" in self.opt['function']:
        gamma = 0.5 / self.opt['num_terms']
        print("gamma: ", gamma)

        alpha = torch.tensor(gamma)


    elif "term" in self.opt['function']:
        gamma = 1 / self.opt['num_terms']
      #
        print("gamma: ", gamma)

        alpha = torch.tensor(gamma)
    else:
        alpha = torch.tensor(self.opt['alpha_ode'])

    # if alpha > 1:
    #     raise ValueError("alpha_ode must be in (0,1)")

    if alpha > 1:
        state = state + x
        if self.opt['method'] == "ceuler":
            z = caputoEuler(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "ceuler_corrector":
            z = caputoEuler_corrector(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "GL":
            z = GL_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "implicit":
            z = implicit_l1(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "trap":
            z = product_trap(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        else:
            raise ValueError("Method not implemented")

    else:
        if self.opt['method'] == "ceuler":
            z = caputoEuler(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "ceuler_corrector":
            z = caputoEuler_corrector(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "GL":
            z = GL_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "implicit":
            z = implicit_l1(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "memory":
            z = caputoEuler_memory(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device,memory_k =self.opt['memory_k'])
        elif self.opt['method'] == "PIEX":
            z = PIEX_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "PIIM":
            z = PIIM_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "PIIM_trap":
            z = PIIM_trap_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        elif self.opt['method'] == "trap":
            z = product_trap(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
        else:
            raise ValueError("Method not implemented")
    # if self.opt["adjoint"] and self.training:
    #   state_dt = integrator(
    #     func, state, t,
    #     method=self.opt['method'],
    #     options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']),
    #     adjoint_method=self.opt['adjoint_method'],
    #     adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']),
    #     atol=self.atol,
    #     rtol=self.rtol,
    #     adjoint_atol=self.atol_adjoint,
    #     adjoint_rtol=self.rtol_adjoint)
    # else:
    #   state_dt = integrator(
    #     func, state, t,
    #     method=self.opt['method'],
    #     options=dict(step_size=self.opt['step_size'], ),
    #     atol=self.atol,
    #     rtol=self.rtol)
    #
    # if self.training and self.nreg > 0:
    #   z = state_dt[0][1]
    #   reg_states = tuple( st[1] for st in state_dt[1:] )
    #   return z, reg_states
    # else:
    #   z = state_dt[1]
    return z

  def __repr__(self):
    return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
           + ")"
