import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from math import sqrt
from utils.masking import TriangularCausalMask, ProbMask
from reformer_pytorch import LSHSelfAttention
from einops import rearrange


class TSMixer(nn.Module):
    def __init__(self, attention, d_model, n_heads):
        super(TSMixer, self).__init__()

        self.attention = attention
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.n_heads = n_heads

    def forward(self, q, k, v, res=False, attn=None):
        B, L, _ = q.shape
        _, S, _ = k.shape
        H = self.n_heads

        q = self.q_proj(q).reshape(B, L, H, -1)
        k = self.k_proj(k).reshape(B, S, H, -1)
        v = self.v_proj(v).reshape(B, S, H, -1)

        out, attn = self.attention(
            q, k, v,
            res=res, attn=attn
        )
        out = out.view(B, L, -1)

        return self.out(out), attn


class ResAttention(nn.Module):
    def __init__(self, attention_dropout=0.1, scale=None, attn_map=False, nst=False):
        super(ResAttention, self).__init__()

        self.nst = nst
        self.scale = scale
        self.attn_map = attn_map
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, res=False, attn=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        attn_map = torch.softmax(scale * scores, dim=-1)
        if self.attn_map is True:
            heat_map = attn_map.reshape(32, -1, H, L, S)
            for b in range(heat_map.shape[0]):
                for c in range(heat_map.shape[1]):
                    h_map = heat_map[b, c, 0, ...].detach().cpu().numpy()
                    # plt.savefig(heat_map, f'{b} sample {c} channel')

                    plt.figure(figsize=(10, 8), dpi=200)
                    plt.imshow(h_map, cmap='Reds', interpolation='nearest')
                    plt.colorbar()

                    # 设置X轴和Y轴的标签为黑体文字
                    plt.rcParams['font.family'] = 'serif'
                    plt.rcParams['font.serif'] = ['Times New Roman']
                    plt.xlabel('Key Time Patch', fontsize=14)
                    plt.ylabel('Query Time Patch', fontsize=14)
                    plt.tight_layout()
                    if self.nst is True:
                        plt.savefig(f'./time map/{b}_sample_{c}_channel.png')
                    else:
                        plt.savefig(f'./stable time map/{b}_sample_{c}_channel.png')
                    # 关闭当前图形窗口
                    plt.close()
        A = self.dropout(attn_map)
        V = torch.einsum("bhls,bshd->blhd", A, values)

        return V.contiguous(), A






# ----------------- CrossScaleCohesionAttention（增强版） -----------------
class CrossScaleCohesionAttention(nn.Module):
    """
    CrossScaleCohesionAttention:
      - 默认行为：在通道维（variables）上做 attention（轻量实现），保持与你原先 cointegrated_attention 的语义。
      - 增强特性：支持传入 cross-scale key/value（kwargs: cross_k, cross_v），
        当提供时，模块会用 cross_k/cross_v 作为 K/V，Q 仍来自当前输入的 channel-level embedding，
        实现真正的 cross-scale cross-attention（例如把 integrated stage 的输出作为 K/V）。
      - 兼容性：构造签名与原 CointAttention 接近，输入可为 [B,C,P,D] 或 [B,P,D]。
    """
    def __init__(self, mixer, d_model, d_ff, dropout=0.0,
                 activation='relu', stable=False, enc_in=None, stable_len=8,
                 use_spectral=False):
        super().__init__()
        self.mixer = mixer
        self.d_model = d_model
        self.d_ff = d_ff
        self.dropout = nn.Dropout(dropout)
        self.activation = activation
        self.stable = stable
        self.enc_in = enc_in
        self.stable_len = stable_len
        self.use_spectral = use_spectral

        # channel-level projections for Q (from current), K/V (from current or cross-scale)
        self.to_q = nn.Linear(d_model, d_model, bias=False)
        self.to_k = nn.Linear(d_model, d_model, bias=False)
        self.to_v = nn.Linear(d_model, d_model, bias=False)

        # small bottleneck for efficiency / regularization
        self.bottleneck = nn.Sequential(
            nn.Linear(d_model, max(32, d_model // 4)),
            nn.GELU(),
            nn.Linear(max(32, d_model // 4), d_model),
        )

        # FFN applied per-position after aggregation
        act = nn.ReLU if activation == 'relu' else nn.GELU
        self.ffn = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_ff),
            act(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

        self.proj_out = nn.Identity()
        # learnable scale: 使用 mixer 的 heads 信息为近似（兼容）
        self.scale = nn.Parameter(torch.tensor(1.0 / sqrt(max(1, getattr(self.mixer, 'n_heads', 1)))))

    def forward(self, x, *args, **kwargs):
        """
        forward 支持以下调用模式（向后兼容）：
          - forward(x) : 默认 channel-attention，x -> mixer -> z, z dim: [B,C,P,D] 或 [B,P,D]
          - forward(x, cross_k=K, cross_v=V) : 使用 cross-scale K/V（K/V 形状应为 [B, C2, D] 或 [B, C2, P2, D]（会自动聚合到 channel-level））
        返回：与输入 layout 保持一致的 tensor（通常是 [B,C,P,D] 或 [B,P,D]）
        注意：如果 K/V 提供了 position 维，会在内部做 position-avg 聚合为 channel-level 表示。
        """
        # 1) 先通过 mixer（如果 mixer 接受 x）
        try:
            z = self.mixer(x)
        except Exception:
            z = x

        if not torch.is_tensor(z):
            z = torch.as_tensor(z)

        # CASE: [B, C, P, D] - 最常见 layout（通道在第二维）
        if z.dim() == 4:
            B, C, P, D = z.shape
            # channel-level embedding: position average（可替换为更复杂的聚合）
            chan_emb = z.mean(dim=2)  # [B, C, D]
            chan_emb_b = self.bottleneck(chan_emb)  # [B, C, D]

            # Q 来自当前 channel embedding
            Q = self.to_q(chan_emb_b)  # [B, C, D]

            # 检查是否提供 cross-scale K/V（优先使用）
            cross_k = kwargs.get('cross_k', None)
            cross_v = kwargs.get('cross_v', None)

            if cross_k is not None and cross_v is not None:
                # 允许 cross_k/cross_v 有多种 layout：
                # - [B, C2, D] (channel-level)
                # - [B, C2, P2, D] (带 position) -> 我们做 position avg -> [B, C2, D]
                K = cross_k
                V = cross_v
                if K.dim() == 4:
                    K = K.mean(dim=2)  # [B, C2, D]
                if V.dim() == 4:
                    V = V.mean(dim=2)  # [B, C2, D]
                # project K/V
                Kp = self.to_k(K)  # [B, C2, D]
                Vp = self.to_v(V)  # [B, C2, D]
                # Q shape: [B, C, D], Kp: [B, C2, D] -> attn logits [B, C, C2]
                attn_logits = torch.matmul(Q, Kp.transpose(-1, -2)) * self.scale
                attn = torch.softmax(attn_logits, dim=-1)  # [B, C, C2]
                attn = self.dropout(attn)
                ch_out = torch.matmul(attn, Vp)  # [B, C, D]
            else:
                # fallback: 内部 channel-to-channel attention（与原实现语义最接近）
                Kp = self.to_k(chan_emb_b)  # [B, C, D]
                Vp = self.to_v(chan_emb_b)  # [B, C, D]
                attn_logits = torch.matmul(Q, Kp.transpose(-1, -2)) * self.scale  # [B, C, C]
                attn = torch.softmax(attn_logits, dim=-1)
                attn = self.dropout(attn)
                ch_out = torch.matmul(attn, Vp)  # [B, C, D]

            # broadcast channel output back to positions
            ch_out_pos = ch_out.unsqueeze(2).expand(-1, -1, P, -1)  # [B, C, P, D]
            out = z + ch_out_pos

            # FFN per position
            out_flat = out.view(B * C * P, D)
            out_ff = self.ffn(out_flat).view(B, C, P, D)
            out = out + out_ff
            return self.proj_out(out)

        # CASE: [B, P, D] - 单通道或已合并的 layout
        elif z.dim() == 3:
            B, P, D = z.shape
            out_flat = z.view(B * P, D)
            out_ff = self.ffn(out_flat).view(B, P, D)
            out = z + out_ff
            return self.proj_out(out)

        else:
            return z
