import torch.nn as nn
import torch
from torchinfo import summary
import util


class AttentionLayer(nn.Module):
    """Perform attention across the -2 dim (the -1 dim is `model_dim`).

    Make sure the tensor is permuted to correct shape before attention.

    E.g.
    - Input shape (batch_size, in_steps, num_nodes, model_dim).
    - Then the attention will be performed across the nodes.

    Also, it supports different src and tgt length.

    But must `src length == K length == V length`.

    """

    def __init__(self, model_dim, num_heads=8, mask=False):
        super().__init__()

        self.model_dim = model_dim
        self.num_heads = num_heads
        self.mask = mask

        self.head_dim = model_dim // num_heads

        self.FC_Q = nn.Linear(model_dim, model_dim)
        self.FC_K = nn.Linear(model_dim, model_dim)
        self.FC_V = nn.Linear(model_dim, model_dim)

        self.out_proj = nn.Linear(model_dim, model_dim)

    def forward(self, query, key, value):

        batch_size = query.shape[0]
        tgt_length = query.shape[-2]
        src_length = key.shape[-2]

        query = self.FC_Q(query)
        key = self.FC_K(key)
        value = self.FC_V(value)

        query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)
        key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
        value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)

        key = key.transpose(
            -1, -2
        )

        attn_score = (
                             query @ key
                     ) / self.head_dim ** 0.5

        if self.mask:
            mask = torch.ones(
                tgt_length, src_length, dtype=torch.bool, device=query.device
            ).tril()
            attn_score.masked_fill_(~mask, -torch.inf)

        attn_score = torch.softmax(attn_score, dim=-1)
        out = attn_score @ value
        out = torch.cat(
            torch.split(out, batch_size, dim=0), dim=-1
        )

        out = self.out_proj(out)

        return out


class SelfAttentionLayer(nn.Module):
    def __init__(
            self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
    ):
        super().__init__()

        self.attn = AttentionLayer(model_dim, num_heads, mask)
        self.feed_forward = nn.Sequential(
            nn.Linear(model_dim, feed_forward_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feed_forward_dim, model_dim),
        )
        self.ln1 = nn.LayerNorm(model_dim)
        self.ln2 = nn.LayerNorm(model_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, dim=-2):
        x = x.transpose(dim, -2)

        residual = x
        out = self.attn(x, x, x)
        out = self.dropout1(out)
        out = self.ln1(residual + out)

        residual = out
        out = self.feed_forward(out)
        out = self.dropout2(out)
        out = self.ln2(residual + out)

        out = out.transpose(dim, -2)
        return out


class nconv(nn.Module):
    def __init__(self):
        super(nconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('ncvl,wv->ncwl', (x, A))
        return x.contiguous()


class d_nconv(nn.Module):
    def __init__(self):
        super(d_nconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('ncwl,nvw->ncvl', (x, A))
        return x.contiguous()


class linear(nn.Module):
    def __init__(self, c_in, c_out):
        super(linear, self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)

    def forward(self, x):
        return self.mlp(x)


class linear_(nn.Module):
    def __init__(self, c_in, c_out):
        super(linear_, self).__init__()
        self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 2), dilation=2, padding=(0, 0), stride=(1, 1),
                                   bias=True).double()

    def forward(self, x):
        return self.mlp(x)


class gcn(nn.Module):
    def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
        super(gcn, self).__init__()
        self.nconv = nconv()
        c_in = (order * support_len + 1) * c_in
        self.mlp = linear(c_in, c_out).double()
        self.dropout = dropout
        self.order = order

    def forward(self, x, support):
        x = x.double()
        out = [x]
        for a in support:
            x1 = self.nconv(x, a)
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = self.nconv(x1, a)
                out.append(x2)
                x1 = x2
        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = nn.functional.dropout(h, self.dropout, training=self.training)
        return h


class dhgcn(nn.Module):
    def __init__(self, c_in, c_out, dropout, order=2):
        super(dhgcn, self).__init__()
        self.d_nconv = d_nconv()
        c_in = (order + 1) * c_in
        self.mlp = linear_(c_in, c_out)
        self.dropout = dropout
        self.order = order

    def forward(self, x, G):
        x = x.double()
        out = [x]
        support = [G]
        for a in support:
            x1 = self.d_nconv(x, a)
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = self.d_nconv(x1, a)
                out.append(x2)
                x1 = x2
        h = torch.cat(out, dim=1)
        h = self.mlp(h)
        h = nn.functional.dropout(h, self.dropout, training=self.training)
        return h


class MvHSTM(nn.Module):
    def __init__(
            self,
            device,
            hyper_graph_lap,
            line_graph_lap,
            num_nodes,
            spatial_H_a, spatial_H_b, spatial_G0, spatial_G1, spatial_H_T_new, spatial_lwjl,
            semantic_H, semantic_H_T_new, semantic_G0, semantic_G1,
            in_steps=12,
            out_steps=12,
            steps_per_day=288,
            input_dim=3,
            output_dim=2,
            input_embedding_dim=24,
            tod_embedding_dim=24,
            dow_embedding_dim=24,
            spatial_embedding_dim=0,
            adaptive_embedding_dim=80,
            feed_forward_dim=256,
            num_heads=4,
            num_layers=3,
            dropout=0.1,
            supports=None,
            in_dim=2, out_dim=12, residual_channels=40, dilation_channels=40, skip_channels=320, end_channels=640,
            kernel_size=2, blocks=3, layers=1,
            use_mixed_proj=True,
    ):
        super().__init__()

        self.num_nodes = num_nodes
        self.in_steps = in_steps
        self.out_steps = out_steps
        self.steps_per_day = steps_per_day
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_embedding_dim = input_embedding_dim
        self.tod_embedding_dim = tod_embedding_dim
        self.dow_embedding_dim = dow_embedding_dim
        self.spatial_embedding_dim = spatial_embedding_dim
        self.adaptive_embedding_dim = adaptive_embedding_dim
        self.model_dim = (
                input_embedding_dim
                + tod_embedding_dim
                + dow_embedding_dim
                + spatial_embedding_dim
                + adaptive_embedding_dim
        )
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.use_mixed_proj = use_mixed_proj

        self.input_proj = nn.Linear(input_dim, input_embedding_dim)
        if tod_embedding_dim > 0:
            self.tod_embedding = nn.Embedding(steps_per_day, tod_embedding_dim)
        if dow_embedding_dim > 0:
            self.dow_embedding = nn.Embedding(7, dow_embedding_dim)
        if spatial_embedding_dim > 0:
            self.node_emb = nn.Parameter(
                torch.empty(self.num_nodes, self.spatial_embedding_dim)
            )
            nn.init.xavier_uniform_(self.node_emb)
        if adaptive_embedding_dim > 0:
            self.adaptive_embedding = nn.init.xavier_uniform_(
                nn.Parameter(torch.empty(in_steps, num_nodes, adaptive_embedding_dim))
            )

        if use_mixed_proj:
            self.output_proj = nn.Linear(
                in_steps * self.model_dim, out_steps * output_dim
            )
        else:
            self.temporal_proj = nn.Linear(in_steps, out_steps)
            self.output_proj = nn.Linear(self.model_dim, self.output_dim)

        self.attn_layers_t = nn.ModuleList(
            [
                SelfAttentionLayer(self.model_dim, feed_forward_dim, num_heads, dropout)
                for _ in range(num_layers)
            ]
        )

        self.attn_layers_s = nn.ModuleList(
            [
                SelfAttentionLayer(self.model_dim, feed_forward_dim, num_heads, dropout)
                for _ in range(num_layers)
            ]
        )
        self.spatial_H_a = spatial_H_a
        self.spatial_H_b = spatial_H_b
        self.spatial_G0 = spatial_G0
        self.spatial_G1 = spatial_G1
        self.spatial_H_T_new = spatial_H_T_new
        self.spatial_lwjl = spatial_lwjl
        self.semantic_H = semantic_H
        self.semantic_H_T_new = semantic_H_T_new
        self.semantic_G0 = semantic_G0
        self.semantic_G1 = semantic_G1

        self.edge_node_vec1 = nn.Parameter(torch.rand(self.spatial_H_a.size(1), 10).to(device), requires_grad=True).to(
            device)
        self.edge_node_vec2 = nn.Parameter(torch.rand(10, self.spatial_H_a.size(0)).to(device), requires_grad=True).to(
            device)
        self.edge_node_vec3 = nn.Parameter(torch.rand(self.semantic_H.size(1), 10).to(device), requires_grad=True).to(
            device)
        self.edge_node_vec4 = nn.Parameter(torch.rand(10, self.semantic_H.size(0)).to(device), requires_grad=True).to(
            device)

        self.node_edge_vec1 = nn.Parameter(torch.rand(self.spatial_H_a.size(0), 10).to(device), requires_grad=True).to(
            device)
        self.node_edge_vec2 = nn.Parameter(torch.rand(10, self.spatial_H_a.size(1)).to(device), requires_grad=True).to(
            device)
        self.node_edge_vec3 = nn.Parameter(torch.rand(self.semantic_H.size(0), 10).to(device), requires_grad=True).to(
            device)
        self.node_edge_vec4 = nn.Parameter(torch.rand(10, self.semantic_H.size(1)).to(device), requires_grad=True).to(
            device)

        self.dropout = dropout
        self.blocks = blocks
        self.layers = layers

        self.skip_convs = nn.ModuleList()
        self.skip_convs2 = nn.ModuleList()
        self.bn = nn.ModuleList()
        self.bn2 = nn.ModuleList()
        self.filter_convs_h = nn.ModuleList()
        self.gate_convs_h = nn.ModuleList()
        self.filter_convs_h2 = nn.ModuleList()
        self.gate_convs_h2 = nn.ModuleList()
        self.gconv_dgcn_w = nn.ModuleList()
        self.gconv_dgcn_w2 = nn.ModuleList()
        self.dhgconv = nn.ModuleList()
        self.dhgconv2 = nn.ModuleList()
        self.bn_hg = nn.ModuleList()
        self.bn_hg2 = nn.ModuleList()

        self.start_conv = nn.Conv2d(in_channels=in_dim,
                                    out_channels=residual_channels,
                                    kernel_size=(1, 1)).float()
        self.start_conv2 = nn.Conv2d(in_channels=in_dim,
                                     out_channels=residual_channels,
                                     kernel_size=(1, 1)).float()
        self.supports = supports
        self.num_nodes = num_nodes
        receptive_field = 1
        self.supports_len = 0
        self.supports_len += len(supports)
        self.supports_len += 1

        for b in range(blocks):
            additional_scope = kernel_size
            new_dilation = 2
            for i in range(layers):
                self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels,
                                                 out_channels=skip_channels,
                                                 kernel_size=(1, 1))).double()
                self.skip_convs2.append(nn.Conv2d(in_channels=dilation_channels,
                                                  out_channels=skip_channels,
                                                  kernel_size=(1, 1))).double()
                self.bn.append(nn.BatchNorm2d(residual_channels)).double()
                self.bn2.append(nn.BatchNorm2d(residual_channels)).double()
                self.filter_convs_h.append(nn.Conv2d(in_channels=1 + residual_channels * 2,
                                                     out_channels=dilation_channels,
                                                     kernel_size=(1, kernel_size), dilation=new_dilation))
                self.gate_convs_h.append(nn.Conv2d(in_channels=1 + residual_channels * 2,
                                                   out_channels=dilation_channels,
                                                   kernel_size=(1, kernel_size), dilation=new_dilation))
                self.filter_convs_h2.append(nn.Conv2d(in_channels=residual_channels,
                                                      out_channels=dilation_channels,
                                                      kernel_size=(1, kernel_size), dilation=new_dilation))
                self.gate_convs_h2.append(nn.Conv2d(in_channels=residual_channels,
                                                    out_channels=dilation_channels,
                                                    kernel_size=(1, kernel_size), dilation=new_dilation))
                receptive_field += (additional_scope * 2)
                self.gconv_dgcn_w.append(
                    gcn((residual_channels), 1, dropout, support_len=2, order=1))
                self.gconv_dgcn_w2.append(
                    gcn((residual_channels), 1, dropout, support_len=2, order=1))
                self.dhgconv.append(dhgcn(dilation_channels, int(residual_channels / 2), dropout))
                self.dhgconv2.append(dhgcn(dilation_channels, int(residual_channels / 2), dropout))
                self.bn_hg.append(nn.BatchNorm2d(int(residual_channels / 2))).double()
                self.bn_hg2.append(nn.BatchNorm2d(int(residual_channels / 2))).double()

        self.end_conv_1 = nn.Conv2d(in_channels=skip_channels,
                                    out_channels=end_channels,
                                    kernel_size=(1, 1),
                                    bias=True).double()
        self.end_conv_3 = nn.Conv2d(in_channels=skip_channels,
                                    out_channels=end_channels,
                                    kernel_size=(1, 1),
                                    bias=True).double()

        self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
                                    out_channels=out_dim,
                                    kernel_size=(1, 1),
                                    bias=True).double()
        self.end_conv_4 = nn.Conv2d(in_channels=end_channels,
                                    out_channels=out_dim,
                                    kernel_size=(1, 1),
                                    bias=True).double()
        self.receptive_field = receptive_field
        self.bn_start = nn.BatchNorm2d(in_dim, affine=False).float()
        self.bn_start2 = nn.BatchNorm2d(in_dim, affine=False).float()
        self.output_layer = nn.Linear(in_features=2, out_features=1).double()
        self.change_layer = nn.Linear(in_features=13, out_features=1).double()

    def forward(self, x):
        batch_size = x.shape[0]

        if self.tod_embedding_dim > 0:
            tod = x[..., 1]
        if self.dow_embedding_dim > 0:
            dow = x[..., 2]
        x = x[..., : self.input_dim]

        x = self.input_proj(x)
        features = [x]
        if self.tod_embedding_dim > 0:
            tod_emb = self.tod_embedding(
                (tod * self.steps_per_day).long()
            )
            features.append(tod_emb)
        if self.dow_embedding_dim > 0:
            dow_emb = self.dow_embedding(
                dow.long()
            )
            features.append(dow_emb)
        if self.spatial_embedding_dim > 0:
            spatial_emb = self.node_emb.expand(
                batch_size, self.in_steps, *self.node_emb.shape
            )
            features.append(spatial_emb)
        if self.adaptive_embedding_dim > 0:
            adp_emb = self.adaptive_embedding.expand(
                size=(batch_size, *self.adaptive_embedding.shape)
            )
            features.append(adp_emb)
        x = torch.cat(features, dim=-1)

        for attn in self.attn_layers_t:
            x = attn(x, dim=1)



        if self.use_mixed_proj:
            out = x.transpose(1, 2)
            out = out.reshape(
                batch_size, self.num_nodes, self.in_steps * self.model_dim
            )
            out = self.output_proj(out).view(
                batch_size, self.num_nodes, self.out_steps, self.output_dim
            )
            out = out.transpose(1, 2)
        else:
            out = x.transpose(1, 3)
            out = self.temporal_proj(
                out
            )
            out = self.output_proj(
                out.transpose(1, 3)
            )

        input = out.transpose(1, 3)
        input = nn.functional.pad(input, pad=(1, 0), mode='constant', value=0)
        in_len = input.size(3)
        if in_len < self.receptive_field:
            x_1 = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0))
            x_2 = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0))
        else:
            x_1 = input
            x_2 = input

        x_1 = self.bn_start(x_1)
        x_1 = self.start_conv(x_1)
        skip = 0


        new_supports = self.supports
        edge_node_H = (self.spatial_H_T_new * (
            torch.mm(self.edge_node_vec1, self.edge_node_vec2)))
        self.spatial_H_a_ = (self.spatial_H_a * (torch.mm(self.node_edge_vec1, self.node_edge_vec2))).float()
        self.spatial_H_b_ = (self.spatial_H_b * (torch.mm(self.node_edge_vec1, self.node_edge_vec2))).float()

        for i in range(self.blocks * self.layers):

            x_1 = x_1.float()
            edge_feature = util.feature_node_to_edge(x_1, self.spatial_H_a_, self.spatial_H_b_, operation="concat")
            edge_feature = torch.cat([edge_feature, self.spatial_lwjl.repeat(1, 1, 1, edge_feature.size(3))], dim=1).float()
            filter_h = self.filter_convs_h[i](edge_feature)
            filter_h = torch.tanh(filter_h)
            gate_h = self.gate_convs_h[i](edge_feature)
            gate_h = torch.sigmoid(gate_h)
            x_h = filter_h * gate_h

            residual = x_1
            dhgcn_w_input = residual
            dhgcn_w_input = dhgcn_w_input.transpose(1, 2)
            dhgcn_w_input = torch.mean(dhgcn_w_input, 3)
            dhgcn_w_input = dhgcn_w_input.transpose(0, 2)
            dhgcn_w_input = torch.unsqueeze(dhgcn_w_input, dim=0)
            dhgcn_w_input = self.gconv_dgcn_w[i](dhgcn_w_input, self.supports)
            dhgcn_w_input = torch.squeeze(dhgcn_w_input)
            dhgcn_w_input = dhgcn_w_input.transpose(0, 1)
            dhgcn_w_input = self.spatial_G0 @ (torch.diag_embed(dhgcn_w_input)) @ self.spatial_G1

            x_h = self.dhgconv[i](x_h, dhgcn_w_input)
            x_h = self.bn_hg[i](x_h)


            x_1 = util.fusion_edge_node(x_1, x_h, edge_node_H)
            x_1 = x_1 + residual[:, :, :, -x_1.size(3):]
            x_1 = self.bn[i](x_1)

            s = x_1
            s = self.skip_convs[i](s)
            try:
                skip = skip[:, :, :, -s.size(3):]
            except:
                skip = 0
            skip = s + skip

        x_1 = nn.functional.leaky_relu(skip)
        x_1 = nn.functional.leaky_relu(self.end_conv_1(x_1))
        x_1 = self.end_conv_2(x_1)


        x_2 = self.bn_start2(x_2)
        x_2 = self.start_conv2(x_2)
        skip2 = 0

        edge_node_H2 = (self.semantic_H_T_new * (
            torch.mm(self.edge_node_vec3, self.edge_node_vec4)))
        semantic_H = (self.semantic_H * (
            torch.mm(self.node_edge_vec3, self.node_edge_vec4))).float().detach()
        semantic_H = semantic_H.transpose(0, 1)
        for i in range(self.blocks * self.layers):
            x_2 = x_2.float()
            edge_feature = x_2.permute(0, 1, 3, 2)
            edge_feature2 = torch.matmul(edge_feature,
                                         semantic_H)
            edge_feature2 = edge_feature2.permute(0, 1, 3, 2)
            filter_h2 = self.filter_convs_h2[i](edge_feature2)
            filter_h2 = torch.tanh(filter_h2)
            gate_h2 = self.gate_convs_h2[i](edge_feature2)
            gate_h2 = torch.sigmoid(gate_h2)
            x_h2 = filter_h2 * gate_h2

            residual2 = x_2
            shgcn_w_input = residual2 + 1e-8
            shgcn_w_input = shgcn_w_input.transpose(1, 2)
            shgcn_w_input = torch.mean(shgcn_w_input, 3)
            shgcn_w_input = shgcn_w_input.transpose(0, 2)
            shgcn_w_input = torch.unsqueeze(shgcn_w_input, dim=0)
            shgcn_w_input = self.gconv_dgcn_w2[i](shgcn_w_input, self.supports)
            shgcn_w_input = torch.squeeze(shgcn_w_input)
            shgcn_w_input = shgcn_w_input.transpose(0, 1)
            shgcn_w_input = self.semantic_G0 @ (torch.diag_embed(shgcn_w_input)) @ self.semantic_G1

            x_h2 = self.dhgconv2[i](x_h2, shgcn_w_input)
            x_h2 = self.bn_hg2[i](x_h2)

            x_2 = util.fusion_edge_node(x_2, x_h2, edge_node_H2)
            x_2 = x_2 + residual2[:, :, :, -x_2.size(3):]
            x_2 = self.bn2[i](x_2)

            s2 = x_2
            s2 = self.skip_convs2[i](s2)
            try:
                skip2 = skip2[:, :, :, -s2.size(3):]
            except:
                skip2 = 0
            skip2 = s2 + skip2

        torch.cuda.empty_cache()
        x_2 = nn.functional.leaky_relu(skip2)
        x_2 = nn.functional.leaky_relu(self.end_conv_3(x_2))
        x_2 = self.end_conv_4(x_2)

        x_concat = torch.cat((x_1, x_2), dim=3)
        x_concat = self.output_layer(x_concat)

        return x_concat


