import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, configs, d_model=None):
        super(MLP, self).__init__()
        self.drop = nn.Dropout(p=configs.dropout)
        if d_model is None:
            d_model = configs.d_model
        self.act = F.relu
        self.fc1 = nn.Linear(d_model, configs.d_ff)
        self.fc2 = nn.Linear(configs.d_ff, d_model)

    def forward(self, x):
        h = self.fc2(self.drop(self.act(self.fc1(x))))  # MLP
        return x + h


class Model(nn.Module):
    def __init__(self, configs, revin=None):
        super(Model, self).__init__()

        self.revin = configs.revin if revin is None else revin  # long-term with temporal

        self.c_in = configs.enc_in
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len

        self.time_stamp = configs.time_stamp
        self.channel_mark = configs.channel_mark
        self.steps_per_day = configs.steps_per_day
        self.minute = configs.minute

        self.c_dim = configs.c_dim
        self.t_dim = configs.t_dim
        self.month = configs.month
        self.time_type = configs.time_type
        self.encoder = nn.Linear(self.seq_len, configs.d_model)

        if 'rlt' in self.time_type:
            assert self.steps_per_day is not None
            self.rlt_m_emb = nn.Parameter(torch.zeros(self.steps_per_day, configs.t_dim))
            self.rlt_w_emb = nn.Parameter(torch.zeros(7, configs.t_dim))
        else:
            self.mi_emb = nn.Parameter(torch.zeros(60 // self.minute, configs.t_dim))
            self.hr_emb = nn.Parameter(torch.zeros(24, configs.t_dim))
            self.dw_emb = nn.Parameter(torch.zeros(7, configs.t_dim))
            self.dm_emb = nn.Parameter(torch.zeros(31, configs.t_dim))
            self.mo_emb = nn.Parameter(torch.zeros(12, configs.t_dim))
        self.noc_emb = nn.Parameter(torch.zeros(self.c_in, self.c_dim))

        c = 1 if self.channel_mark else 0
        t = 1 if self.time_stamp else 0
        d_model = configs.d_model + self.c_dim * c + self.t_dim * t
        self.decoder = nn.Sequential(*[
            MLP(configs, d_model=d_model)
            for _ in range(configs.t_layers)
        ]) if configs.t_layers >= 1 else nn.Identity()
        self.predictor = nn.Linear(d_model, self.pred_len, bias=False)

    def forecast(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None):
        b, t, c = x_enc.shape

        if x_mark_enc is None:
            x_mark_enc = torch.zeros((b, t, 6), device=x_enc.device)

        if self.revin:
            mean, std = (x_enc.mean(1, keepdim=True).detach(),
                         x_enc.std(1, keepdim=True).detach())
            x_enc = ((x_enc - mean) / (std + 1e-5))

        t_enc = x_enc = x_enc.transpose(-1, -2)
        t_enc = self.encoder(t_enc)

        if self.time_stamp:
            if self.time_type == 'rlt':
                time_stamp = self.rlt_time_stamp(x_enc, x_mark_enc)
                time_stamp = time_stamp.unsqueeze(1).repeat(1, c, 1)
                t_enc = torch.cat([t_enc, time_stamp], dim=-1)
            else:
                time_stamp = self.abs_time_stamp(x_enc, x_mark_enc)
                time_stamp = time_stamp.unsqueeze(1).repeat(1, c, 1)
                t_enc = torch.cat([t_enc, time_stamp], dim=-1)

        if self.channel_mark:
            noc_emb = self.noc_emb.unsqueeze(0).repeat(b, 1, 1)
            t_enc = torch.cat([t_enc, noc_emb], dim=-1)

        t_enc = self.decoder(t_enc)
        t_enc = self.predictor(t_enc)

        t_enc = t_enc.transpose(-1, -2)

        if self.revin:
            t_enc = t_enc * std + mean
        return t_enc

    def index_fusion(self, x_enc, x_mark_enc):
        b, c, t = x_enc.shape
        t_enc = self.encoder(x_enc)
        if self.time_stamp:
            if self.time_type == 'rlt':
                time_stamp = self.rlt_time_stamp(x_enc, x_mark_enc)
                time_stamp = time_stamp.unsqueeze(1).repeat(1, c, 1)
                t_enc = torch.cat([t_enc, time_stamp], dim=-1)
            else:
                time_stamp = self.abs_time_stamp(x_enc, x_mark_enc)
                time_stamp = time_stamp.unsqueeze(1).repeat(1, c, 1)
                t_enc = torch.cat([t_enc, time_stamp], dim=-1)

        if self.channel_mark:
            noc_emb = self.noc_emb.unsqueeze(0).repeat(b, 1, 1)
            t_enc = torch.cat([t_enc, noc_emb], dim=-1)

        t_enc = self.decoder(t_enc)
        t_enc = self.predictor(t_enc)
        return t_enc

    def rlt_time_stamp(self, x_enc, x_mark_enc):
        m_idx = (x_mark_enc[:, 0, -2] * self.steps_per_day).type(torch.LongTensor).to(x_enc.device)  # 0-self.steps_per_day
        d_idx = (x_mark_enc[:, 0, -1] * 7).type(torch.LongTensor).to(x_enc.device)  # 0-6
        time_stamp = self.rlt_m_emb[m_idx]
        if 'week' in self.time_type:
            time_stamp = time_stamp + self.rlt_w_emb[d_idx]
        return time_stamp

    def abs_time_stamp(self, x_enc, x_mark_enc):
        month = (x_mark_enc[:, 0, 1] - 1).type(torch.LongTensor).to(x_enc.device)  # 1-12
        dow = (x_mark_enc[:, 0, 2]).type(torch.LongTensor).to(x_enc.device)  # 0-6
        dom = (x_mark_enc[:, 0, 3] - 1).type(torch.LongTensor).to(x_enc.device)  # 1-31
        hour = (x_mark_enc[:, 0, 4]).type(torch.LongTensor).to(x_enc.device)  # 0-23
        minute = (x_mark_enc[:, 0, 5]).type(torch.LongTensor).to(x_enc.device)  # 0-59

        time_list = [
            self.mo_emb[month] * (1. if not self.month else 0.),
            self.dm_emb[dom] * (1. if not self.month else 0.),
            self.dw_emb[dow],
            self.hr_emb[hour],
            self.mi_emb[minute] * (0. if self.minute == 60 else 1.0)  # dont have time
        ]
        time_stamp = torch.stack(time_list, dim=1).sum(dim=1)
        return time_stamp

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out
