import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt


def PeriodNorm(x, period_len=24):
    if len(x.shape) == 3:
        x = x.unsqueeze(-2)
    b, c, n, t = x.shape
    x_patch = [x[..., period_len-1-i:-i+t] for i in range(0, period_len)]
    x_patch = torch.stack(x_patch, dim=-1)

    mean = x_patch.mean(4)
    # var = (x_patch ** 2).mean(4) - mean ** 2 + 0.00001
    mean = F.pad(mean.reshape(b * c, n, -1), mode='replicate', pad=(period_len-1, 0)).reshape(b, c, n, -1)
    # var = F.pad(var.reshape(b * c, n, -1), mode='replicate', pad=(period_len-1, 0)).reshape(b, c, n, -1)
    # out = (x - mean) / (var + 1e-5) ** 0.5
    out = x - mean

    out, mean = out.squeeze(-2), mean.squeeze(-2)
    # out, mean, var = out.squeeze(-2), mean.squeeze(-2), var.squeeze(-2)
    # return out, mean, var ** 0.5
    return out, mean, 1


class TSMixer(nn.Module):
    def __init__(self, attention, d_model, n_heads):
        super(TSMixer, self).__init__()

        self.attention = attention
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.n_heads = n_heads

    def forward(self, q, k, v, res=False, attn=None):
        B, L, _ = q.shape
        _, S, _ = k.shape
        H = self.n_heads

        q = self.q(q).view(B, L, H, -1)
        k = self.k(k).view(B, S, H, -1)
        v = self.v(v).view(B, S, H, -1)

        out, attn = self.attention(
            q, k, v,
            res=res, attn=attn
        )
        out = out.view(B, L, -1)

        return self.out(out), attn

class ResAttention(nn.Module):
    def __init__(self, attention_dropout=0.1, scale=None):
        super(ResAttention, self).__init__()
        self.scale = scale
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, res=False, attn=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        return V.contiguous(), A