from models import DLinear, LightTS, PatchTST, MICN, FiLM, iTransformer, \
    TiDE, FreTS, TimeMixer, TSMixer, SegRNN, TimeXer, LSTM, RLinear, FourierGNN, WPMixer, GRU, CrossGNN, \
    WITRAN, TGCN, DMT, GCN, GCNII, ResGCN, ResGAT, AGCLSTM, THGNN, CATS, DMCT, TQNet, CycleNet, CMCT, \
    TimeBridge, FilterNet, Amplifier, DFCT, SOFTS
from utils.constants_LamaH import Constants


class Exp_Basic(object):
    def __init__(self, args, station, setting):
        self.args = args
        self.constants = Constants(args)
        self.target_station = station
        self.setting = setting
        self.all_stations = self.constants.all_stations
        # set number of channels
        self.station_channels_dict = self.constants.station_channels_dict
        self.distance_adj_matrix = self.constants.distance_adj_matrix
        self.verbose = self.args.verbose
        self.device = self.args.device
        self.run_type = self.args.run_type
        self.model_dict = {
            'DLinear': DLinear,
            'LightTS': LightTS,
            'PatchTST': PatchTST,
            'MICN': MICN,
            'FiLM': FiLM,
            'iTransformer': iTransformer,
            'TiDE': TiDE,
            'FreTS': FreTS,
            'TimeMixer': TimeMixer,
            'TSMixer': TSMixer,
            'SegRNN': SegRNN,
            'TimeXer': TimeXer,
            'LSTM': LSTM,
            'RLinear': RLinear,
            'FourierGNN': FourierGNN,
            'WPMixer': WPMixer,
            'GRU': GRU,
            'CrossGNN': CrossGNN,
            'WITRAN': WITRAN,
            'TGCN': TGCN,
            'DMT': DMT,
            'GCN': GCN,
            'GCNII': GCNII,
            'ResGCN': ResGCN,
            'ResGAT': ResGAT,
            'AGCLSTM': AGCLSTM,
            'THGNN': THGNN,
            'CATS': CATS,
            'DMCT': DMCT,
            'TQNet': TQNet,
            'CycleNet': CycleNet,
            'CMCT': CMCT,
            'TimeBridge': TimeBridge,
            'FilterNet': FilterNet,
            'Amplifier': Amplifier,
            'DFCT': DFCT,
            'SOFTS': SOFTS,
        }
        self.GNN_list = [
            'FourierGNN', 'CrossGNN', 'TGCN', 'GCN', 'GCNII', 'ResGCN', 'ResGAT', 'THGNN', 'AGCLSTM'
        ]
        self.cycle_model_list = ['CycleNet', 'TQNet', 'CMCT']

    def _build_model(self):
        raise NotImplementedError
        return None

    def _get_data(self):
        pass

    def vali(self):
        pass

    def train(self):
        pass

    def test(self):
        pass
