from base_classes import ODEblock
import torch
from utils import get_rw_adj, gcn_norm_fill_val
from block_fractional_euler import caputoEuler,caputoEuler_corrector , GL_method,PIEX_method,PIIM_method,PIIM_trap_method,implicit_l1,caputoEuler_memory

class ConstantODEblock_GRAPH(ODEblock):
  def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])):
    super(ConstantODEblock_GRAPH, self).__init__(odefunc, regularization_fns, opt, data, device, t)

    self.aug_dim = 2 if opt['augment'] else 1
    self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)

    # self.reg_odefunc = None
    # self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight

    if opt['adjoint']:
      from torchdiffeq import odeint_adjoint as odeint
    else:
      from torchdiffeq import odeint

    self.train_integrator = odeint
    self.test_integrator = odeint
    self.set_tol()
    self.device = device
    self.opt = opt
  def forward(self, x,edge_index,edge_weight=None):
    t = self.t.type_as(x)

    integrator = self.train_integrator if self.training else self.test_integrator

    # if self.opt['data_norm'] == 'rw':
    #   edge_index, edge_weight = get_rw_adj(edge_index, edge_weight=edge_index, norm_dim=1,
    #                                                                fill_value=self.opt['self_loop_weight'],
    #                                                                num_nodes=x.shape[0],
    #                                                                dtype=x.dtype)
    # else:
    edge_index, edge_weight = gcn_norm_fill_val(edge_index, edge_weight=edge_weight,
                                           fill_value=self.opt['self_loop_weight'],
                                           num_nodes=x.shape[0],
                                           dtype=x.dtype)

    # if edge_weight is None:
    #     edge_weight = torch.ones((edge_index.size(1),), dtype=x.dtype,
    #                              device=edge_index.device)
    self.odefunc.edge_index = edge_index.to(self.device)
    self.odefunc.edge_weight = edge_weight.to(self.device)

    
    # reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) )

    # func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc
    # state = (x,) + reg_states if self.training and self.nreg > 0 else x

    func = self.odefunc
    state = x

    if "graphconterm" in self.opt['function']:
        gamma = 0.5 / self.opt['num_terms']
        # print("gamma: ", gamma)

        alpha = torch.tensor(gamma)


    elif "term" in self.opt['function']:
        gamma = 1 / self.opt['num_terms']
        #
        # print("gamma: ", gamma)

        alpha = torch.tensor(gamma)
    else:
        alpha = torch.tensor(self.opt['alpha_ode'])

    # if alpha > 1:
    #     raise ValueError("alpha_ode must be in (0,1)")

    if alpha > 1:
        state = state + x
        if self.opt['method'] == "ceuler":
            z = caputoEuler(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                            device=self.device)
        elif self.opt['method'] == "ceuler_corrector":
            z = caputoEuler_corrector(alpha, func, state,
                                      tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                                      device=self.device)
        elif self.opt['method'] == "GL":
            z = GL_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                          device=self.device)
        elif self.opt['method'] == "implicit":
            z = implicit_l1(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                            device=self.device)
        else:
            raise ValueError("Method not implemented")

    else:
        if self.opt['method'] == "ceuler":
            z = caputoEuler(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                            device=self.device)
        elif self.opt['method'] == "ceuler_corrector":
            z = caputoEuler_corrector(alpha, func, state,
                                      tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                                      device=self.device)
        elif self.opt['method'] == "GL":
            z = GL_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                          device=self.device)
        elif self.opt['method'] == "implicit":
            z = implicit_l1(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                            device=self.device)
        elif self.opt['method'] == "memory":
            z = caputoEuler_memory(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                                   device=self.device)
        elif self.opt['method'] == "PIEX":
            z = PIEX_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                            device=self.device)
        elif self.opt['method'] == "PIIM":
            z = PIIM_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                            device=self.device)
        elif self.opt['method'] == "PIIM_trap":
            z = PIIM_trap_method(alpha, func, state, tspan=torch.arange(0, self.opt['time'], self.opt['step_size']),
                                 device=self.device)
        elif self.opt['method'] in ["euler","rk4"]:
            state_dt = integrator(
                func, state, t,
                method=self.opt['method'],
                options=dict(step_size=self.opt['step_size'], ),
                atol=self.atol,
                rtol=self.rtol)
            z = state_dt[1]
        else:
            raise ValueError("Method not implemented")
    # if self.training and self.nreg > 0:
    #   z = state_dt[0][1]
    #   reg_states = tuple( st[1] for st in state_dt[1:] )
    #   return z, reg_states
    # else:
    #   z = state_dt[1]
    return z

  def __repr__(self):
    return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
           + ")"
