from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
    Informer, LightTS, Pyraformer, PatchTST, MICN, FiLM, iTransformer, \
    Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, TemporalFusionTransformer, SCINet, TimeXer, \
    Linear, LSTM, RLinear, FourierGNN, StemGNN, GWNet, WPMixer, GRU, CrossGNN, GTS, FlowNet, \
    WITRAN, RLinear_sub, STGCN, FlowNet_i, TGCN, FlowNet_D, DMT, GCN, GCNII, ResGCN, ResGAT, AGCLSTM, AGCLSTM_revin, \
    CATS, CycleNet, TQNet, FilterNet, SOFTS, DMCT
from utils.constants import Constants


class Exp_Basic(object):
    def __init__(self, args, phase, verbose=True, device='cpu'):
        self.args = args
        self.constants = Constants(args)
        self.target_station = self.args.data_path.split('.csv')[0]
        self.all_stations = self.constants.all_stations
        # set number of channels
        self.station_channels_dict = self.constants.station_channels_dict

        self.distance_adj_matrix = self.constants.distance_adj_matrix
        self.verbose = verbose
        self.device = device
        self.phase = phase
        self.model_dict = {
            'TimesNet': TimesNet,
            'Autoformer': Autoformer,
            'Transformer': Transformer,
            'Nonstationary_Transformer': Nonstationary_Transformer,
            'DLinear': DLinear,
            'FEDformer': FEDformer,
            'Informer': Informer,
            'LightTS': LightTS,
            'PatchTST': PatchTST,
            'Pyraformer': Pyraformer,
            'MICN': MICN,
            'FiLM': FiLM,
            'iTransformer': iTransformer,
            'Koopa': Koopa,
            'TiDE': TiDE,
            'FreTS': FreTS,
            'TimeMixer': TimeMixer,
            'TSMixer': TSMixer,
            'SegRNN': SegRNN,
            'TemporalFusionTransformer': TemporalFusionTransformer,
            "SCINet": SCINet,
            'TimeXer': TimeXer,
            'Linear': Linear,
            'LSTM': LSTM,
            'RLinear': RLinear,
            'FourierGNN': FourierGNN,
            'StemGNN': StemGNN,
            'GWNet': GWNet,
            'WPMixer': WPMixer,
            'GRU': GRU,
            'CrossGNN': CrossGNN,
            'GTS': GTS,
            'FlowNet': FlowNet,
            'WITRAN': WITRAN,
            'RLinear_sub': RLinear_sub,
            'STGCN': STGCN,
            'FlowNet_i': FlowNet_i,
            'TGCN': TGCN,
            'FlowNet_D': FlowNet_D,
            'DMT': DMT,
            'GCN': GCN,
            'GCNII': GCNII,
            'ResGCN': ResGCN,
            'ResGAT': ResGAT,
            'AGCLSTM': AGCLSTM,
            'CATS': CATS,
            'CycleNet': CycleNet,
            'TQNet': TQNet,
            'FilterNet': FilterNet,
            'SOFTS': SOFTS,
            'DMCT': DMCT,
        }
        self.cycle_model_list = ['CycleNet', 'TQNet']

    def _build_model(self):
        raise NotImplementedError
        return None

    def _get_data(self):
        pass

    def vali(self):
        pass

    def train(self):
        pass

    def test(self):
        pass
