from base_classes import ODEblock
import torch
from utils import get_rw_adj, gcn_norm_fill_val
from block_fractional_euler import caputoEuler,caputoEuler_corrector , GL_method,PIEX_method,PIIM_method,PIIM_trap_method
from function_transformer_attention import SpGraphTransAttentionLayer
class AttODEblock_GRAPH(ODEblock):
  def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])):
    super(AttODEblock_GRAPH, self).__init__(odefunc, regularization_fns, opt, data, device, t)

    self.aug_dim = 2 if opt['augment'] else 1
    self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)

    # self.reg_odefunc = None
    # self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight

    if opt['adjoint']:
      from torchdiffeq import odeint_adjoint as odeint
    else:
      from torchdiffeq import odeint

    self.train_integrator = odeint
    self.test_integrator = odeint
    self.set_tol()
    self.device = device
    self.opt = opt
    self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt,
                                                          device, edge_weights=None).to(device)

  def get_attention_weights(self, x):
    attention, values = self.multihead_att_layer(x, self.odefunc.edge_index)
    return attention
  def forward(self, x,edge_index,edge_weight):
    t = self.t.type_as(x)


    integrator = self.train_integrator if self.training else self.test_integrator

    if self.opt['data_norm'] == 'rw':
      edge_index, edge_weight = get_rw_adj(edge_index, edge_weight=edge_index, norm_dim=1,
                                                                   fill_value=self.opt['self_loop_weight'],
                                                                   num_nodes=x.shape[0],
                                                                   dtype=x.dtype)
    else:
      edge_index, edge_weight = gcn_norm_fill_val(edge_index, edge_weight=edge_weight,
                                           fill_value=self.opt['self_loop_weight'],
                                           num_nodes=x.shape[0],
                                           dtype=x.dtype)
    self.odefunc.edge_index = edge_index.to(self.device)
    self.odefunc.edge_weight = edge_weight.to(self.device)

    self.odefunc.attention_weights = self.get_attention_weights(x)
    self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights

    
    # reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) )

    # func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc
    # state = (x,) + reg_states if self.training and self.nreg > 0 else x

    func = self.odefunc
    state = x

    alpha = torch.tensor(self.opt['alpha_ode'])

    # if alpha > 1:
    #     raise ValueError("alpha_ode must be in (0,1)")

    if self.opt['method'] == "ceuler":
        z = caputoEuler(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
    elif self.opt['method'] == "ceuler_corrector":
        z = caputoEuler_corrector(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
    elif self.opt['method'] == "GL":
        z = GL_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
    elif self.opt['method'] == "PIEX":
        z = PIEX_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
    elif self.opt['method'] == "PIIM":
        z = PIIM_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
    elif self.opt['method'] == "PIIM_trap":
        z = PIIM_trap_method(alpha,func, state, tspan= torch.arange(0,self.opt['time'],self.opt['step_size']),device=self.device)
    else:
        raise ValueError("Method not implemented")
    # if self.opt["adjoint"] and self.training:
    #   state_dt = integrator(
    #     func, state, t,
    #     method=self.opt['method'],
    #     options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']),
    #     adjoint_method=self.opt['adjoint_method'],
    #     adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']),
    #     atol=self.atol,
    #     rtol=self.rtol,
    #     adjoint_atol=self.atol_adjoint,
    #     adjoint_rtol=self.rtol_adjoint)
    # else:
    #   state_dt = integrator(
    #     func, state, t,
    #     method=self.opt['method'],
    #     options=dict(step_size=self.opt['step_size'], ),
    #     atol=self.atol,
    #     rtol=self.rtol)
    #
    # if self.training and self.nreg > 0:
    #   z = state_dt[0][1]
    #   reg_states = tuple( st[1] for st in state_dt[1:] )
    #   return z, reg_states
    # else:
    #   z = state_dt[1]
    return z

  def __repr__(self):
    return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
           + ")"
