import torch
import copy
import torch.nn as nn



class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        self.cycle_len = configs.cycle
        self.model_type = configs.model_type
        self.d_model = configs.d_model
        self.use_revin = configs.use_revin

        self.valid_model_types = ['moirai', 'timer', 'timesfm', 'timemoe', 'chronos', 'visionts', 'toto']
        assert self.model_type in self.valid_model_types, f"Model type {self.model_type} not supported. Choose from {self.valid_model_types}"
        if self.model_type == 'moirai':
            from models.MOIRAI import Model as Moirai
            moirai_config = copy.deepcopy(configs)
            moirai_config.enable_plot = False
            self.model = Moirai(moirai_config)
        elif self.model_type == 'timer':
            from models.Timer import Model as Timer
            timer_config = copy.deepcopy(configs)
            timer_config.enable_plot = False
            self.model = Timer(timer_config)
        elif self.model_type == 'timesfm':
            from models.TimesFM import Model as TimesFM
            timesfm_config = copy.deepcopy(configs)
            timesfm_config.enable_plot = False
            self.model = TimesFM(timesfm_config)
        elif self.model_type == 'timemoe':
            from models.TimeMoE import Model as TimeMoE
            timemoe_config = copy.deepcopy(configs)
            timemoe_config.enable_plot = False
            self.model = TimeMoE(timemoe_config)
        elif self.model_type == 'chronos':
            from models.Chronos import Model as Chronos
            chronos_config = copy.deepcopy(configs)
            chronos_config.enable_plot = False
            self.model = Chronos(chronos_config)
        elif self.model_type == 'visionts':
            from models.VisionTS import Model as VisionTS
            visionts_config = copy.deepcopy(configs)
            visionts_config.enable_plot = False
            self.model = VisionTS(visionts_config)
        elif self.model_type == 'toto':
            from models.Toto import Model as Toto
            toto_config = copy.deepcopy(configs)
            toto_config.enable_plot = False
            self.model = Toto(toto_config)

    def forward(self, x, cycle_index, x_mark=None, dec_inp=None, y_mark=None):
        B, _, C = x.shape
        if self.use_revin:
            seq_mean = torch.mean(x, dim=1, keepdim=True)
            seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
            x = (x - seq_mean) / torch.sqrt(seq_var)

        if self.model_type in self.valid_model_types:
            y = self.model(x)
        else:
            y = self.model(x.permute(0, 2, 1)).permute(0, 2, 1)

        if self.use_revin:
            y = y * torch.sqrt(seq_var) + seq_mean
        return y
    