# encoding:utf-8
import torch.nn.functional as F
import torch
import torch.nn as nn
from lib.layers import gcn
from space_gcn import adp_gcn,Mish
from einops.layers.torch import Rearrange
from einops import rearrange
from tcn import Hight_order_Gate
from tadconv import TAdFeatureCNN,TAdSpaceCNN
class Encoder(nn.Module):
    def __init__(self,supports_len,in_dim=2,blocks=4,
                 layers=2,kernel_size=2,dropout=0.3,
                 latchannels=64):
        super(Encoder, self).__init__()
        self.start_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_dim,
                                    out_channels=latchannels,
                                    kernel_size=(1, 1)),
            Mish()
        )
        self.blocks = blocks
        self.layers = layers
        self.supports_len = supports_len
        self.tcns = nn.ModuleList()
        self.tsimpler = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.gconvs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()

        receptive_field = 1
        for b in range(blocks):
            additional_scope = kernel_size - 1
            new_dilation = 1
            for i in range(layers):
                self.tcns.append(Hight_order_Gate(
                    in_channels=latchannels,out_channels=latchannels,
                    kernel_size=3,padding=1,dilation=new_dilation
                ))
                self.bns.append(nn.BatchNorm2d(latchannels))
                new_dilation *= 2
                receptive_field += additional_scope
                additional_scope *= 2
                self.gconvs.append(adp_gcn(latchannels,latchannels,
                                          dropout,support_len=self.supports_len))
        self.receptive_field = receptive_field
        self.out_linear = nn.Sequential(
            nn.Conv2d(in_channels=latchannels,
                      out_channels=latchannels,
                      kernel_size=(1, 1)),
            Mish()
        )

    def forward(self,x,adjs):
        in_len = x.size(3)
        if in_len < self.receptive_field:
            x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
        else:
            x = x
        x = self.start_conv(x)
        for i in range(self.blocks * self.layers):
            residual = x
            x = self.tcns[i](x)
            x = self.gconvs[i](x, adjs)
            x = x + residual[:, :, :, -x.size(3):]
            x = self.bns[i](x)
        return self.out_linear(x)

class Decoder(nn.Module):
    def __init__(self,input_channels,out_dim):
        super(Decoder, self).__init__()
        #mid_channels = 3*input_channels
        mid_channels = int(input_channels/2)
        self.model = nn.Sequential(
            TAdFeatureCNN(in_channels=input_channels,
                        out_channels=mid_channels,
                        kernel_size=1, padding=0),
            Mish(),
            TAdFeatureCNN(in_channels=mid_channels,
                      out_channels=out_dim,
                      kernel_size=(1, 1),
                      bias=True),
            Rearrange('b t n f -> b f n t')

        )
    def forward(self,x):
        return self.model(x)

class sthsl(nn.Module):
    def __init__(self, device,
                 num_nodes,
                 seq_len,
                 dropout=0.3,
                 supports=None,
                 in_dim=2,
                 out_dim=12,
                 latchannels=32,
                 kernel_size=2,
                 blocks=4,
                 layers=2):
        super(sthsl, self).__init__()
        self.dropout = dropout
        self.blocks = blocks
        self.layers = layers
        self.supports = supports
        self.h_num_nodes = int(num_nodes)
        self.embed_dim = 10
        self.k = 128

        self.supports_len = 0
        if supports is not None:
            self.supports_len += len(supports)
        self.nodevec1 = nn.Parameter(torch.randn(num_nodes, self.embed_dim).to(device), requires_grad=True).to(device)
        self.nodevec2 = nn.Parameter(torch.randn(self.embed_dim, num_nodes).to(device), requires_grad=True).to(device)
        #self.h_adj = nn.Parameter(torch.Tensor(torch.randn([self.h_num_nodes, self.h_num_nodes])), requires_grad=True)
        self.encoder = Encoder(supports_len= self.supports_len + 1,in_dim=in_dim,blocks=blocks,
                               layers=layers,kernel_size=kernel_size,dropout=dropout,
                               latchannels=latchannels)
        self.decoder = Decoder(input_channels=latchannels,out_dim=out_dim)

    # def top_k(self):
    #     adp = torch.mm(self.nodevec1, self.nodevec2)
    #     values, indices = adp.topk(self.k, dim=-1)
    #     assert torch.max(indices) < adp.shape[1]
    #     mask = torch.zeros(adp.shape).cuda()
    #     mask[torch.arange(adp.shape[0]).view(-1, 1), indices] = 1.
    #     #mask[torch.arange(adp.shape[1]).view(1, -1), indices] = 1.
    #
    #     mask.requires_grad = False
    #     sparse_graph = adp * mask
    #     sparse_graph = F.relu(sparse_graph)
    #     adp = F.softmax(sparse_graph, dim=1)
    #     return adp
    def top_k(self):
        adp = torch.mm(self.nodevec1, self.nodevec2)
        adp = F.relu(adp)
        values, indices = adp.topk(self.k, dim=-1)
        assert torch.max(indices) < adp.shape[1]
        mask = torch.zeros(adp.shape).cuda()
        mask[torch.arange(adp.shape[0]).view(-1, 1), indices] = 1.

        mask.requires_grad = False
        sparse_graph = adp * mask
        adp = F.softmax(sparse_graph, dim=1)
        return adp

    def forward(self, input,input_neg=None):
        adp = self.top_k()
        if self.supports_len == 0:
            new_supports = [adp]
        else:
            new_supports = self.supports + [adp]
        x = self.encoder(input,new_supports)
        x = self.decoder(x)
        return x


if __name__ == '__main__':
    device = 'cuda:0'
    adjs = [torch.randn((207, 207)).to(device)]
    model = Encoder(len(adjs),in_dim=2,blocks=4,layers=2,kernel_size=2,dropout=0.3,latchannels=32).to(device)
    x = torch.randn((64,2,207,5)).to(device)
    y = model(x,adjs)
    print(y.shape)





