import torch
import torch.nn as nn
from prettytable import PrettyTable
from torch.distributions.multivariate_normal import MultivariateNormal
import math

class XYNet(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(XYNet, self).__init__()

        self.net1 = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.Tanh(),
            nn.Linear(512, 256),
            nn.Tanh(),
            nn.Linear(256, out_dim),
            nn.Tanh()
        )
        self.net2 = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.Tanh(),
            nn.Linear(512, 256),
            nn.Tanh(),
            nn.Linear(256, out_dim),
            nn.Tanh()
        )
        self.net3 = nn.Sequential(
            nn.Linear(out_dim,out_dim),
            nn.Tanh(),
            nn.Linear(out_dim,out_dim),
            nn.Tanh()
        )



    def forward(self, x, y, L):
        x_hat = self.net1(x)
        y_hat = self.net2(y)
        W = torch.diag(torch.diag(L)) - L
        Wy_hat = self.net3(W @ y_hat)
        return x_hat, y_hat, Wy_hat



class PDF(nn.Module):

    def __init__(self, dim, pdf):
        super(PDF, self).__init__()
        assert pdf in {'gauss', 'logistic'}
        self.dim = dim
        self.pdf = pdf
        self.mu = nn.Embedding(1, self.dim)
        self.ln_var = nn.Embedding(self.dim, self.dim)  # ln(s) in logistic

    def forward(self, Y):
        cross_entropy = compute_negative_ln_prob(Y, self.mu.weight,
                                                 self.ln_var.weight, self.pdf)
        return cross_entropy


def compute_negative_ln_prob(Y, mu, ln_var, pdf):
    #var = ln_var.exp()
    inv_var = ln_var @ ln_var.T
    if pdf == 'gauss':
        negative_ln_prob = 0.5 * torch.diag((Y - mu) @ inv_var @ (Y - mu).T).mean() + \
                           0.5 * Y.size(1) * math.log(2 * math.pi) - \
                           0.5 * torch.logdet(inv_var + 1e-6 * torch.eye(inv_var.shape[0]))

    else:
        raise ValueError('Unknown PDF: %s' % (pdf))

    return negative_ln_prob


class mySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs

def print_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params += param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
