import torch
import torch.nn as nn
from models.utils import Segmentation
from .ACL import GlobalGraphLearner
from .GNN import Propagation
from .ESM import EvolutionaryStateModeling
from .Pool import Pooling


class DSN(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.hidden = args.hidden
        self.n_channels = args.n_channels
        self.seq_len = args.window // args.seg
        self.dropout = args.dropout
        self.preprocess = args.preprocess
        self.dataset = args.dataset
        self.task = args.task
        assert 'cls' in self.task or 'anomaly' in self.task

        self.time_layers = args.time_layers
        self.time_activation = args.time_activation
        self.time_decay = args.time_decay

        self.pool_method = args.pool_method
        self.pool_heads = args.pool_heads

        self.global_graph_method = args.global_graph_method
        self.global_gnn_layers = args.global_gnn_layers
        self.global_gnn_activation = args.global_gnn_activation
        self.global_gnn_depth = args.global_gnn_depth

        self.use_ffn = args.use_ffn
        self.input_dim = args.input_dim

        self.activation = args.activation
        self.classifier = args.classifier

        if self.preprocess == 'seg':
            self.segmentation = Segmentation(self.input_dim, self.hidden, self.n_channels)
        elif self.preprocess == 'fft':
            self.fc = nn.Linear(self.input_dim, self.hidden)

        # temporal
        self.time_layer = nn.ModuleList()
        self.time_ln = nn.ModuleList()
        for _ in range(self.time_layers):
            self.time_layer.append(EvolutionaryStateModeling(self.hidden, self.seq_len, self.time_layers, self.dropout,
                                                             self.time_activation, self.time_decay))
            self.time_ln.append(nn.LayerNorm(self.hidden))

        # pooling
        self.pooling = Pooling(self.hidden, self.seq_len, self.n_channels, self.pool_method, self.pool_heads)

        # global
        self.global_graph_learner = nn.ModuleList()
        self.global_layer = nn.ModuleList()
        self.global_ln = nn.ModuleList()
        for _ in range(self.global_gnn_layers):
            self.global_graph_learner.append(GlobalGraphLearner(self.hidden, 1, self.n_channels,
                                                                self.global_graph_method, self.dropout, pos_enc=True))
            self.global_layer.append(Propagation(self.hidden, self.n_channels, self.global_gnn_layers, self.dropout,
                                                 self.global_gnn_activation, self.global_gnn_depth))
            self.global_ln.append(nn.LayerNorm(self.hidden))

        # ffn
        if self.use_ffn:
            self.ffn = nn.Sequential(nn.Linear(self.hidden, 4 * self.hidden),
                                     nn.GELU(),
                                     nn.Linear(4 * self.hidden, self.hidden))
            self.ffn_ln = nn.LayerNorm(self.hidden)

        # decoder
        if self.activation == 'tanh':
            self.act = nn.Tanh()
        elif self.activation == 'relu':
            self.act = nn.ReLU()
        else:
            self.act = nn.Identity()
        if self.classifier == 'mlp':
            self.decoder = nn.Sequential(nn.Linear(self.n_channels * self.hidden, self.hidden),
                                         nn.ReLU(),
                                         nn.Linear(self.hidden, 1))
        elif self.classifier == 'max':
            self.decoder = nn.Linear(self.hidden, 1)
        else:
            raise ValueError()

    def forward(self, x, p, y):
        # (B, T, C, D/S)
        bs = x.shape[0]

        if self.preprocess == 'seg':
            x = self.segmentation.segment(x)  # (B, T, C, D)
        elif self.preprocess == 'fft':
            x = self.fc(x)  # (B, T, C, D)

        # time
        x = x.transpose(2, 1).reshape(bs * self.n_channels, self.seq_len, self.hidden)  # (B*C, T, D)
        for layer in range(self.time_layers):
            x = self.time_layer[layer](x)  # (B*C, T, D)
            x = self.time_ln[layer](x)  # (B*C, T, D)

        x = x.reshape(bs, self.n_channels, self.seq_len, self.hidden)

        if 'cls' in self.task:
            # local graph pooling
            x = self.pooling(x)  # (B, C, 1, D)

            # global graph
            for layer in range(self.global_gnn_layers):
                global_graph = self.global_graph_learner[layer](x)  # (B, C, C)
                x = x.squeeze(dim=2)
                x = self.global_layer[layer](x, global_graph)  # (B, C, D)
                x = x.unsqueeze(dim=2)
                x = self.global_ln[layer](x)  # (B, C, D)

            # ffn
            z = x
            if self.use_ffn:
                z = self.ffn_ln(z + self.ffn(z))

            # decoder
            z = torch.mean(z, dim=-2)
            z = self.act(z)

            if self.classifier == 'mlp':
                z = z.reshape(bs, self.n_channels * self.hidden)  # (B, C, D)
                z = self.decoder(z).squeeze(dim=-1)  # (B)
            elif self.classifier == 'max':
                z = z.reshape(bs, self.n_channels, self.hidden)  # (B, C, D)
                z = self.decoder(z).squeeze(dim=-1)  # (B, C)
                z, _ = torch.max(z, dim=1)

            return z, None

        else:
            x = x.transpose(1, 2)  # (B, T, C, D)

            # global graph
            for layer in range(self.global_gnn_layers):
                global_graphs = []
                for t in range(self.seq_len):
                    xt = x[:, t, :, :].unsqueeze(dim=2)  # (B, C, 1, D)
                    global_graph = self.global_graph_learner[layer](xt)  # (B, C, C)
                    global_graphs.append(global_graph)
                if hasattr(self, 'global_graphs'):  # for visualization
                    self.global_graphs.append(torch.stack(global_graphs, dim=1).cpu().detach())

                global_graphs = torch.cat(global_graphs, dim=0)

                x = x.reshape(bs * self.seq_len, self.n_channels, self.hidden)  # (B*T, C, D)
                x = self.global_layer[layer](x, global_graphs)  # (B*T, C, D)
                x = x.reshape(bs, self.seq_len, self.n_channels, self.hidden)  # (B, T, C, D)
                x = self.global_ln[layer](x)  # (B, T, C, D)

            # ffn
            z = x
            if self.use_ffn:
                z = self.ffn_ln(z + self.ffn(z))

            z = self.act(z)
            if self.classifier == 'mlp':
                z = z.reshape(bs, self.seq_len, self.n_channels * self.hidden)  # (B, T, C*D)
                z = self.decoder(z).squeeze(dim=-1)  # (B, T)
            elif self.classifier == 'max':
                z = z.reshape(bs, self.seq_len, self.n_channels, self.hidden)  # (B, T, C, D)
                z = self.decoder(z).squeeze(dim=-1)  # (B, T, C)
                z, _ = torch.max(z, dim=-1)

            return z, None
