import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict


__all__ = ["tidespl_vae"]


class TiDeSPLVAE(nn.Module):
    def __init__(self, input_dim, content_dim, style_dim, hidden_state_dim, hidden_dim=None):
        super().__init__()
        self.input_dim = input_dim
        self.content_dim = content_dim
        self.style_dim = style_dim
        self.latent_dim = self.content_dim + self.style_dim
        if hidden_dim is None:
            self.hidden_dims = [self.input_dim, self.latent_dim]
        else:
            self.hidden_dims = hidden_dim
        self.hidden_state_dim = hidden_state_dim
        
        # encoder_x
        modules = []
        for hidden_dim in self.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.ReLU()
                )
            )
            input_dim = hidden_dim
        self.encoder_x = nn.Sequential(*modules)

        # posterior
        self.posterior_content = nn.Sequential(
            nn.Linear(self.hidden_dims[-1] + self.hidden_state_dim, self.content_dim),
            nn.BatchNorm1d(self.content_dim),
            nn.ReLU(),
            nn.Linear(self.content_dim, self.content_dim)
        )
        self.posterior_style = nn.Sequential(
            nn.Linear(self.hidden_dims[-1] + self.hidden_state_dim, self.style_dim),
            nn.BatchNorm1d(self.style_dim),
            nn.ReLU(),
            nn.Linear(self.style_dim, self.style_dim * 2)
        )

        # prior
        self.prior_style = nn.Sequential(
            nn.Linear(self.hidden_state_dim, self.style_dim),
            nn.BatchNorm1d(self.style_dim),
            nn.ReLU(),
            nn.Linear(self.style_dim, self.style_dim * 2)
        )

        # decoder
        modules = []
        input_dim = self.latent_dim + self.hidden_state_dim
        self.hidden_dims = [self.input_dim] + self.hidden_dims + [self.hidden_dims[-1]]
        for i in range(len(self.hidden_dims) - 1, 0, -1):
            modules.append(
                nn.Sequential(
                    nn.Linear(input_dim, self.hidden_dims[i]),
                    nn.BatchNorm1d(self.hidden_dims[i]),
                    nn.ReLU()
                )
            )
            input_dim = self.hidden_dims[i]
        modules.append(
            nn.Sequential(
                nn.Linear(input_dim, self.hidden_dims[0]),
                nn.Softplus()
            )
        )
        self.decoder = nn.Sequential(*modules)
        
        # recurrence
        self.rnn_content = nn.GRU(
            input_size=self.hidden_dims[-1],
            hidden_size=self.hidden_state_dim,
            num_layers=1,
            bias=True
        )
        self.rnn_style = nn.GRU(
            input_size=self.hidden_dims[-1] + self.latent_dim,
            hidden_size=self.hidden_state_dim,
            num_layers=1,
            bias=True
        )

        self.init_weight()
    
    def encode(self, x, h_content, h_style=None, style=False):
        feature_x = self.encoder_x(x)
        content_input = torch.cat((feature_x, h_content), dim=-1)
        z_content = self.posterior_content(content_input)
        if not style:
            return feature_x, z_content
        
        style_input = torch.cat((feature_x, h_style), dim=-1)
        z_style = self.posterior_style(style_input)
        z_style_mu = z_style[:, :self.style_dim]
        z_style_log_var = z_style[:, self.style_dim:]

        return feature_x, z_content, z_style_mu, z_style_log_var

    def reparameterize(self, z_mu, z_log_var):
        epsilon = torch.randn_like(z_mu)
        z = z_mu + torch.exp(0.5 * z_log_var) * epsilon

        return z

    def set_output_dict(self):
        output = {
            "x_recons_ori": [],
            "x_pos_recons_ori": [],
            "x_recons_swap": [],
            "x_pos_recons_swap": [],
            "z_content": [],
            "z_style_mu": [],
            "z_style_log_var": [],
            "z_prior_style_mu": [],
            "z_prior_style_log_var": [],
            "z_pos_content": [],
            "z_pos_style_mu": [],
            "z_pos_style_log_var": [],
            "z_prior_pos_style_mu": [],
            "z_prior_pos_style_log_var": [],
            "z_neg_content": []
        }

        return output

    def forward(self, x, x_pos, x_neg):
        T = x.size(0)

        output = self.set_output_dict()

        h_content = torch.zeros(1, x.size(1), self.hidden_state_dim).to(next(self.parameters()).device)
        h_style = torch.zeros(1, x.size(1), self.hidden_state_dim).to(next(self.parameters()).device)
        h_pos_content = torch.zeros(1, x.size(1), self.hidden_state_dim).to(next(self.parameters()).device)
        h_pos_style = torch.zeros(1, x.size(1), self.hidden_state_dim).to(next(self.parameters()).device)
        h_neg_content = torch.zeros(1, x.size(1), self.hidden_state_dim).to(next(self.parameters()).device)

        for t in range(T):
            # posterior
            feature_x, z_content_t, z_style_mu_t, z_style_log_var_t = self.encode(x[t], h_content[-1], h_style[-1], style=True)
            z_style_t = self.reparameterize(z_style_mu_t, z_style_log_var_t)
            output["z_content"].append(z_content_t)
            output["z_style_mu"].append(z_style_mu_t)
            output["z_style_log_var"].append(z_style_log_var_t)

            feature_x_pos, z_pos_content_t, z_pos_style_mu_t, z_pos_style_log_var_t = self.encode(x_pos[t], h_pos_content[-1], h_pos_style[-1], style=True)
            z_pos_style_t = self.reparameterize(z_pos_style_mu_t, z_pos_style_log_var_t)
            output["z_pos_content"].append(z_pos_content_t)
            output["z_pos_style_mu"].append(z_pos_style_mu_t)
            output["z_pos_style_log_var"].append(z_pos_style_log_var_t)

            feature_x_neg, z_neg_content_t = self.encode(x_neg[t], h_neg_content[-1], style=False)
            output["z_neg_content"].append(z_neg_content_t)

            # prior
            z_prior_style_t = self.prior_style(h_style[-1])
            output["z_prior_style_mu"].append(z_prior_style_t[:, :self.style_dim])
            output["z_prior_style_log_var"].append(z_prior_style_t[:, self.style_dim:])
            z_prior_pos_style_t = self.prior_style(h_pos_style[-1])
            output["z_prior_pos_style_mu"].append(z_prior_pos_style_t[:, :self.style_dim])
            output["z_prior_pos_style_log_var"].append(z_prior_pos_style_t[:, self.style_dim:])

            # decode
            z_original_t = torch.cat((z_content_t, z_style_t), dim=-1)
            z_pos_original_t = torch.cat((z_pos_content_t, z_pos_style_t), dim=-1)

            z_swap_t = torch.cat((z_pos_content_t, z_style_t), dim=-1)
            z_pos_swap_t = torch.cat((z_content_t, z_pos_style_t), dim=-1)

            output["x_recons_ori"].append(self.decoder(torch.cat((z_original_t, h_style[-1]), dim=-1)))
            output["x_pos_recons_ori"].append(self.decoder(torch.cat((z_pos_original_t, h_pos_style[-1]), dim=-1)))

            output["x_recons_swap"].append(self.decoder(torch.cat((z_swap_t, h_style[-1]), dim=-1)))
            output["x_pos_recons_swap"].append(self.decoder(torch.cat((z_pos_swap_t, h_pos_style[-1]), dim=-1)))

            # recurrence
            _, h_content = self.rnn_content(feature_x.unsqueeze(0), h_content)
            _, h_pos_content = self.rnn_content(feature_x_pos.unsqueeze(0), h_pos_content)
            _, h_neg_content = self.rnn_content(feature_x_neg.unsqueeze(0), h_neg_content)

            _, h_style = self.rnn_style(torch.cat((feature_x, z_content_t, z_style_t), dim=-1).unsqueeze(0), h_style)
            _, h_pos_style = self.rnn_style(torch.cat((feature_x_pos, z_pos_content_t, z_pos_style_t), dim=-1).unsqueeze(0), h_pos_style)
        
        for key in output.keys():
            output[key] = torch.stack(output[key], dim=0)
        return output

    def compute_reconstruction_loss(self, x, x_recons):
        x_recons = torch.clamp(x_recons, min=1e-7, max=1e7)
        return F.poisson_nll_loss(x_recons, x, log_input=False, reduction="sum") / x.size(0) / x.size(1)

    def compute_kld_loss(self, prior_mu, prior_log_var, post_mu, post_log_var):
        return 0.5 * torch.mean(torch.sum(-1 + ((post_mu - prior_mu) ** 2 + torch.exp(post_log_var)) / torch.exp(prior_log_var) - post_log_var + prior_log_var, dim=-1))

    def compute_infonce(self, z, z_pos, z_neg, temperature):
        z = F.normalize(z.permute(1, 0, 2).flatten(1, 2), p=2, dim=-1)
        z_pos = F.normalize(z_pos.permute(1, 0, 2).flatten(1, 2), p=2, dim=-1)
        z_neg = F.normalize(z_neg.permute(1, 0, 2).flatten(1, 2), p=2, dim=-1)

        pos_sim = torch.sum(z * z_pos, dim=-1) / temperature
        neg_sim = torch.mm(z, z_neg.T) / temperature
        loss = torch.mean(-pos_sim + torch.log(torch.sum(torch.cat([torch.exp(neg_sim), torch.exp(pos_sim).unsqueeze(-1)], dim=-1), dim=-1)))

        return loss

    def compute_prior_loss(self, prior_mu, prior_log_var):
        return torch.mean(torch.sum(prior_mu ** 2, dim=-1)) + torch.mean(torch.sum(prior_log_var ** 2, dim=-1))

    def compute_loss(self, **kwargs):
        reconstruction_ori_loss = (self.compute_reconstruction_loss(kwargs["x"], kwargs["x_recons_ori"]) + self.compute_reconstruction_loss(kwargs["x_pos"], kwargs["x_pos_recons_ori"])) / 2
        reconstruction_swap_loss = (self.compute_reconstruction_loss(kwargs["x"], kwargs["x_recons_swap"]) + self.compute_reconstruction_loss(kwargs["x_pos"], kwargs["x_pos_recons_swap"])) / 2

        kld_loss = (self.compute_kld_loss(kwargs["z_prior_style_mu"], kwargs["z_prior_style_log_var"], kwargs["z_style_mu"], kwargs["z_style_log_var"]) + self.compute_kld_loss(kwargs["z_prior_pos_style_mu"], kwargs["z_prior_pos_style_log_var"], kwargs["z_pos_style_mu"], kwargs["z_pos_style_log_var"])) / 2
        
        infonce_loss = self.compute_infonce(kwargs["z_content"], kwargs["z_pos_content"], kwargs["z_neg_content"], kwargs["temperature"])

        prior_loss = (self.compute_prior_loss(kwargs["z_prior_style_mu"], kwargs["z_prior_style_log_var"]) + self.compute_prior_loss(kwargs["z_prior_pos_style_mu"], kwargs["z_prior_pos_style_log_var"])) / 2
        
        loss = reconstruction_ori_loss + reconstruction_swap_loss + kwargs["kld_weight"] * kld_loss + kwargs["cont_weight"] * infonce_loss + kwargs["prior_weight"] * prior_loss

        return OrderedDict([("loss", loss), ("reconstruction_ori_loss", reconstruction_ori_loss.detach()), ("reconstruction_swap_loss", reconstruction_swap_loss.detach()), ("kld_loss", kld_loss.detach()), ("infonce_loss", infonce_loss.detach()), ("prior_loss", prior_loss.detach())])
    
    def init_weight(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.GRU, nn.RNN, nn.LSTM)):
                nn.init.kaiming_normal_(m.weight_ih_l0)
                nn.init.kaiming_normal_(m.weight_hh_l0)
                nn.init.constant_(m.bias_ih_l0, 0)
                nn.init.constant_(m.bias_hh_l0, 0)


def tidespl_vae(**kwargs):
    return TiDeSPLVAE(**kwargs)
