__all__ = ['SPM_Net']
from torch import nn
from layers.SPM_backbone import SPM_backbone



class Model(nn.Module):
    def __init__(self, configs):
        
        super().__init__()
        
        # load parameters
        c_in = configs.enc_in
        context_window = configs.seq_len
        target_window = configs.pred_len
        revin = configs.revin
        affine = configs.affine
        subtract_last = configs.subtract_last
        # for memory
        ep_topk = configs.ep_topk
        num_hard_example = configs.hard_num
        ep_mem_num = configs.ep_mem_num
        mem_num = configs.mem_num
        gamma = configs.gamma
        substitution = configs.substitution

        self.model = SPM_backbone(c_in=c_in, context_window = context_window, target_window=target_window, revin=revin, affine=affine,
                              subtract_last=subtract_last, ep_topk=ep_topk, num_hard_example=num_hard_example, ep_mem_num=ep_mem_num,
                              mem_num=mem_num, gamma=gamma, substitution=substitution)

    def forward(self, x, y, flag='not train'):           # x: [Batch, Input length, Channel]
        x = x.permute(0,2,1)    
        x, query, pos, neg = self.model(x, y, flag)
        x = x.permute(0,2,1)   
        return x, query, pos, neg
