import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import autograd


class MLP(torch.nn.Module):
    
    def __init__(self, in_dim, mid_dim, mid_layers, out_dim):
        super(MLP, self).__init__()
        
        self.layers = torch.nn.ModuleList()
        self.layers.append(torch.nn.Linear(in_dim, mid_dim))
        self.layers.append(torch.nn.Tanh())
        for _ in range(mid_layers):
            self.layers.append(torch.nn.Linear(mid_dim, mid_dim))
            self.layers.append(torch.nn.Tanh())
        self.layers.append(torch.nn.Linear(mid_dim, out_dim))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    
class LinearSGM(torch.nn.Module):
    
    def __init__(self, data_dim, hidden_dim, num_layers):
        super(LinearSGM, self).__init__()
    
        self.field_model = MLP(data_dim, hidden_dim, num_layers, data_dim)
        self.scalar_model = MLP(data_dim, hidden_dim, num_layers, 1)
        
    def forward(self, x):
        v, s = self.field_model(x), self.scalar_model(x)
        s = torch.log(1 + torch.exp(s))
        return v, s
    
    def sample(self):
        pass
    
    def _normal_density(self, x, epsilon):
        dim = x.shape[-1]
        prefix = -dim / 2.0 * np.log(2 * np.pi * epsilon)
        postfix = -torch.pow(x, 2).sum(dim=-1) / (2 * epsilon)
        return torch.exp(prefix + postfix)
    
    def loss_fn(self, x, epsilon, num_slices):
        bsz, dim = x.shape
        assert (bsz % 2 == 0), 'batch size must be a even number!'
        
        fir_batch, sec_batch = x[:bsz // 2], x[bsz // 2:]
        fields, scalars = self(fir_batch)
        scalars = scalars.squeeze(-1)
        distances = sec_batch - fir_batch
        coeff = self._normal_density(distances, epsilon)
        
        if dim == 1:
            hybrid_obj = torch.sum(fields)
            trace = autograd.grad(hybrid_obj, fir_batch, create_graph=True)[0]
        else:
            slices = torch.randn_like(fir_batch)
            vmat = torch.sum(slices * fields)
            gradv = autograd.grad(vmat, fir_batch, create_graph=True)[0]
            trace = torch.sum(gradv * slices, dim=-1)
        field_loss = torch.mean(torch.sum(torch.pow(fields, 2), -1) + trace.squeeze(-1) * coeff)
        scalar_loss = torch.mean(torch.pow(scalars, 2) - 2 * coeff * scalars)
        return field_loss, scalar_loss
