import torch
import torch.nn as nn
import einops
import copy
from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline


class AverageSeq2Seq(nn.Module):
    def __init__(self, seq_len:int, pred_len: int):
        super(AverageSeq2Seq, self).__init__()
        self.seq_len = seq_len
        self.pred_len = pred_len
        
    def forward(self, x: torch.FloatTensor):
        output = x.mean(dim=-1).unsqueeze(-1).expand(x.shape[0], -1, self.pred_len)
        return output


class MA(nn.Module):
    def __init__(self, kernel_size, stride):
        super(MA, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, self.kernel_size // 2 - 1, 1)
        end = x[:, -1:, :].repeat(1, self.kernel_size // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class Model(nn.Module):
    def reshape_reconstruct(self, x):
        batch_size, n_channel, _ = x.shape
        x_period = x.reshape(batch_size, n_channel, -1, self.cycle)
        weight = (0.9 ** torch.arange(x_period.shape[2], 0, -1, device=x.device).float()).unsqueeze(0).unsqueeze(0).unsqueeze(-1); weight = weight / weight.sum() * x_period.shape[2]
        x_period = torch.mean(x_period * weight, 2)
        pred = einops.repeat(x_period, 'B C point_per_period -> B C (num_period point_per_period)', num_period=self.pred_len // self.cycle + 1)
        return pred[:, :, :self.pred_len]
    
    def __init__(self, configs):
        super(Model, self).__init__()
        
        self.model_type = configs.model_type
        self.cycle = configs.cycle
        self.sample_rate = self.cycle // 4
        self.seq_len: int = configs.seq_len
        self.label_len: int = configs.label_len
        self.pred_len: int = configs.pred_len
        self.d_model: int = configs.d_model
        self.use_revin = configs.use_revin
        self.hist_trend_len = self.seq_len // self.sample_rate
        self.infer_trend_len = self.pred_len // self.sample_rate + 2
        assert self.seq_len % self.cycle == 0, f'seq_len: {self.seq_len}, cycle: {self.cycle}'
        
        self.trend_extractor = MA(kernel_size=self.cycle // 2, stride=1)
        
        self.valid_model_types = ['moirai', 'timer', 'timesfm', 'timemoe', 'chronos', 'visionts', 'toto']
        assert self.model_type in self.valid_model_types, f"Model type {self.model_type} not supported. Choose from {self.valid_model_types}"
        if self.model_type == 'moirai':
            from models.MOIRAI import Model as Moirai
            moirai_config = copy.deepcopy(configs)
            moirai_config.seq_len = self.hist_trend_len
            moirai_config.pred_len = self.infer_trend_len
            moirai_config.enable_plot = False
            self.model = Moirai(moirai_config)
        elif self.model_type == 'timer':
            from models.Timer import Model as Timer
            timer_config = copy.deepcopy(configs)
            timer_config.seq_len = self.hist_trend_len
            timer_config.pred_len = self.infer_trend_len
            timer_config.enable_plot = False
            self.model = Timer(timer_config)
        elif self.model_type == 'timesfm':
            from models.TimesFM import Model as TimesFM
            timesfm_config = copy.deepcopy(configs)
            timesfm_config.seq_len = self.hist_trend_len
            timesfm_config.pred_len = self.infer_trend_len
            timesfm_config.enable_plot = False
            self.model = TimesFM(timesfm_config)
        elif self.model_type == 'timemoe':
            from models.TimeMoE import Model as TimeMoE
            timemoe_config = copy.deepcopy(configs)
            timemoe_config.seq_len = self.hist_trend_len
            timemoe_config.pred_len = self.infer_trend_len
            timemoe_config.enable_plot = False
            self.model = TimeMoE(timemoe_config)
        elif self.model_type == 'chronos':
            from models.Chronos import Model as Chronos
            chronos_config = copy.deepcopy(configs)
            chronos_config.seq_len = self.hist_trend_len
            chronos_config.pred_len = self.infer_trend_len
            chronos_config.enable_plot = False
            self.model = Chronos(chronos_config)
        elif self.model_type == 'visionts':
            from models.VisionTS import Model as VisionTS
            visionts_config = copy.deepcopy(configs)
            visionts_config.seq_len = self.hist_trend_len
            visionts_config.pred_len = self.infer_trend_len
            visionts_config.enable_plot = False
            self.model = VisionTS(visionts_config)
        elif self.model_type == 'toto':
            from models.Toto import Model as Toto
            toto_config = copy.deepcopy(configs)
            toto_config.seq_len = self.hist_trend_len
            toto_config.pred_len = self.infer_trend_len
            toto_config.enable_plot = False
            self.model = Toto(toto_config)

    def forecast(self, x, cycle_len):
        # x: (batch_size, seq_len, enc_in), cycle_index: (batch_size,)
        x_index = torch.arange(-self.seq_len, 0, dtype=x.dtype, device=x.device)
        y_index = torch.arange(-self.label_len, self.pred_len, dtype=x.dtype, device=x.device)

        if self.use_revin:
            seq_mean = torch.mean(x, dim=1, keepdim=True)
            seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
            x = (x - seq_mean) / torch.sqrt(seq_var)
        
        trend = self.trend_extractor(x)
        seasonal_hist = x - trend
        seasonal_pred = self.reshape_reconstruct(seasonal_hist.permute(0, 2, 1)).permute(0, 2, 1)  # B, L, C

        trend = trend[:, ::self.sample_rate]
        if self.model_type in self.valid_model_types:
            y = self.model(trend)
        else:
            y = self.model(trend.permute(0, 2, 1)).permute(0, 2, 1)
        x_support = torch.cat((trend, y), 1)
        index_support = torch.cat((x_index[::self.sample_rate], torch.arange(0, self.infer_trend_len, dtype=x.dtype, device=x.device) * self.sample_rate), -1)
        x_spline = NaturalCubicSpline(natural_cubic_spline_coeffs(index_support, x_support))
        y = x_spline.evaluate(y_index)

        y = y + seasonal_pred

        if self.use_revin:
            y = y * torch.sqrt(seq_var) + seq_mean

        return y

    def forward(self, x, cycle):
        return self.forecast(x, cycle)

