import random
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
import torch.nn.functional as F
import smplx


# -------------------   TCN   ---------------------- #

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size - 1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


# ----------- 1 full conv-based encoder------------- #
"""
from tm2t
TM2T: Stochastical and Tokenized Modeling for the Reciprocal Generation of 3D Human Motions and Texts
https://github.com/EricGuo5513/TM2T
"""
from .quantizer import *
from .utils.layer import ResBlock, init_weight, EncResBlock


class SCFormer(nn.Module):
    def __init__(self, args):
        super(VQEncoderV3, self).__init__()

        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down
        layers = [
            nn.Conv1d(input_size, channels[0], 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0]),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 4, 2, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i]),
            ]
        self.main = nn.Sequential(*layers)
        # self.out_net = nn.Linear(output_size, output_size)
        self.main.apply(init_weight)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):  # bs t n
        '''
        face 51 or 106
        hand 30*(15)
        upper body
        lower body
        global 1*3
        max length around 180 --> 450
        '''
        bs, t, n = inputs.shape
        inputs = inputs.reshape(bs * t, n)
        inputs = self.spatial_transformer_encoder(inputs)  # bs*t c
        cs = inputs.shape[1]
        inputs = inputs.reshape(bs, t, cs).permute(0, 2, 1).reshape(bs * cs, t)
        inputs = self.temporal_cnn_encoder(inputs)  # bs*c t
        ct = inputs.shape[1]
        outputs = inputs.reshape(bs, cs, ct).permute(0, 2, 1)  # bs ct cs
        return outputs


class VQEncoderV3(nn.Module):
    def __init__(self, args):
        super(VQEncoderV3, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down
        layers = [
            nn.Conv1d(input_size, channels[0], 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0]),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 4, 2, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i]),
            ]
        self.main = nn.Sequential(*layers)
        # self.out_net = nn.Linear(output_size, output_size)
        self.main.apply(init_weight)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs


class VQEncoderV6(nn.Module):
    def __init__(self, args):
        super(VQEncoderV6, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down
        layers = [
            nn.Conv1d(input_size, channels[0], 3, 1, 1),
            nn.Tanh(),  # nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0]),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.Tanh(),  # nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i]),
            ]
        self.main = nn.Sequential(*layers)
        # self.out_net = nn.Linear(output_size, output_size)
        self.main.apply(init_weight)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs


class VQEncoderV4(nn.Module):
    def __init__(self, args):
        super(VQEncoderV4, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down
        layers = [
            nn.Conv1d(input_size, channels[0], 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0]),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i]),
            ]
        self.main = nn.Sequential(*layers)
        # self.out_net = nn.Linear(output_size, output_size)
        self.main.apply(init_weight)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        # print(outputs.shape)
        return outputs


class VQEncoderV5(nn.Module):
    def __init__(self, args):
        super(VQEncoderV5, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down
        layers = [
            nn.Conv1d(input_size, channels[0], 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0]),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i]),
            ]
        self.main = nn.Sequential(*layers)
        # self.out_net = nn.Linear(output_size, output_size)
        self.main.apply(init_weight)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        # print(outputs.shape)
        return outputs


class VQEncoderV7(nn.Module):
    def __init__(self, args):
        super(VQEncoderV7, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down
        layers = [
            nn.Conv1d(input_size, channels[0], 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0]),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i]),
            ]
        self.v1 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v2 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v3 = nn.Parameter(torch.Tensor(4096, input_size))
        self.style_layer = nn.Linear(4096, input_size)
        self.main = nn.Sequential(*layers)
        # self.out_net = nn.Linear(output_size, output_size)
        self.main.apply(init_weight)
        self.style_layer.apply(init_weight)
        torch.nn.init.xavier_uniform(self.v1)
        torch.nn.init.xavier_uniform(self.v2)
        torch.nn.init.xavier_uniform(self.v3)
        # self.out_net.apply(init_weight)

    def forward(self, inputs, cond_emb):
        inputs = inputs.permute(0, 2, 1)
        # cond_emb = cond_emb.permute(0, 2, 1)

        # query = torch.matmul(cond_emb, self.v1)
        # key = torch.matmul(cond_emb, self.v2)
        # attention_scores = query*key#torch.matmul(query, key.transpose(-1, -2))
        # attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # value = torch.matmul(cond_emb, self.v3)
        # context = attention_probs*value #torch.matmul(attention_probs, value)
        # context = context.permute(0, 2, 1)
        style = self.style_layer(cond_emb)
        style = style.permute(0, 2, 1)
        inputs = inputs + style  # context
        outputs = self.main(inputs).permute(0, 2, 1)
        # print(outputs.shape)
        return outputs


class VQEncoderV9(nn.Module):
    def __init__(self, args):
        super(VQEncoderV9, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down

        layers = [
            nn.Conv1d(input_size, channels[0], 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 3, 1, 1),
        ]
        layers_0 = [
            nn.Conv1d(input_size, channels[0], 5, 1, 2),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 5, 1, 2),
        ]
        layers_1 = [
            nn.Conv1d(input_size, channels[0], 33, 1, 16),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 33, 1, 16),
        ]
        layers_2 = [
            nn.Conv1d(input_size, channels[0], 9, 1, 4),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 9, 1, 4),
        ]
        layers_3 = [
            nn.Conv1d(input_size, channels[0], 17, 1, 8),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 17, 1, 8),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i], 3, 1, 1),
            ]
            layers_0 += [
                nn.Conv1d(channels[i - 1], channels[i], 5, 1, 2),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 5, 1, 2),
            ]
            layers_1 += [
                nn.Conv1d(channels[i - 1], channels[i], 33, 1, 16),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i], 33, 1, 16),
            ]
            layers_2 += [
                nn.Conv1d(channels[i - 1], channels[i], 9, 1, 4),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 9, 1, 4),
            ]
            layers_3 += [
                nn.Conv1d(channels[i - 1], channels[i], 17, 1, 8),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 17, 1, 8),
            ]

        layers += [
            nn.Conv1d(channels[-1], channels[-1], 3, 1, 1),
            # nn.Tanh(),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_0 += [
            nn.Conv1d(channels[- 1], channels[-1], 5, 2, 2),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_1 += [
            nn.Conv1d(channels[-1], channels[-1], 33, 16, 16),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_2 += [
            nn.Conv1d(channels[-1], channels[-1], 9, 4, 4),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_3 += [
            nn.Conv1d(channels[-1], channels[-1], 17, 8, 8),
            # nn.Softplus(),]
            # nn.SiLU(0.2, inplace=True),
            nn.LeakyReLU(0.2, inplace=True), ]

        self.v1 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v2 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v3 = nn.Parameter(torch.Tensor(4096, input_size))
        self.w_in = nn.Linear(input_size, channels[0])
        # self.w_out_1 = nn.Linear(channels[0], channels[0])
        self.style_layer = nn.Linear(4096, input_size)

        self.main = nn.Sequential(*layers)
        self.main1 = nn.Sequential(*layers_1)
        self.log_softmax = nn.LogSoftmax(dim=1)
        # self.out_net = nn.Linear(output_size, output_size)

        self.main2 = nn.Sequential(*layers_2)
        self.main3 = nn.Sequential(*layers_3)
        self.main0 = nn.Sequential(*layers_0)
        self.main.apply(init_weight)
        self.main1.apply(init_weight)
        self.main0.apply(init_weight)
        self.main2.apply(init_weight)
        self.main3.apply(init_weight)
        self.style_layer.apply(init_weight)
        torch.nn.init.xavier_uniform(self.v1)
        torch.nn.init.xavier_uniform(self.v2)
        torch.nn.init.xavier_uniform(self.v3)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        output_list = []
        # cond_emb = cond_emb.permute(0, 2, 1)

        # query = torch.matmul(cond_emb, self.v1)
        # key = torch.matmul(cond_emb, self.v2)
        # attention_scores = query*key#torch.matmul(query, key.transpose(-1, -2))
        # attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # value = torch.matmul(cond_emb, self.v3)
        # context = attention_probs*value #torch.matmul(attention_probs, value)
        # context = context.permute(0, 2, 1)
        # style = self.style_layer(cond_emb)
        # style = style.permute(0, 2, 1)
        # inputs = inputs + style  # context
        # motion_inputs = inputs.permute(0, 2, 1).reshape(-1, inputs.shape[1])#inputs.permute(0, 2, 1).reshape(-1, inputs.shape[1]).unsqueeze(-1)
        # motion_inputs = motion_inputs.reshape(motion_inputs.shape[0],6,-1)
        # motion_outputs = self.motion_main(motion_inputs).reshape(inputs.shape[0],inputs.shape[2], -1)#.squeeze().reshape(inputs.shape[0], -1, inputs.shape[1])
        # inputs = (inputs+motion_outputs)/2
        # outputs_tcn = self.tcn_main(inputs).permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        # outputs = F.normalize(outputs, dim=2)
        outputs_0 = self.main0(inputs).permute(0, 2, 1)
        # outputs_0 = F.normalize(outputs_0, dim=2)
        # outputs_1 = self.main1(inputs).permute(0, 2, 1)
        outputs_2 = self.main2(inputs).permute(0, 2, 1)
        # outputs_2 = F.normalize(outputs_2, dim=2)
        outputs_3 = self.main3(inputs).permute(0, 2, 1)
        # outputs_3 = F.normalize(outputs_3, dim=2)
        # self.main(motion_outputs.permute(0, 2, 1)).permute(0, 2, 1)#self.main(inputs).permute(0, 2, 1)
        # keys = torch.concat((outputs.unsqueeze(1),outputs_1.unsqueeze(1),outputs_0.unsqueeze(1),outputs_3.unsqueeze(1)),dim=1)
        # query = self.w_in(inputs.permute(0, 2, 1))
        # scores = torch.matmul(query.reshape(query.shape[0], 1, -1), keys.reshape(query.shape[0],4,-1).permute(0, 2, 1)).squeeze(1)/torch.sqrt(torch.tensor(query.shape[1]*query.shape[2]))
        # scores = self.log_softmax(scores)
        output_list = [outputs, outputs_0, outputs_2, outputs_3]
        # outputs_f = torch.sum(scores.reshape(scores.shape[0], scores.shape[1], 1, 1)*keys, dim=1) #(outputs + outputs_1 + outputs_2 + outputs_3) / 4
        # outputs_f = outputs #(outputs_0 + outputs + outputs_1 + outputs_3)
        # outputs = torch.concat((outputs, motion_outputs),dim=-1)
        # outputs = (outputs+motion_outputs)/2
        # print(outputs.shape)
        return output_list  # outputs_f


class VQEncoderV19(nn.Module):
    def __init__(self, args):
        super(VQEncoderV19, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down

        layers = [
            nn.Conv1d(input_size, channels[0], 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 3, 1, 1),
        ]
        layers_0 = [
            nn.Conv1d(input_size, channels[0], 5, 1, 2),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 5, 1, 2),
        ]
        layers_1 = [
            nn.Conv1d(input_size, channels[0], 33, 1, 16),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 33, 1, 16),
        ]
        layers_2 = [
            nn.Conv1d(input_size, channels[0], 9, 1, 4),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 9, 1, 4),
        ]
        layers_3 = [
            nn.Conv1d(input_size, channels[0], 17, 1, 8),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 17, 1, 8),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i], 3, 1, 1),
            ]
            layers_0 += [
                nn.Conv1d(channels[i - 1], channels[i], 5, 1, 2),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 5, 1, 2),
            ]
            layers_1 += [
                nn.Conv1d(channels[i - 1], channels[i], 33, 1, 16),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i], 33, 1, 16),
            ]
            layers_2 += [
                nn.Conv1d(channels[i - 1], channels[i], 9, 1, 4),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 9, 1, 4),
            ]
            layers_3 += [
                nn.Conv1d(channels[i - 1], channels[i], 17, 1, 8),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 17, 1, 8),
            ]

        layers += [
            nn.Conv1d(channels[-1], channels[-1], 3, 1, 1),
            # nn.Tanh(),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_0 += [
            nn.Conv1d(channels[- 1], channels[-1], 5, 2, 2),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_1 += [
            nn.Conv1d(channels[-1], channels[-1], 33, 16, 16),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_2 += [
            nn.Conv1d(channels[-1], channels[-1], 9, 4, 4),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True), ]
        layers_3 += [
            nn.Conv1d(channels[-1], channels[-1], 17, 8, 8),
            # nn.Softplus(),]
            # nn.SiLU(0.2, inplace=True),
            nn.LeakyReLU(0.2, inplace=True), ]

        self.v1 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v2 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v3 = nn.Parameter(torch.Tensor(4096, input_size))
        self.w_in = nn.Linear(input_size, channels[0])
        # self.w_out_1 = nn.Linear(channels[0], channels[0])
        self.style_layer = nn.Linear(4096, input_size)

        self.main = nn.Sequential(*layers)
        self.main1 = nn.Sequential(*layers_1)
        self.log_softmax = nn.LogSoftmax(dim=1)
        # self.out_net = nn.Linear(output_size, output_size)

        self.main2 = nn.Sequential(*layers_2)
        self.main3 = nn.Sequential(*layers_3)
        self.main0 = nn.Sequential(*layers_0)
        self.main.apply(init_weight)
        self.main1.apply(init_weight)
        self.main0.apply(init_weight)
        self.main2.apply(init_weight)
        self.main3.apply(init_weight)
        self.style_layer.apply(init_weight)
        torch.nn.init.xavier_uniform(self.v1)
        torch.nn.init.xavier_uniform(self.v2)
        torch.nn.init.xavier_uniform(self.v3)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        output_list = []
        # cond_emb = cond_emb.permute(0, 2, 1)

        # query = torch.matmul(cond_emb, self.v1)
        # key = torch.matmul(cond_emb, self.v2)
        # attention_scores = query*key#torch.matmul(query, key.transpose(-1, -2))
        # attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # value = torch.matmul(cond_emb, self.v3)
        # context = attention_probs*value #torch.matmul(attention_probs, value)
        # context = context.permute(0, 2, 1)
        # style = self.style_layer(cond_emb)
        # style = style.permute(0, 2, 1)
        # inputs = inputs + style  # context
        # motion_inputs = inputs.permute(0, 2, 1).reshape(-1, inputs.shape[1])#inputs.permute(0, 2, 1).reshape(-1, inputs.shape[1]).unsqueeze(-1)
        # motion_inputs = motion_inputs.reshape(motion_inputs.shape[0],6,-1)
        # motion_outputs = self.motion_main(motion_inputs).reshape(inputs.shape[0],inputs.shape[2], -1)#.squeeze().reshape(inputs.shape[0], -1, inputs.shape[1])
        # inputs = (inputs+motion_outputs)/2
        # outputs_tcn = self.tcn_main(inputs).permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        # outputs = F.normalize(outputs, dim=2)
        outputs_0 = self.main0(inputs).permute(0, 2, 1)
        # outputs_0 = F.normalize(outputs_0, dim=2)
        # outputs_1 = self.main1(inputs).permute(0, 2, 1)
        outputs_2 = self.main2(inputs).permute(0, 2, 1)
        # outputs_2 = F.normalize(outputs_2, dim=2)
        outputs_3 = self.main3(inputs).permute(0, 2, 1)
        # outputs_3 = F.normalize(outputs_3, dim=2)
        # self.main(motion_outputs.permute(0, 2, 1)).permute(0, 2, 1)#self.main(inputs).permute(0, 2, 1)
        # keys = torch.concat((outputs.unsqueeze(1),outputs_1.unsqueeze(1),outputs_0.unsqueeze(1),outputs_3.unsqueeze(1)),dim=1)
        # query = self.w_in(inputs.permute(0, 2, 1))
        # scores = torch.matmul(query.reshape(query.shape[0], 1, -1), keys.reshape(query.shape[0],4,-1).permute(0, 2, 1)).squeeze(1)/torch.sqrt(torch.tensor(query.shape[1]*query.shape[2]))
        # scores = self.log_softmax(scores)
        output_list = [outputs, outputs_0, outputs_2, outputs_3]
        # outputs_f = torch.sum(scores.reshape(scores.shape[0], scores.shape[1], 1, 1)*keys, dim=1) #(outputs + outputs_1 + outputs_2 + outputs_3) / 4
        # outputs_f = outputs #(outputs_0 + outputs + outputs_1 + outputs_3)
        # outputs = torch.concat((outputs, motion_outputs),dim=-1)
        # outputs = (outputs+motion_outputs)/2
        # print(outputs.shape)
        return output_list  # outputs_f


class VQEncoderV11(nn.Module):
    def __init__(self):
        super(VQEncoderV11, self).__init__()
        n_down = 2
        channels = [256]
        for i in range(n_down - 1):
            channels.append(256)

        input_size = 256
        assert len(channels) == n_down

        layers = [
            nn.Conv1d(input_size, channels[0], 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            EncResBlock(channels[0], 3, 1, 1),
        ]
        layers_0 = [
            nn.Conv1d(input_size, channels[0], 5, 1, 2),
            nn.LeakyReLU(0.2, inplace=True),
            EncResBlock(channels[0], 5, 1, 2),
        ]
        layers_1 = [
            nn.Conv1d(input_size, channels[0], 33, 1, 16),
            nn.LeakyReLU(0.2, inplace=True),
            EncResBlock(channels[0], 33, 1, 16),
        ]
        layers_2 = [
            nn.Conv1d(input_size, channels[0], 9, 1, 4),
            nn.LeakyReLU(0.2, inplace=True),
            EncResBlock(channels[0], 9, 1, 4),
        ]
        layers_3 = [
            nn.Conv1d(input_size, channels[0], 17, 1, 8),
            nn.LeakyReLU(0.2, inplace=True),
            EncResBlock(channels[0], 17, 1, 8),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.LeakyReLU(0.2, inplace=True),
                EncResBlock(channels[i], 3, 1, 1),
            ]
            layers_0 += [
                nn.Conv1d(channels[i - 1], channels[i], 5, 1, 2),
                nn.LeakyReLU(0.2, inplace=True),
                EncResBlock(channels[0], 5, 1, 2),
            ]
            layers_1 += [
                nn.Conv1d(channels[i - 1], channels[i], 33, 1, 16),
                nn.LeakyReLU(0.2, inplace=True),
                EncResBlock(channels[i], 33, 1, 16),
            ]
            layers_2 += [
                nn.Conv1d(channels[i - 1], channels[i], 9, 1, 4),
                nn.LeakyReLU(0.2, inplace=True),
                EncResBlock(channels[0], 9, 1, 4),
            ]
            layers_3 += [
                nn.Conv1d(channels[i - 1], channels[i], 17, 1, 8),
                nn.LeakyReLU(0.2, inplace=True),
                EncResBlock(channels[0], 17, 1, 8),
            ]

        layers += [
            nn.Conv1d(channels[-1], channels[-1], 3, 1, 1),
            # nn.Tanh(),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),
        ]
        layers_0 += [
            nn.Conv1d(channels[- 1], channels[-1], 5, 2, 2),
            # nn.Tanh(),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),
        ]
        layers_1 += [
            nn.Conv1d(channels[-1], channels[-1], 33, 16, 16),
            # nn.Tanh(),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),
        ]
        layers_2 += [
            nn.Conv1d(channels[-1], channels[-1], 9, 4, 4),
            # nn.Tanh(),
            # nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),
        ]
        layers_3 += [
            nn.Conv1d(channels[-1], channels[-1], 17, 8, 8),
            # nn.Tanh(),
            # nn.SiLU(0.2, inplace=True),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        # self.v1 = nn.Parameter(torch.Tensor(4096, input_size))
        # self.v2 = nn.Parameter(torch.Tensor(4096, input_size))
        # self.v3 = nn.Parameter(torch.Tensor(4096, input_size))
        # self.w_in = nn.Linear(input_size, channels[0])
        # self.w_out_1 = nn.Linear(channels[0], channels[0])
        # self.style_layer = nn.Linear(4096, input_size)

        self.main = nn.Sequential(*layers)
        self.main1 = nn.Sequential(*layers_1)
        self.log_softmax = nn.LogSoftmax(dim=1)
        # self.out_net = nn.Linear(output_size, output_size)

        self.main2 = nn.Sequential(*layers_2)
        self.main3 = nn.Sequential(*layers_3)
        self.main0 = nn.Sequential(*layers_0)
        self.main.apply(init_weight)
        self.main1.apply(init_weight)
        self.main0.apply(init_weight)
        self.main2.apply(init_weight)
        self.main3.apply(init_weight)
        # self.style_layer.apply(init_weight)
        # torch.nn.init.xavier_uniform(self.v1)
        # torch.nn.init.xavier_uniform(self.v2)
        # torch.nn.init.xavier_uniform(self.v3)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        output_list = []

        outputs = self.main(inputs).permute(0, 2, 1) / (3 * 3)
        # outputs = F.normalize(outputs, dim=2)
        outputs_0 = self.main0(inputs).permute(0, 2, 1) / (5 * 5)
        # outputs_0 = F.normalize(outputs_0, dim=2)
        # outputs_1 = self.main1(inputs).permute(0, 2, 1)
        outputs_2 = self.main2(inputs).permute(0, 2, 1) / (9 * 9)
        # outputs_2 = F.normalize(outputs_2, dim=2)
        outputs_3 = self.main3(inputs).permute(0, 2, 1) / (17 * 17)
        # outputs_3 = F.normalize(outputs_3, dim=2)
        output_list = [outputs, outputs_0, outputs_2, outputs_3]

        # print(outputs.shape)
        return output_list  # outputs_f


class VQEncoderV10(nn.Module):
    def __init__(self, args):
        super(VQEncoderV10, self).__init__()
        n_down = args.vae_layer
        channels = [args.vae_length]
        for i in range(n_down - 1):
            channels.append(args.vae_length)

        input_size = args.vae_test_dim
        assert len(channels) == n_down

        layers = [
            nn.Conv1d(input_size, channels[0], 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 3, 1, 1),
        ]
        layers_0 = [
            nn.Conv1d(input_size, channels[0], 5, 1, 2),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 5, 1, 2),
        ]
        layers_1 = [
            nn.Conv1d(input_size, channels[0], 33, 1, 16),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 33, 1, 16),
        ]
        layers_2 = [
            nn.Conv1d(input_size, channels[0], 9, 1, 4),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 9, 1, 4),
        ]
        layers_3 = [
            nn.Conv1d(input_size, channels[0], 17, 1, 8),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(channels[0], 17, 1, 8),
        ]

        for i in range(1, n_down):
            layers += [
                nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i], 3, 1, 1),
            ]
            layers_0 += [
                nn.Conv1d(channels[i - 1], channels[i], 5, 1, 2),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 5, 1, 2),
            ]
            layers_1 += [
                nn.Conv1d(channels[i - 1], channels[i], 33, 1, 16),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[i], 33, 1, 16),
            ]
            layers_2 += [
                nn.Conv1d(channels[i - 1], channels[i], 9, 1, 4),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 9, 1, 4),
            ]
            layers_3 += [
                nn.Conv1d(channels[i - 1], channels[i], 17, 1, 8),
                nn.LeakyReLU(0.2, inplace=True),
                ResBlock(channels[0], 17, 1, 8),
            ]

        '''
        layers += [
            nn.Conv1d(channels[-1], channels[-1], 3, 1, 1),
            #nn.Tanh(),
            #nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),]
        layers_0 += [
            nn.Conv1d(channels[- 1], channels[-1], 5, 2, 2),
            #nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),]
        layers_1 += [
            nn.Conv1d(channels[-1], channels[-1], 33, 16, 16),
            #nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),]
        layers_2 += [
            nn.Conv1d(channels[-1], channels[-1], 9, 4, 4),
            #nn.Softplus(),]
            nn.LeakyReLU(0.2, inplace=True),]
        layers_3 += [
            nn.Conv1d(channels[-1], channels[-1], 17, 8, 8),
            #nn.Softplus(),]
            #nn.SiLU(0.2, inplace=True),
            nn.LeakyReLU(0.2, inplace=True),]
        '''

        self.v1 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v2 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v3 = nn.Parameter(torch.Tensor(4096, input_size))
        self.w_in = nn.Linear(input_size, channels[0])
        # self.w_out_1 = nn.Linear(channels[0], channels[0])
        self.style_layer = nn.Linear(4096, input_size)

        self.main = nn.Sequential(*layers)
        self.main1 = nn.Sequential(*layers_1)
        self.log_softmax = nn.LogSoftmax(dim=1)
        # self.out_net = nn.Linear(output_size, output_size)

        self.main2 = nn.Sequential(*layers_2)
        self.main3 = nn.Sequential(*layers_3)
        self.main0 = nn.Sequential(*layers_0)
        self.main.apply(init_weight)
        self.main1.apply(init_weight)
        self.main0.apply(init_weight)
        self.main2.apply(init_weight)
        self.main3.apply(init_weight)
        self.style_layer.apply(init_weight)
        torch.nn.init.xavier_uniform(self.v1)
        torch.nn.init.xavier_uniform(self.v2)
        torch.nn.init.xavier_uniform(self.v3)
        # self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        output_list = []
        outputs = self.main(inputs).permute(0, 2, 1)
        # outputs_0 = self.main0(inputs).permute(0, 2, 1)#.repeat_interleave(2,dim=1)/2
        # outputs_1 = self.main1(inputs).permute(0, 2, 1)
        # outputs_2 = self.main2(inputs).permute(0, 2, 1)#.repeat_interleave(4,dim=1)/4
        # outputs_3 = self.main3(inputs).permute(0, 2, 1)#.repeat_interleave(8,dim=1)/8
        # self.main(motion_outputs.permute(0, 2, 1)).permute(0, 2, 1)#self.main(inputs).permute(0, 2, 1)
        # keys = torch.concat((outputs.unsqueeze(1),outputs_1.unsqueeze(1),outputs_0.unsqueeze(1),outputs_3.unsqueeze(1)),dim=1)
        # query = self.w_in(inputs.permute(0, 2, 1))
        # scores = torch.matmul(query.reshape(query.shape[0], 1, -1), keys.reshape(query.shape[0],4,-1).permute(0, 2, 1)).squeeze(1)/torch.sqrt(torch.tensor(query.shape[1]*query.shape[2]))
        # scores = self.log_softmax(scores)
        output_list = outputs  # (outputs+outputs_0+outputs_2+outputs_3)/4#[outputs, outputs_0, outputs_2, outputs_3]

        return output_list  # outputs_f


class VQDecoderV4(nn.Module):
    def __init__(self, args):
        super(VQDecoderV4, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim)
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]

        for i in range(n_resblk):
            layers += [ResBlock(channels[0])]
        # channels = channels
        for i in range(n_up):
            up_factor = 2 if i < n_up - 1 else 1
            layers += [
                nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        self.main = nn.Sequential(*layers)
        self.main.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs


class VQDecoderV5(nn.Module):
    def __init__(self, args):
        super(VQDecoderV5, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim)
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]

        for i in range(n_resblk):
            layers += [ResBlock(channels[0])]
        # channels = channels
        for i in range(n_up):
            up_factor = 2 if i < n_up - 1 else 1
            layers += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        self.main = nn.Sequential(*layers)

        self.main.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)

        outputs = self.main(inputs).permute(0, 2, 1)

        return outputs


class VQDecoderV8(nn.Module):
    def __init__(self, args):
        super(VQDecoderV8, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim)
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
        motion_layers = [
            nn.Conv1d(6, 10, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            ResBlock(10),
        ]
        for i in range(n_resblk):
            layers += [ResBlock(channels[0])]
        # channels = channels
        for i in range(n_up):
            up_factor = 2 if i < n_up - 1 else 1
            layers += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        self.v1 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v2 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v3 = nn.Parameter(torch.Tensor(4096, input_size))
        self.style_layer = nn.Linear(4096, input_size)
        self.style_layer.apply(init_weight)
        self.main = nn.Sequential(*layers)
        self.motion_main = nn.Sequential(*motion_layers)
        self.main.apply(init_weight)

    def forward(self, inputs, cond_emb):
        inputs = inputs.permute(0, 2, 1)
        # query = torch.matmul(cond_emb, self.v1)
        # key = torch.matmul(cond_emb, self.v2)
        # attention_scores = torch.matmul(query, key.transpose(-1, -2))
        # attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # value = torch.matmul(cond_emb, self.v3)
        # context = torch.matmul(attention_probs, value)
        # context = context.permute(0, 2, 1)

        style = self.style_layer(cond_emb)
        style = style.permute(0, 2, 1)
        inputs = inputs + style  # context
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs


class AttentionWeightedFusion(nn.Module):
    def __init__(self, feature_dim, num_heads=2):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads

        # Linear projections for Q, K, V
        self.query = nn.Linear(feature_dim, feature_dim)
        self.key = nn.Linear(feature_dim, feature_dim)
        self.value = nn.Linear(feature_dim, feature_dim)

        # Output projection
        self.out_proj = nn.Linear(feature_dim, feature_dim)

        # Layer normalization
        self.layer_norm = nn.LayerNorm(feature_dim)

        self.attention = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, 1),
            nn.Softmax(dim=2)
        )

        self.alpha = nn.Parameter(torch.ones(1) * 0.5)

    def forward(self, outputs_list):
        # outputs_list: list of [batch_size, seq_len, feature_dim] tensors
        # Stack all outputs to create a tensor of shape [batch_size, seq_len, num_outputs, feature_dim]
        stacked = torch.stack(outputs_list, dim=2)  # [bs, seq_len, 4, feat_dim]
        batch_size, seq_len, num_outputs, feat_dim = stacked.shape

        # Reshape for multi-head attention
        q = stacked.mean(
            dim=-2)  # self.query(stacked.reshape(stacked.shape[0], -1 ,stacked.shape[-1]))  # [bs, seq_len, feat_dim]
        k = self.key(
            stacked)  # self.key(stacked.reshape(stacked.shape[0], -1 ,stacked.shape[-1])) #.view(batch_size, seq_len, num_outputs, self.num_heads, self.head_dim)
        v = self.value(
            stacked)  # self.value(stacked.reshape(stacked.shape[0], -1 ,stacked.shape[-1])) #.view(batch_size, seq_len, num_outputs, self.num_heads, self.head_dim)
        # mlp_v = self.value(stacked.reshape(stacked.shape[0], stacked.shape[1], -1))
        # Transpose for attention computation
        # k = k.permute(2, 0, 1, 3, 4)  # [4, bs, seq_len, num_heads, head_dim]
        # v = v.permute(2, 0, 1, 3, 4)  # [4, bs, seq_len, num_heads, head_dim]

        # Reshape queries
        # q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).unsqueeze(3)  # [bs, seq_len, num_heads, 1, head_dim]

        # Compute attention scores # ATTENTION 2
        # attn_scores = torch.matmul(q.unsqueeze(2), k.transpose(-1, -2)) # [bs, seq_len, num_heads, 1, 4]
        # attn_weights = F.softmax(attn_scores/math.sqrt(q.shape[-1]), dim=-1) #F.softmax(attn_scores, dim=-1)

        # Compute attention scores # ATTENTION 4
        # q = self.query(stacked.reshape(stacked.shape[0], stacked.shape[1], -1))
        attn_scores = torch.matmul(q.unsqueeze(2), k.transpose(-1, -2))  # [bs, seq_len, num_heads, 1, 4]
        attn_weights = F.softmax(attn_scores / math.sqrt(q.shape[-1]), dim=-1)  # F.softmax(attn_scores, dim=-1)

        # Compute attention scores # ATTENTION 3
        # q = self.query(stacked.reshape(stacked.shape[0], stacked.shape[1], -1))
        # k = self.key(stacked)
        # mlp_v = self.value(stacked.reshape(stacked.shape[0], stacked.shape[1], -1))
        # attn_scores = torch.matmul(q.unsqueeze(2), k.transpose(-1, -2)) # [bs, seq_len, num_heads, 1, 4]
        # attn_weights = F.softmax(attn_scores/math.sqrt(q.shape[-1]), dim=-1) #F.softmax(attn_scores, dim=-1)

        # Apply attention weights # ATTENTION 1
        # attn_weights = self.attention(stacked).view(batch_size, seq_len, 1, num_outputs) # ATTENTION 1
        # weighted = torch.matmul(attn_weights, v).squeeze(2)
        # Residual Attention
        weighted = q.squeeze(2) + torch.matmul(attn_weights, v).squeeze(
            2)  # *self.alpha.to(stacked.device) + q*(1-self.alpha.to(stacked.device))   # [bs, seq_len, num_heads, 1, head_dim] stacked.mean(dim=2)

        # weighted = weighted.squeeze(2).transpose(1, 2).reshape(batch_size, seq_len, feat_dim)

        # Project back to feature dimension
        # weighted = self.out_proj(weighted)

        # Add residual connection and layer norm
        # weighted = self.layer_norm(weighted)#(weighted + q)/2

        return weighted, attn_weights.squeeze(2)


class VQDecoderV9(nn.Module):
    def __init__(self, args):
        super(VQDecoderV9, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim)
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
            layers_0 = []
            layers_1 = []
            layers_2 = []
            layers_3 = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1),
                      nn.LeakyReLU(0.2, inplace=True), ]
            layers_0 = [nn.Conv1d(input_size, channels[0], kernel_size=5, stride=1, padding=2),
                        nn.LeakyReLU(0.2, inplace=True), ]
            layers_1 = [nn.Conv1d(input_size, channels[0], kernel_size=33, stride=1, padding=16),
                        nn.LeakyReLU(0.2, inplace=True), ]
            layers_2 = [nn.Conv1d(input_size, channels[0], kernel_size=9, stride=1, padding=4),
                        nn.LeakyReLU(0.2, inplace=True), ]
            layers_3 = [nn.Conv1d(input_size, channels[0], kernel_size=17, stride=1, padding=8),
                        nn.LeakyReLU(0.2, inplace=True), ]
        '''
        for i in range(n_resblk):
            layers += [ResBlock(channels[0], kernel_size=3, stride=1, padding=1)]
            layers_0 += [ResBlock(channels[0], kernel_size=5, stride=1, padding=2)]
            layers_1 += [ResBlock(channels[0], kernel_size=33, stride=1, padding=16)]
            layers_2 += [ResBlock(channels[0], kernel_size=9, stride=1, padding=4)]
            layers_3 += [ResBlock(channels[0], kernel_size=17, stride=1, padding=8)]
        '''
        # channels = channels

        layers += [nn.ConvTranspose1d(input_size, input_size, kernel_size=3, stride=1, padding=1)]
        layers_0 += [nn.ConvTranspose1d(input_size, input_size, kernel_size=5, stride=2, padding=2, output_padding=1)]
        layers_1 += [
            nn.ConvTranspose1d(input_size, input_size, kernel_size=33, stride=16, padding=16, output_padding=1)]
        layers_2 += [nn.ConvTranspose1d(input_size, input_size, kernel_size=9, stride=4, padding=4, output_padding=3)]
        layers_3 += [nn.ConvTranspose1d(input_size, input_size, kernel_size=17, stride=8, padding=8, output_padding=7)]

        for i in range(n_up):
            up_factor = 2 if i < n_up - 1 else 1
            layers += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            layers_0 += [
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=5, stride=1, padding=2),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            layers_1 += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=33, stride=1, padding=16),
                nn.LeakyReLU(0.2, inplace=True)
            ]

            layers_2 += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=9, stride=1, padding=4),
                nn.LeakyReLU(0.2, inplace=True)
            ]

            layers_3 += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=17, stride=1, padding=8),
                nn.LeakyReLU(0.2, inplace=True)
            ]

        layers += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        layers_0 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=5, stride=1, padding=2)]
        layers_1 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=33, stride=1, padding=16)]
        layers_2 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=9, stride=1, padding=4)]
        layers_3 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=17, stride=1, padding=8)]

        self.attention_fusion = AttentionWeightedFusion(feature_dim=channels[-1])

        self.main = nn.Sequential(*layers)
        self.main0 = nn.Sequential(*layers_0)
        self.main1 = nn.Sequential(*layers_1)
        self.w_in = nn.Linear(channels[0], channels[-1])
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.main.apply(init_weight)
        self.main0.apply(init_weight)
        self.main1.apply(init_weight)
        self.main2 = nn.Sequential(*layers_2)
        self.main3 = nn.Sequential(*layers_3)
        self.main2.apply(init_weight)
        self.main3.apply(init_weight)

    def forward(self, inputs):
        # inputs = inputs.permute(0, 2, 1)
        # inputs1 = inputs[:, :256, :]
        # inputs2 = inputs[:, 256:, :]
        # outputs_tcn = self.tcn_main(inputs).permute(0, 2, 1)
        # inputs = [torch.ones_like(inputs[i]).cuda() for i in range(len(inputs))]
        outputs = self.main(inputs[0].permute(0, 2, 1)).permute(0, 2, 1)
        outputs0 = self.main0(inputs[1].permute(0, 2, 1)).permute(0, 2, 1)
        # outputs1 = self.main1(inputs[2].permute(0, 2, 1)).permute(0, 2, 1) # 16
        outputs2 = self.main2(inputs[2].permute(0, 2, 1)).permute(0, 2, 1)
        outputs3 = self.main3(inputs[3].permute(0, 2, 1)).permute(0, 2, 1)
        # keys = torch.concat(
        #    (outputs.unsqueeze(1), outputs1.unsqueeze(1), outputs0.unsqueeze(1), outputs3.unsqueeze(1)), dim=1)
        # query = self.w_in(inputs.permute(0, 2, 1))
        # scores = torch.matmul(query.reshape(query.shape[0], 1, -1), keys.reshape(query.shape[0], 4, -1).permute(0, 2, 1)).squeeze(1) / torch.sqrt(
        #    torch.tensor(query.shape[1] * query.shape[2]))
        # scores = self.log_softmax(scores)
        # outputs_f = torch.sum(scores.reshape(scores.shape[0], scores.shape[1], 1, 1) * keys, dim=1)  # (outputs + outputs_1 + outputs_2 + outputs_3) / 4

        # inputs2 = inputs2.reshape(-1, inputs2.shape[1])
        # inputs2 = inputs2.reshape(inputs2.shape[0], 10, -1)
        # outputs1_in = outputs1.reshape(-1, outputs1.shape[-1])
        # outputs1_in = outputs1_in.reshape(outputs1_in.shape[0],6,-1)
        # outputs2 = self.motion_main(outputs1_in).reshape(outputs1.shape)
        # outputs_f = (outputs + outputs0 + outputs2 + outputs3)/4
        # outputs_f = (outputs + outputs0.repeat_interleave(repeats=2,dim=1) + outputs2.repeat_interleave(repeats=4,dim=1) + outputs3.repeat_interleave(repeats=8,dim=1))/4
        # outputs = outputs1#*self.weight1+(1-self.weight1)*outputs2
        outputs_f, attn_weights = self.attention_fusion([outputs, outputs0, outputs2, outputs3])
        # outputs_f = (outputs + outputs0 + outputs2 + outputs3) / 4
        return outputs_f #, attn_weights


class VQDecoderV19(nn.Module):
    def __init__(self, args):
        super(VQDecoderV19, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim)
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
            layers_0 = []
            layers_1 = []
            layers_2 = []
            layers_3 = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1),
                      nn.LeakyReLU(0.2, inplace=True), ]
            layers_0 = [nn.Conv1d(input_size, channels[0], kernel_size=5, stride=1, padding=2),
                        nn.LeakyReLU(0.2, inplace=True), ]
            layers_1 = [nn.Conv1d(input_size, channels[0], kernel_size=33, stride=1, padding=16),
                        nn.LeakyReLU(0.2, inplace=True), ]
            layers_2 = [nn.Conv1d(input_size, channels[0], kernel_size=9, stride=1, padding=4),
                        nn.LeakyReLU(0.2, inplace=True), ]
            layers_3 = [nn.Conv1d(input_size, channels[0], kernel_size=17, stride=1, padding=8),
                        nn.LeakyReLU(0.2, inplace=True), ]
        '''
        for i in range(n_resblk):
            layers += [ResBlock(channels[0], kernel_size=3, stride=1, padding=1)]
            layers_0 += [ResBlock(channels[0], kernel_size=5, stride=1, padding=2)]
            layers_1 += [ResBlock(channels[0], kernel_size=33, stride=1, padding=16)]
            layers_2 += [ResBlock(channels[0], kernel_size=9, stride=1, padding=4)]
            layers_3 += [ResBlock(channels[0], kernel_size=17, stride=1, padding=8)]
        '''
        # channels = channels

        layers += [nn.ConvTranspose1d(input_size, input_size, kernel_size=3, stride=1, padding=1)]
        layers_0 += [nn.ConvTranspose1d(input_size, input_size, kernel_size=5, stride=2, padding=2, output_padding=1)]
        layers_1 += [
            nn.ConvTranspose1d(input_size, input_size, kernel_size=33, stride=16, padding=16, output_padding=1)]
        layers_2 += [nn.ConvTranspose1d(input_size, input_size, kernel_size=9, stride=4, padding=4, output_padding=3)]
        layers_3 += [nn.ConvTranspose1d(input_size, input_size, kernel_size=17, stride=8, padding=8, output_padding=7)]

        for i in range(n_up):
            up_factor = 2 if i < n_up - 1 else 1
            layers += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            layers_0 += [
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=5, stride=1, padding=2),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            layers_1 += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=33, stride=1, padding=16),
                nn.LeakyReLU(0.2, inplace=True)
            ]

            layers_2 += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=9, stride=1, padding=4),
                nn.LeakyReLU(0.2, inplace=True)
            ]

            layers_3 += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.ConvTranspose1d(channels[i], channels[i + 1], kernel_size=17, stride=1, padding=8),
                nn.LeakyReLU(0.2, inplace=True)
            ]

        layers += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        layers_0 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=5, stride=1, padding=2)]
        layers_1 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=33, stride=1, padding=16)]
        layers_2 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=9, stride=1, padding=4)]
        layers_3 += [nn.ConvTranspose1d(channels[-1], channels[-1], kernel_size=17, stride=1, padding=8)]

        self.main = nn.Sequential(*layers)
        self.main0 = nn.Sequential(*layers_0)
        self.main1 = nn.Sequential(*layers_1)
        self.w_in = nn.Linear(channels[0], channels[-1])
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.main.apply(init_weight)
        self.main0.apply(init_weight)
        self.main1.apply(init_weight)
        self.main2 = nn.Sequential(*layers_2)
        self.main3 = nn.Sequential(*layers_3)
        self.main2.apply(init_weight)
        self.main3.apply(init_weight)

    def forward(self, inputs):
        # inputs = inputs.permute(0, 2, 1)
        # inputs1 = inputs[:, :256, :]
        # inputs2 = inputs[:, 256:, :]
        # outputs_tcn = self.tcn_main(inputs).permute(0, 2, 1)
        outputs = self.main(inputs[0].permute(0, 2, 1)).permute(0, 2, 1)
        outputs0 = self.main0(inputs[1].permute(0, 2, 1)).permute(0, 2, 1)
        # outputs1 = self.main1(inputs[2].permute(0, 2, 1)).permute(0, 2, 1)
        outputs2 = self.main2(inputs[2].permute(0, 2, 1)).permute(0, 2, 1)
        outputs3 = self.main3(inputs[3].permute(0, 2, 1)).permute(0, 2, 1)
        # keys = torch.concat(
        #    (outputs.unsqueeze(1), outputs1.unsqueeze(1), outputs0.unsqueeze(1), outputs3.unsqueeze(1)), dim=1)
        # query = self.w_in(inputs.permute(0, 2, 1))
        # scores = torch.matmul(query.reshape(query.shape[0], 1, -1), keys.reshape(query.shape[0], 4, -1).permute(0, 2, 1)).squeeze(1) / torch.sqrt(
        #    torch.tensor(query.shape[1] * query.shape[2]))
        # scores = self.log_softmax(scores)
        # outputs_f = torch.sum(scores.reshape(scores.shape[0], scores.shape[1], 1, 1) * keys, dim=1)  # (outputs + outputs_1 + outputs_2 + outputs_3) / 4

        # inputs2 = inputs2.reshape(-1, inputs2.shape[1])
        # inputs2 = inputs2.reshape(inputs2.shape[0], 10, -1)
        # outputs1_in = outputs1.reshape(-1, outputs1.shape[-1])
        # outputs1_in = outputs1_in.reshape(outputs1_in.shape[0],6,-1)
        # outputs2 = self.motion_main(outputs1_in).reshape(outputs1.shape)
        outputs_f = (outputs + outputs0 + outputs2 + outputs3) / 4
        # outputs_f = (outputs + outputs0.repeat_interleave(repeats=2,dim=1) + outputs2.repeat_interleave(repeats=4,dim=1) + outputs3.repeat_interleave(repeats=8,dim=1))/4
        # outputs = outputs1#*self.weight1+(1-self.weight1)*outputs2
        return outputs_f #, outputs_f


class VQDecoderV7(nn.Module):
    def __init__(self, args):
        super(VQDecoderV7, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim + 4)
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]

        for i in range(n_resblk):
            layers += [ResBlock(channels[0])]
        # channels = channels
        for i in range(n_up):
            up_factor = 2 if i < n_up - 1 else 1
            layers += [
                # nn.Upsample(scale_factor=up_factor, mode='nearest'),
                nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        self.v1 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v2 = nn.Parameter(torch.Tensor(4096, input_size))
        self.v3 = nn.Parameter(torch.Tensor(4096, input_size))
        self.main = nn.Sequential(*layers)
        self.main.apply(init_weight)

    def forward(self, inputs, cond_emb):
        inputs = inputs.permute(0, 2, 1)
        query = torch.matmul(cond_emb, self.v1)
        key = torch.matmul(cond_emb, self.v2)
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        value = torch.matmul(cond_emb, self.v3)
        context = torch.matmul(attention_probs, value)
        context = context.permute(0, 2, 1)
        inputs = inputs + context
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs


class VQDecoderV3(nn.Module):
    def __init__(self, args):
        super(VQDecoderV3, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim)
        input_size = args.vae_length
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]

        for i in range(n_resblk):
            layers += [ResBlock(channels[0])]
        # channels = channels
        for i in range(n_up):
            layers += [
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        self.main = nn.Sequential(*layers)
        self.main.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs


class VQDecoderV6(nn.Module):
    def __init__(self, args):
        super(VQDecoderV6, self).__init__()
        n_up = args.vae_layer
        channels = []
        for i in range(n_up - 1):
            channels.append(args.vae_length)
        channels.append(args.vae_length)
        channels.append(args.vae_test_dim)
        input_size = args.vae_length * 2
        n_resblk = 2
        assert len(channels) == n_up + 1
        if input_size == channels[0]:
            layers = []
        else:
            layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]

        for i in range(n_resblk):
            layers += [ResBlock(channels[0])]
        # channels = channels
        for i in range(n_up):
            layers += [
                # nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
        self.main = nn.Sequential(*layers)
        self.main.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        return outputs


# -----------2 conv+mlp based fix-length input ae ------------- #
from .utils.layer import reparameterize, ConvNormRelu, BasicBlock

"""
from Trimodal,
encoder:
    bs, n, c_in --conv--> bs, n/k, c_out_0 --mlp--> bs, c_out_1, only support fixed length
decoder:
    bs, c_out_1 --mlp--> bs, n/k*c_out_0 --> bs, n/k, c_out_0 --deconv--> bs, n, c_in
"""


class PoseEncoderConv(nn.Module):
    def __init__(self, length, dim, feature_length=32):
        super().__init__()
        self.base = feature_length
        self.net = nn.Sequential(
            ConvNormRelu(dim, self.base, batchnorm=True),  # 32
            ConvNormRelu(self.base, self.base * 2, batchnorm=True),  # 30
            ConvNormRelu(self.base * 2, self.base * 2, True, batchnorm=True),  # 14
            nn.Conv1d(self.base * 2, self.base, 3)
        )
        self.out_net = nn.Sequential(
            nn.Linear(12 * self.base, self.base * 4),  # for 34 frames
            nn.BatchNorm1d(self.base * 4),
            nn.LeakyReLU(True),
            nn.Linear(self.base * 4, self.base * 2),
            nn.BatchNorm1d(self.base * 2),
            nn.LeakyReLU(True),
            nn.Linear(self.base * 2, self.base),
        )
        self.fc_mu = nn.Linear(self.base, self.base)
        self.fc_logvar = nn.Linear(self.base, self.base)

    def forward(self, poses, variational_encoding=None):
        poses = poses.transpose(1, 2)  # to (bs, dim, seq)
        out = self.net(poses)
        out = out.flatten(1)
        out = self.out_net(out)
        mu = self.fc_mu(out)
        logvar = self.fc_logvar(out)
        if variational_encoding:
            z = reparameterize(mu, logvar)
        else:
            z = mu
        return z, mu, logvar


class PoseDecoderFC(nn.Module):
    def __init__(self, gen_length, pose_dim, use_pre_poses=False):
        super().__init__()
        self.gen_length = gen_length
        self.pose_dim = pose_dim
        self.use_pre_poses = use_pre_poses

        in_size = 32
        if use_pre_poses:
            self.pre_pose_net = nn.Sequential(
                nn.Linear(pose_dim * 4, 32),
                nn.BatchNorm1d(32),
                nn.ReLU(),
                nn.Linear(32, 32),
            )
            in_size += 32

        self.net = nn.Sequential(
            nn.Linear(in_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, gen_length * pose_dim),
        )

    def forward(self, latent_code, pre_poses=None):
        if self.use_pre_poses:
            pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1))
            feat = torch.cat((pre_pose_feat, latent_code), dim=1)
        else:
            feat = latent_code
        output = self.net(feat)
        output = output.view(-1, self.gen_length, self.pose_dim)
        return output


class PoseDecoderConv(nn.Module):
    def __init__(self, length, dim, use_pre_poses=False, feature_length=32):
        super().__init__()
        self.use_pre_poses = use_pre_poses
        self.feat_size = feature_length

        if use_pre_poses:
            self.pre_pose_net = nn.Sequential(
                nn.Linear(dim * 4, 32),
                nn.BatchNorm1d(32),
                nn.ReLU(),
                nn.Linear(32, 32),
            )
            self.feat_size += 32

        if length == 64:
            self.pre_net = nn.Sequential(
                nn.Linear(self.feat_size, self.feat_size),
                nn.BatchNorm1d(self.feat_size),
                nn.LeakyReLU(True),
                nn.Linear(self.feat_size, self.feat_size // 8 * 64),
            )
        elif length == 34:
            self.pre_net = nn.Sequential(
                nn.Linear(self.feat_size, self.feat_size * 2),
                nn.BatchNorm1d(self.feat_size * 2),
                nn.LeakyReLU(True),
                nn.Linear(self.feat_size * 2, self.feat_size // 8 * 34),
            )
        elif length == 32:
            self.pre_net = nn.Sequential(
                nn.Linear(self.feat_size, self.feat_size * 2),
                nn.BatchNorm1d(self.feat_size * 2),
                nn.LeakyReLU(True),
                nn.Linear(self.feat_size * 2, self.feat_size // 8 * 32),
            )
        else:
            assert False
        self.decoder_size = self.feat_size // 8
        self.net = nn.Sequential(
            nn.ConvTranspose1d(self.decoder_size, self.feat_size, 3),
            nn.BatchNorm1d(self.feat_size),
            nn.LeakyReLU(0.2, True),

            nn.ConvTranspose1d(self.feat_size, self.feat_size, 3),
            nn.BatchNorm1d(self.feat_size),
            nn.LeakyReLU(0.2, True),
            nn.Conv1d(self.feat_size, self.feat_size * 2, 3),
            nn.Conv1d(self.feat_size * 2, dim, 3),
        )

    def forward(self, feat, pre_poses=None):
        if self.use_pre_poses:
            pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1))
            feat = torch.cat((pre_pose_feat, feat), dim=1)
        # print(feat.shape)
        out = self.pre_net(feat)
        # print(out.shape)
        out = out.view(feat.shape[0], self.decoder_size, -1)
        # print(out.shape)
        out = self.net(out)
        out = out.transpose(1, 2)
        return out


'''
Our CaMN Modification
'''


class PoseEncoderConvResNet(nn.Module):
    def __init__(self, length, dim, feature_length=32):
        super().__init__()
        self.base = feature_length
        self.conv1 = BasicBlock(dim, self.base, reduce_first=1, downsample=False, first_dilation=1)  # 34
        self.conv2 = BasicBlock(self.base, self.base * 2, downsample=False, first_dilation=1, )  # 34
        self.conv3 = BasicBlock(self.base * 2, self.base * 2, first_dilation=1, downsample=True,
                                stride=2)  # 17
        self.conv4 = BasicBlock(self.base * 2, self.base, first_dilation=1, downsample=False)

        self.out_net = nn.Sequential(
            # nn.Linear(864, 256),  # for 64 frames
            nn.Linear(17 * self.base, self.base * 4),  # for 34 frames
            nn.BatchNorm1d(self.base * 4),
            nn.LeakyReLU(True),
            nn.Linear(self.base * 4, self.base * 2),
            nn.BatchNorm1d(self.base * 2),
            nn.LeakyReLU(True),
            nn.Linear(self.base * 2, self.base),
        )

        self.fc_mu = nn.Linear(self.base, self.base)
        self.fc_logvar = nn.Linear(self.base, self.base)

    def forward(self, poses, variational_encoding=None):
        poses = poses.transpose(1, 2)  # to (bs, dim, seq)
        out1 = self.conv1(poses)
        out2 = self.conv2(out1)
        out3 = self.conv3(out2)
        out = self.conv4(out3)
        out = out.flatten(1)
        out = self.out_net(out)
        mu = self.fc_mu(out)
        logvar = self.fc_logvar(out)
        if variational_encoding:
            z = reparameterize(mu, logvar)
        else:
            z = mu
        return z, mu, logvar

    # -----------3 lstm ------------- #


'''
bs, n, c_int --> bs, n, c_out or bs, 1 (hidden), c_out 
'''


class AELSTM(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.motion_emb = nn.Linear(args.vae_test_dim, args.vae_length)
        self.lstm = nn.LSTM(args.vae_length, hidden_size=args.vae_length, num_layers=4, batch_first=True,
                            bidirectional=True, dropout=0.3)
        self.out = nn.Sequential(
            nn.Linear(args.vae_length, args.vae_length // 2),
            nn.LeakyReLU(0.2, True),
            nn.Linear(args.vae_length // 2, args.vae_test_dim)
        )
        self.hidden_size = args.vae_length

    def forward(self, inputs):
        poses = self.motion_emb(inputs)
        out, _ = self.lstm(poses)
        out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:]
        out_poses = self.out(out)
        return {
            "poses_feat": out,
            "rec_pose": out_poses,
        }


class PoseDecoderLSTM(nn.Module):
    """
    input bs*n*64
    """

    def __init__(self, pose_dim, feature_length):
        super().__init__()
        self.pose_dim = pose_dim
        self.base = feature_length
        self.hidden_size = 256
        self.lstm_d = nn.LSTM(self.base, hidden_size=self.hidden_size, num_layers=4, batch_first=True,
                              bidirectional=True, dropout=0.3)
        self.out_d = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.LeakyReLU(True),
            nn.Linear(self.hidden_size // 2, self.pose_dim)
        )

    def forward(self, latent_code):
        output, _ = self.lstm_d(latent_code)
        output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:]  # sum bidirectional outputs
        # print("outd:", output.shape)
        output = self.out_d(output.reshape(-1, output.shape[2]))
        output = output.view(latent_code.shape[0], latent_code.shape[1], -1)
        # print("resotuput:", output.shape)
        return output


# ---------------4 transformer --------------- #
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # .transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        # print(self.pe.shape, x.shape)
        x = x + self.pe[:, :x.shape[1]]
        return self.dropout(x)


class Encoder_TRANSFORMER(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.skelEmbedding = nn.Linear(args.vae_test_dim, args.vae_length)
        self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3)
        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=args.vae_length,
                                                          nhead=4,
                                                          dim_feedforward=1025,
                                                          dropout=0.3,
                                                          activation="gelu",
                                                          batch_first=True
                                                          )
        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                     num_layers=4)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, inputs):
        x = self.skelEmbedding(inputs)  # bs * n * 128
        # print(x.shape)
        xseq = self.sequence_pos_encoder(x)
        device = xseq.device
        # mask = self._generate_square_subsequent_mask(xseq.size(1)).to(device)
        final = self.seqTransEncoder(xseq)
        # print(final.shape)
        mu = final[:, 0:1, :]
        logvar = final[:, 1:2, :]
        return final, mu, logvar


class Decoder_TRANSFORMER(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.vae_test_len = args.vae_test_len
        self.vae_length = args.vae_length
        self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3)
        seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=args.vae_length,
                                                          nhead=4,
                                                          dim_feedforward=1024,
                                                          dropout=0.3,
                                                          activation="gelu",
                                                          batch_first=True)
        self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
                                                     num_layers=4)
        self.finallayer = nn.Linear(args.vae_length, args.vae_test_dim)

    def forward(self, inputs):
        timequeries = torch.zeros(inputs.shape[0], self.vae_test_len, self.vae_length, device=inputs.device)
        timequeries = self.sequence_pos_encoder(timequeries)
        output = self.seqTransDecoder(tgt=timequeries, memory=inputs)
        output = self.finallayer(output)
        return output

    # --------- 5 skcnn --------------- #


'''
from NeMF,
NeMF: Neural Motion Fields for Kinematic Animation
'''
from .utils.skeleton import ResidualBlock, SkeletonResidual, residual_ratio, SkeletonConv, SkeletonPool, find_neighbor, \
    build_edge_topology


class LocalEncoder(nn.Module):
    def __init__(self, args, topology):
        super(LocalEncoder, self).__init__()
        args.channel_base = 6
        args.activation = "tanh"
        args.use_residual_blocks = True
        args.z_dim = 1024
        args.temporal_scale = 8
        args.kernel_size = 4
        args.num_layers = args.vae_layer
        args.skeleton_dist = 2
        args.extra_conv = 0
        # check how to reflect in 1d
        args.padding_mode = "constant"
        args.skeleton_pool = "mean"
        args.upsampling = "linear"

        self.topologies = [topology]
        self.channel_base = [args.channel_base]

        self.channel_list = []
        self.edge_num = [len(topology)]
        self.pooling_list = []
        self.layers = nn.ModuleList()
        self.args = args
        # self.convs = []

        kernel_size = args.kernel_size
        kernel_even = False if kernel_size % 2 else True
        padding = (kernel_size - 1) // 2
        bias = True
        self.grow = args.vae_grow
        for i in range(args.num_layers):
            self.channel_base.append(self.channel_base[-1] * self.grow[i])

        for i in range(args.num_layers):
            seq = []
            neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist)
            in_channels = self.channel_base[i] * self.edge_num[i]
            out_channels = self.channel_base[i + 1] * self.edge_num[i]
            if i == 0:
                self.channel_list.append(in_channels)
            self.channel_list.append(out_channels)
            last_pool = True if i == args.num_layers - 1 else False

            # (T, J, D) => (T, J', D)
            pool = SkeletonPool(edges=self.topologies[i], pooling_mode=args.skeleton_pool,
                                channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool)

            if args.use_residual_blocks:
                # (T, J, D) => (T/2, J', 2D)
                seq.append(SkeletonResidual(self.topologies[i], neighbour_list, joint_num=self.edge_num[i],
                                            in_channels=in_channels, out_channels=out_channels,
                                            kernel_size=kernel_size, stride=2, padding=padding,
                                            padding_mode=args.padding_mode, bias=bias,
                                            extra_conv=args.extra_conv, pooling_mode=args.skeleton_pool,
                                            activation=args.activation, last_pool=last_pool))
            else:
                for _ in range(args.extra_conv):
                    # (T, J, D) => (T, J, D)
                    seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
                                            joint_num=self.edge_num[i],
                                            kernel_size=kernel_size - 1 if kernel_even else kernel_size,
                                            stride=1,
                                            padding=padding, padding_mode=args.padding_mode, bias=bias))
                    seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh())
                # (T, J, D) => (T/2, J, 2D)
                seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
                                        joint_num=self.edge_num[i], kernel_size=kernel_size, stride=2,
                                        padding=padding, padding_mode=args.padding_mode, bias=bias, add_offset=False,
                                        in_offset_channel=3 * self.channel_base[i] // self.channel_base[0]))
                # self.convs.append(seq[-1])

                seq.append(pool)
                seq.append(nn.PReLU() if args.activation == 'relu' else nn.Tanh())
            self.layers.append(nn.Sequential(*seq))

            self.topologies.append(pool.new_edges)
            self.pooling_list.append(pool.pooling_list)
            self.edge_num.append(len(self.topologies[-1]))

        # in_features = self.channel_base[-1] * len(self.pooling_list[-1])
        # in_features *= int(args.temporal_scale / 2)
        # self.reduce = nn.Linear(in_features, args.z_dim)
        # self.mu = nn.Linear(in_features, args.z_dim)
        # self.logvar = nn.Linear(in_features, args.z_dim)

    def forward(self, input):
        # bs, n, c = input.shape[0], input.shape[1], input.shape[2]
        output = input.permute(0, 2, 1)  # input.reshape(bs, n, -1, 6)
        for layer in self.layers:
            output = layer(output)
        # output = output.view(output.shape[0], -1)
        output = output.permute(0, 2, 1)
        return output


