import torch
import torch.nn as nn
import numpy as np

from .graph.graph import Graph
from .graph.tools import k_adjacency, normalize_adjacency_matrix, get_adjacency_matrix

import math
import torch.nn.functional as F

class SegmentSELayer(nn.Module):
    """
    SE weighting over spatial nodes V applied per-segment over time.
    """
    def __init__(self, M: int, V: int, reduction: int = 16):
        super().__init__()
        self.M = M
        self.V = V
        self.reduction = reduction
        # MLP on V-dim: V -> V//r -> V
        self.fc = nn.Sequential(
            nn.Linear(V, V // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(V // reduction, V, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [N, C, T, V]
        N, C, T, V = x.shape
        assert V == self.V, f"Expected V={self.V}, got {V}"

        seg_len = math.ceil(T / self.M)
        new_T = seg_len * self.M

        if new_T != T:

            x = F.interpolate(x.permute(0,1,3,2), size=(V, new_T), mode='bilinear', align_corners=False)
            x = x.permute(0,1,3,2)

        x_seg = x.view(N, C, self.M, seg_len, V)

        z = x_seg.mean(dim=1).mean(dim=2)  # mean over C and seg_len

        alpha = self.fc(z.view(-1, V))       # [N*M, V]
        alpha = alpha.view(N, self.M, V)     # [N, M, V]

        alpha = alpha.view(N, 1, self.M, 1, V)

        alpha = torch.clamp(alpha, 0, 0.1)  

        x_se = x_seg + x_seg * alpha * 0.1
 
        x_se = x_se.view(N, C, new_T, V)
      
        if new_T != T:
            x_se = F.interpolate(x_se.permute(0,1,3,2), size=(V, T), mode='bilinear', align_corners=False)
            x_se = x_se.permute(0,1,3,2)
       
        x_se_pad = F.pad(x_se, pad=(0, 0, 1, 1), mode='reflect')  # pad time dim
        x_se = (x_se_pad[:, :, :-2, :] + x_se_pad[:, :, 1:-1, :] + x_se_pad[:, :, 2:, :]) / 3
        return x_se

class MultiScale_GraphConv_SE(nn.Module):
    def __init__(self,
                 num_scales,   # 13
                 in_channels,
                 out_channels,
                 dataset,
                 M=25,
                 se_reduction=4,
                 disentangled_agg=True,
                 use_mask=True,
                 dropout=0.01,
                 activation='relu'):
        super().__init__()
        self.graph = Graph(labeling_mode='spatial', layout=dataset)
        neighbor = self.graph.neighbor
        self.num_scales = num_scales

        if dataset == 'LARA': #（19,19）
            A_binary = get_adjacency_matrix(neighbor, 19)  
        else:
            A_binary = get_adjacency_matrix(neighbor, 25)  
                
        if disentangled_agg: #13 scale
            A_powers = [k_adjacency(A_binary, k, with_self=True) for k in range(num_scales)] # 13（V，V）
            A_powers = np.concatenate([normalize_adjacency_matrix(g) for g in A_powers]) #（13×V，V）
        else:
            A_powers = [A_binary + np.eye(len(A_binary)) for k in range(num_scales)]
            A_powers = [normalize_adjacency_matrix(g) for g in A_powers]
            A_powers = [np.linalg.matrix_power(g, k) for k, g in enumerate(A_powers)]
            A_powers = np.concatenate(A_powers)

        self.A_powers = torch.Tensor(A_powers)
        MV = self.A_powers.shape[0]  # num_scales * V
        self.seg_se = SegmentSELayer(M, MV, reduction=se_reduction)
        self.use_mask = use_mask
        if use_mask:
            self.A_res = nn.Parameter(torch.zeros_like(self.A_powers)) if use_mask else None
        self.mlp = MLP(in_channels * num_scales, [out_channels], dropout=dropout, activation=activation)

    def forward(self, x):
        # x: [N, C_in, V, T]
        x = x.transpose(2, 3)  # [N, C, T, V]
        N, C, T, V = x.shape
        A = self.A_powers.to(x.device).to(x.dtype)
        if self.use_mask:
            A = A + self.A_res.to(x.dtype)
        support = torch.einsum('vu,nctu->nctv', A, x)  # [N, C, T, V] from adjacency op
        
        support = self.seg_se(support)  # apply SE over T, V
        support = support.view(N, C, T, self.num_scales, V)
        support = support.permute(0,3,1,2,4).contiguous().view(N, self.num_scales*C, T, V)
        out = self.mlp(support)  # [N, C, T, V]
        return out


def get_num_nodes(dataset):
    if dataset == 'LARA':
        return 19
    else:
        return 25

class MultiScale_GraphConv(nn.Module):
    def __init__(self,
                 num_scales,   # 13
                 in_channels,
                 out_channels,
                 dataset,
                 disentangled_agg=True,
                 use_mask=True,
                 dropout=0,
                 activation='relu'):
        super().__init__()

        self.graph = Graph(labeling_mode='spatial', layout=dataset)
        neighbor = self.graph.neighbor
        self.num_scales = num_scales #13
        
        if dataset == 'LARA': #（19,19）
            A_binary = get_adjacency_matrix(neighbor, 19)  
        else:
            A_binary = get_adjacency_matrix(neighbor, 25)  
                
        if disentangled_agg: #13 scale
            A_powers = [k_adjacency(A_binary, k, with_self=True) for k in range(num_scales)] # 13（V，V）
            A_powers = np.concatenate([normalize_adjacency_matrix(g) for g in A_powers]) #（13×V，V）
        else:
            A_powers = [A_binary + np.eye(len(A_binary)) for k in range(num_scales)]
            A_powers = [normalize_adjacency_matrix(g) for g in A_powers]
            A_powers = [np.linalg.matrix_power(g, k) for k, g in enumerate(A_powers)]
            A_powers = np.concatenate(A_powers)

        self.A_powers = torch.Tensor(A_powers)
        self.use_mask = use_mask
        if use_mask:
            # NOTE: the inclusion of residual mask appears to slow down training noticeably
            self.A_res = nn.init.uniform_(nn.Parameter(torch.Tensor(self.A_powers.shape)), -1e-6, 1e-6)

        self.mlp = MLP(in_channels * num_scales, [out_channels], dropout=dropout, activation=activation)
    
    def forward(self, x):
        x = x.transpose(2, 3) #n,c,v,t->n,c,t,v
        N, C, T, V = x.shape
        self.A_powers = self.A_powers.to(x.device) #(13*v,v)
        A = self.A_powers.to(x.dtype)
        if self.use_mask:
            A = A + self.A_res.to(x.dtype)
        
        support = torch.einsum('vu,nctu->nctv', A, x) #（n,c,t,13*v）
        support = support.view(N, C, T, self.num_scales, V) #（n,c,t,13,v）
        support = support.permute(0,3,1,2,4).contiguous().view(N, self.num_scales*C, T, V) ##（n,13*c,t,v）
        out = self.mlp(support) #（n,c,t,v）
        return out


class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, activation='relu', dropout=0):
        super().__init__()
        channels = [in_channels] + out_channels
        self.layers = nn.ModuleList()
        for i in range(1, len(channels)):
            if dropout > 0.001:
                self.layers.append(nn.Dropout(p=dropout))
            self.layers.append(nn.Conv2d(channels[i-1], channels[i], kernel_size=1))
            self.layers.append(nn.BatchNorm2d(channels[i]))
            self.layers.append(activation_factory(activation)) #relu

    def forward(self, x):
        # Input shape: (N,C,T,V)
        for layer in self.layers:
            x = layer(x)
        return x

def activation_factory(name, inplace=True):
    if name == 'relu':
        return nn.ReLU(inplace=inplace)
    elif name == 'leakyrelu':
        return nn.LeakyReLU(0.2, inplace=inplace)
    elif name == 'tanh':
        return nn.Tanh()
    elif name == 'linear' or name is None:
        return nn.Identity()
    else:
        raise ValueError('Not supported activation:', name)
    
