import torch
import torch.nn as nn
from components import LossFunctions, InferenceNet, GenerationNet
import torch.nn.functional as F

class CausalConv1d(torch.nn.Conv1d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 dilation=1,
                 groups=1,
                 bias=True):
        super(CausalConv1d, self).__init__(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias)

        self.__padding = (kernel_size - 1) * dilation

    def forward(self, input):
        return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0)))


class context_embedding(torch.nn.Module):
    def __init__(self, in_channels=1, embedding_size=256, k=5):
        super(context_embedding, self).__init__()
        self.causal_convolution = CausalConv1d(in_channels, embedding_size, kernel_size=k)

    def forward(self, x):
        x = self.causal_convolution(x)
        return F.tanh(x)

class StackedVAGT(nn.Module):
    def __init__(self, layer_xz=1, layer_h=3, n_head=8, x_dim=36, z_dim=15, h_dim=20, embd_h=256, embd_s=256, q_len=1,
                 vocab_len=128, win_len=20, horizon=12, dropout=0.1, beta=0.1, anneal_rate=1, max_beta=1, device=torch.device('cpu')):
        super(StackedVAGT, self).__init__()
        """
        In this class we will merge inference and generation altogether for simplification!
        """
        self.beta = beta
        self.max_beta = max_beta
        self.anneal_rate = anneal_rate
        self.layer_xz = layer_xz
        self.layer_h = layer_h
        self.n_head = n_head
        self.x_dim = x_dim
        self.z_dim = z_dim
        self.h_dim = h_dim
        self.embd_h = embd_h
        self.embd_s = embd_s
        self.q_len = q_len
        self.vocab_len = vocab_len
        self.win_len = win_len
        self.horizon = horizon
        self.dropout = dropout

        self.losses = LossFunctions()
        # For Inference and Generation
        self.inference = InferenceNet(z_dim, x_dim, h_dim, embd_h, layer_xz, layer_h, n_head,
                                      vocab_len, dropout, q_len, win_len, device)
        self.end_conv = nn.Conv2d(win_len,horizon, kernel_size=(1,1), bias=True)
        self.end_fc = nn.Linear(h_dim, x_dim)
        self.generation = GenerationNet(h_dim, z_dim, x_dim, win_len)

        # For Graph
        self.gamma0 = nn.Parameter(torch.randn((embd_s, self.h_dim), requires_grad=True))
        self.register_parameter('gamma0', self.gamma0)
        self.gamma1 = nn.Parameter(torch.randn((embd_s, self.z_dim), requires_grad=True))
        self.register_parameter('gamma1', self.gamma1)
        self.gamma2 = nn.Parameter(torch.randn((embd_s, self.h_dim), requires_grad=True))
        self.register_parameter('gamma2', self.gamma2)
        self.alpha0 = nn.Parameter(torch.randn((self.x_dim, embd_s), requires_grad=True))
        self.register_parameter('alpha0', self.alpha0)
        self.alpha1 = nn.Parameter(torch.randn((self.z_dim, embd_s), requires_grad=True))
        self.register_parameter('alpha1', self.alpha1)

        # self.alpha0 = torch.eye(self.x_dim, embd_s).cuda()
        # self.alpha1 = torch.eye(self.z_dim, embd_s).cuda()


        self.causal_conv = nn.Conv2d(3,1, kernel_size=(1, 1), bias=True)

    def loss_LLH(self, x, x_mu, x_logsigma):
        loglikelihood = self.losses.log_normal(x.float(), x_mu.float(), torch.pow(torch.exp(x_logsigma.float()), 2))
        return loglikelihood

    def loss_KL(self, z_mean_posterior_forward, z_logvar_posterior_forward, z_mean_prior_forward,
                z_logvar_prior_forward):
        z_var_posterior_forward = torch.exp(z_logvar_posterior_forward)
        z_var_prior_forward = torch.exp(z_logvar_prior_forward)
        kld_z_forward = 0.5 * torch.sum(z_logvar_prior_forward - z_logvar_posterior_forward +
                                        ((z_var_posterior_forward + torch.pow(
                                            z_mean_posterior_forward - z_mean_prior_forward, 2)) /
                                         z_var_prior_forward) - 1)
        return kld_z_forward

    def  mse_loss(self, x, target):
        loss = torch.nn.MSELoss()
        output = loss(x,target)
        return output

    def mae_loss(self,x,target):
        loss = torch.nn.L1Loss()
        output = loss(x,target)
        return output

    def forward(self, x):
        # print(x.shape)
        # x=x.permute(0,3,2,1)
        #
        # x = self.causal_conv(x)
        # x=x.permute(0,3,2,1)
        # print(x.shape)
        x = x.float().squeeze(2).squeeze(-1)
        # print(x.shape)
        z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, h_out = \
            self.inference(x, self.alpha1, self.alpha0)

        z_mean_prior_forward, z_logvar_prior_forward, x_mu, x_logsigma = \
            self.generation(h_out, z_posterior_forward, self.alpha1, self.alpha0, self.gamma0,
                            self.gamma1, self.gamma2)

        #inference
        # print('h_dim is',h_out.shape)
        # print(h_out.shape)
        #CNN based predictor
        h_out = h_out.unsqueeze(1)
        h_out = h_out.permute(0,2,1,3)
        # print('h shape', h_out.shape)
        output = self.end_conv((h_out))                         #B, T*C, N, 1
        # print('out shape',output.shape)
        # output = output.permute(0, 1, 3, 2)
        # fc based
        output = self.end_fc(output)
        output = output.permute(0,1,3,2)
        # print(output.shape)


        return z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, z_mean_prior_forward, \
               z_logvar_prior_forward, x_mu, x_logsigma, output
        # return z_posterior_forward, z_mean_posterior_forward, z_logvar_posterior_forward, output