import torch
import torch.nn as nn


from easytsf.layer.BnKanLayer import BnKANLayer
from easytsf.layer.Bn2KanLayer import Bn2KANLayer

class KANInterface(nn.Module):
    # 导入各种KAN层
    def __init__(self, in_features, out_features, layer_type, n_grid, degree, order, n_center, Bn=3):
        super(KANInterface, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        if layer_type == "Linear":
            print("Using Linear")
            self.transform = nn.Linear(in_features, out_features, bias=True)
        elif layer_type == "BnKAN":  # 伯恩斯坦KAN
            print("Using BnKAN")
            self.transform = BnKANLayer(in_features, out_features, num=n_grid, Bn=Bn)
        elif layer_type == "Bn2KAN":  # 伯恩斯坦KAN
            print("Using Bn2KAN")
            self.transform = Bn2KANLayer(in_features, out_features, num=n_grid, Bn=Bn)
        else:
            raise NotImplementedError(f"Layer type {layer_type} not implemented")

    def forward(self, x, mode=None):
        if mode == 'moe':
            return self.transform(x)
        if len(x.shape) == 3:
            B, N, L = x.shape
            x = x.reshape(B * N, L)
        return self.transform(x).reshape(B, N, self.out_features)
