# encoding:utf-8
# encoding:utf-8
import os
import sys
import torch.optim as optim
from model import *
from model import sthsl
sys.path.append(os.getcwd())
from lib.utils import *
from lib.base_model import *
class trainer(BasicTrainer):
    def __init__(self,scaler, lrate, wdecay,device,
                 in_dim,seq_len,horizon,num_nodes,
                 supports,hdim,dropout,gcn_bool,
                 addaptadj,aptinit,ou_dim=1,order=3):
        self.model = sthsl(device=device,
                           num_nodes=num_nodes,
                           dropout=dropout,
                           supports=supports,
                           in_dim=in_dim,
                           out_dim=horizon,
                           seq_len=seq_len)
        super(trainer,self).__init__(self.model,scaler, lrate, wdecay,device)



