import nf
import numpy as np
import torch
import torch.nn as nn
import torch.distributions as td

class ConfetStochastic(nn.Module):
    """
    Args:
        dim: Input dimension
        hidden_dim: Size of the hidden layer
        num_layers: Number of layers
        solver: ODE solver to use
        solver_step: If fixed solver, the size of the ODE step
    """
    def __init__(self, dim, hidden_dim, num_layers, solver='rk4', solver_step=None, **kwargs):
        super().__init__()

        self.dim = dim
        self.base_dist = td.Uniform(torch.zeros(self.dim), torch.ones(self.dim))

        solver_options = {} if solver_step is None else { 'step_size': solver_step }
        net = nf.net.EquivariantNet(dim, [hidden_dim] * num_layers, dim)

        self.cnf = nf.ContinuousSetFlow(dim, net, solver=solver, solver_options=solver_options)
        self.transforms = [nf.Logit(), self.cnf, nf.Sigmoid()]
        self.flow = nf.Flow(self.base_dist, self.transforms)

    def forward(self, x, m, **kwargs):
        log_prob = self.flow.log_prob(x, mask=m.unsqueeze(-1).expand_as(x))
        loss = -(log_prob.squeeze(-1) * m).sum() / m.sum()
        return loss

    def sample(self, num_samples):
        return self.flow.sample((1, num_samples)).squeeze(0)


class SumInteractionLayer(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, **kwargs):
        super().__init__()
        self.net = nf.net.MLP(in_dim, [hidden_dim] * num_layers, out_dim)

    def forward(self, x, mask=None, **kwargs):
        if mask is None:
            mask = torch.ones(*x.shape[:-1], 1)
        else:
            mask = mask[...,0,None]

        x_emb = self.net(x) * mask
        set_emb = x_emb.sum(-2, keepdims=True)
        y = (set_emb - x_emb) / mask.sum(-2, keepdim=True)

        trace = torch.zeros_like(y).to(y).requires_grad_(True)
        sqjnorm = torch.zeros(y.shape[0]).to(y).requires_grad_(True)
        return y, trace, sqjnorm

class ConfetFixed(nn.Module):
    """
    Args:
        dim: Input dimension
        hidden_dim: Size of the hidden layer
        num_layers: Number of layers
        solver: ODE solver to use
        solver_step: If fixed solver, the size of the ODE step
        num_coupling_layers: Number of additional coupling layers
    """
    def __init__(self, dim, hidden_dim, num_layers, solver='rk4', solver_step=None,
                  num_coupling_layers=0, **kwargs):
        super().__init__()

        self.dim = dim
        self.base_dist = td.Uniform(torch.zeros(self.dim), torch.ones(self.dim))

        solver_options = {} if solver_step is None else { 'step_size': solver_step }
        net = SumInteractionLayer(dim, hidden_dim, dim, num_layers)
        self.cnf = nf.ContinuousSetFlow(dim, net, solver=solver, solver_options=solver_options,
                                        returns_divergence=True, divergence_fn='brute_force')

        self.transforms = [nf.Logit(), self.cnf, nf.Sigmoid()]
        for i in range(num_coupling_layers):
            self.transforms.append(nf.Coupling(
                flow=nf.Spline(dim, latent_dim=hidden_dim, n_bins=5, lower=0, upper=1),
                net=nf.net.MLP(dim, [hidden_dim], hidden_dim),
                mask=nf.mask.ordered(right=i%2),
            ))

        self.flow = nf.Flow(self.base_dist, self.transforms)

    def forward(self, x, m, **kwargs):
        log_prob = self.flow.log_prob(x, mask=m.unsqueeze(-1).expand_as(x))
        loss = -(log_prob.squeeze(-1) * m).sum() / m.sum()
        return loss

    def sample(self, num_samples):
        return self.flow.sample((1, num_samples)).squeeze(0)


class ConfetExact(nn.Module):
    """
    Args:
        dim: Input dimension
        hidden_dim: Size of the hidden layer
        num_layers: Number of layers
        solver: ODE solver to use
        solver_step: If fixed solver, the size of the ODE step
        conditioner_dim: Size of the d_h dimension in the decoupled network
        n_heads: Number of attention heads
        induced: Whether to use induced attention
        n_points: If using induced attention, how many inducing points
        interaction: Type of aggregation
    """
    def __init__(self, dim, hidden_dim, num_layers, conditioner_dim=None, n_points=None, induced=False,
                 solver='rk4', solver_step=None, interaction=None, **kwargs):
        super().__init__()

        self.dim = dim
        self.base_dist = td.Uniform(torch.zeros(self.dim), torch.ones(self.dim))

        solver_options = {} if solver_step is None else { 'step_size': solver_step }
        net = nf.net.CNFFuncAndJac(dim, hidden_dim, conditioner_dim or hidden_dim, num_layers,
                                   interaction=interaction, n_points=n_points, induced=induced)
        self.cnf = nf.ContinuousSetFlow(dim, net, solver=solver, solver_options=solver_options, input_time=True,
                                        returns_divergence=True, divergence_fn='brute_force')

        self.transforms = [nf.Logit(), self.cnf, nf.Sigmoid()]
        self.flow = nf.Flow(self.base_dist, self.transforms)

    def forward(self, x, m, **kwargs):
        log_prob = self.flow.log_prob(x, mask=m.unsqueeze(-1).expand_as(x))
        loss = -(log_prob.squeeze(-1) * m).sum() / m.sum()
        return loss

    def sample(self, num_samples):
        return self.flow.sample((1, num_samples)).squeeze(0)
