"""
==========
Utilities
==========

"""

import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool


def affine(x, translation, log_scale):
    z = translation + x * log_scale.exp()
    ldj = sum_except_batch(log_scale)
    return z, ldj

def affine_tobatch(x, translation, log_scale, batch):
    z = translation + x * log_scale.exp()
    ldj = sum_to_batch(log_scale, batch)
    return z, ldj

def assert_correctly_masked(variable, node_mask):
    assert (variable * (1 - node_mask)).abs().max().item() < 1e-4, \
        'Variables not masked properly.'
    
def assert_mean_zero(x):
    mean = torch.mean(x, dim=1, keepdim=True)
    assert mean.abs().max().item() < 1e-4


def assert_mean_zero_with_mask(x, node_mask):
    assert_correctly_masked(x, node_mask)
    assert torch.sum(x, dim=1, keepdim=True).abs().max().item() < 1e-4, \
        'Mean is not zero'

def center_gravity_zero_gaussian_log_likelihood(x):
    assert len(x.size()) == 3
    B, N, D = x.size()
    assert_mean_zero(x)

    # r is invariant to a basis change in the relevant hyperplane.
    r2 = sum_except_batch(x.pow(2))

    # The relevant hyperplane is (N-1) * D dimensional.
    degrees_of_freedom = (N-1) * D

    # Normalizing constant and logpx are computed:
    log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi)
    log_px = -0.5 * r2 + log_normalizing_constant

    return log_px

def center_gravity_zero_gaussian_log_likelihood_with_mask(x, node_mask):
    assert len(x.size()) == 3
    B, N_embedded, D = x.size()
    assert_mean_zero_with_mask(x, node_mask)

    # r is invariant to a basis change in the relevant hyperplane, the masked
    # out values will have zero contribution.
    r2 = sum_except_batch(x.pow(2))

    # The relevant hyperplane is (N-1) * D dimensional.
    N = node_mask.squeeze(2).sum(1)  # N has shape [B]
    degrees_of_freedom = (N-1) * D

    # Normalizing constant and logpx are computed:
    log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi)
    log_px = -0.5 * r2 + log_normalizing_constant

    return log_px

def center_gravity_zero_gaussian_log_likelihood_tobatch(x, batch):
    _, D = x.size()
    assert_mean_zero(x)

    # r is invariant to a basis change in the relevant hyperplane, the masked
    # out values will have zero contribution.
    r2 = sum_to_batch(x.pow(2), batch)

    # The relevant hyperplane is (N-1) * D dimensional.
    _, hist = torch.unique(batch, return_counts=True)  #this should be the number of nodes in each batch but needs to be revised.
    hist = hist.to(x.device).to(torch.long)
    degrees_of_freedom = hist * D

    # Normalizing constant and logpx are computed:
    log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi)
    log_px = -0.5 * r2 + log_normalizing_constant

    return log_px

def remove_mean(x):
    mean = torch.mean(x, dim=1, keepdim=True)
    x = x - mean
    return x

def remove_mean_with_mask(x, node_mask):
    assert (x * (1 - node_mask)).abs().sum().item() < 1e-8
    N = node_mask.sum(1, keepdims=True)

    mean = torch.sum(x, dim=1, keepdim=True) / N
    x = x - mean * node_mask
    return x


def sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask):
    assert len(size) == 3
    x = torch.randn(size, device=device)

    x_masked = x * node_mask

    # This projection only works because Gaussian is rotation invariant around
    # zero and samples are independent!
    x_projected = remove_mean_with_mask(x_masked, node_mask)
    return x_projected

def sample_center_gravity_zero_gaussian(size, device):
    x = torch.randn(size, device=device)
    x_projected = remove_mean(x)
    return x_projected


def sample_gaussian(size, device):
    x = torch.randn(size, device=device)
    return x

def sample_gaussian_with_mask(size, device, node_mask):
    x = torch.randn(size, device=device)
    x_masked = x * node_mask
    return x_masked

def sigmoid(x, node_mask):
    z = torch.sigmoid(x)
    ldj = sum_except_batch(node_mask * (F.logsigmoid(x) + F.logsigmoid(-x)))
    return z, ldj

def sigmoid_no_mask(x, batch):
    z = torch.sigmoid(x)
    ldj = sum_to_batch((F.logsigmoid(x) + F.logsigmoid(-x)), batch)
    return z, ldj

def standard_gaussian_log_likelihood(x):
    # Normalizing constant and logpx are computed:
    log_px = sum_except_batch(-0.5 * x * x - 0.5 * np.log(2*np.pi))
    return log_px

def standard_gaussian_log_likelihood_tobatch(x,batch):
    # Normalizing constant and logpx are computed:
    log_px = sum_to_batch(-0.5 * x * x - 0.5 * np.log(2*np.pi), batch)
    return log_px


def standard_gaussian_log_likelihood_with_mask(x, node_mask):
    # Normalizing constant and logpx are computed:
    log_px_elementwise = -0.5 * x * x - 0.5 * np.log(2*np.pi)
    log_px = sum_except_batch(log_px_elementwise * node_mask)
    return log_px

def sum_except_batch(x):
    return x.reshape(x.size(0), -1).sum(dim=-1)

def sum_to_batch(x, batch):
    return global_add_pool(x, batch).sum(dim=-1)

def transform_to_argmax_partition(onehot, u, node_mask):
    assert torch.allclose(
        onehot.sum(-1, keepdims=True) * node_mask,
        torch.ones_like(onehot[..., 0:1]) * node_mask)

    T = (onehot * u).sum(-1, keepdim=True)
    z = onehot * u + node_mask * (1 - onehot) * (T - F.softplus(T - u))
    ldj = (1 - onehot) * F.logsigmoid(T - u) * node_mask

    assert_correctly_masked(z, node_mask)
    assert_correctly_masked(ldj, node_mask)

    ldj = sum_except_batch(ldj)

    return z, ldj

def transform_to_argmax_partition_no_mask(onehot, u, batch):
    assert torch.allclose(
        onehot.sum(-1, keepdims=True),
        torch.ones_like(onehot[..., 0:1]))

    T = (onehot * u).sum(-1, keepdim=True)
    z = onehot * u + (1 - onehot) * (T - F.softplus(T - u))
    ldj = (1 - onehot) * F.logsigmoid(T - u)


    ldj = sum_to_batch(ldj, batch)

    return z, ldj

def transform_to_hypercube_partition(integer, interval_noise):
    assert interval_noise.min().item() >= 0., interval_noise.max().item() <= 1.
    return integer + interval_noise

# Rotation data augmntation
def random_rotation(x):
    bs, n_nodes, n_dims = x.size()
    device = x.device
    angle_range = np.pi * 2
    if n_dims == 2:
        theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)
        R_row0 = torch.cat([cos_theta, -sin_theta], dim=2)
        R_row1 = torch.cat([sin_theta, cos_theta], dim=2)
        R = torch.cat([R_row0, R_row1], dim=1)

        x = x.transpose(1, 2)
        x = torch.matmul(R, x)
        x = x.transpose(1, 2)

    elif n_dims == 3:

        # Build Rx
        Rx = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
        theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
        cos = torch.cos(theta)
        sin = torch.sin(theta)
        Rx[:, 1:2, 1:2] = cos
        Rx[:, 1:2, 2:3] = sin
        Rx[:, 2:3, 1:2] = - sin
        Rx[:, 2:3, 2:3] = cos

        # Build Ry
        Ry = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
        theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
        cos = torch.cos(theta)
        sin = torch.sin(theta)
        Ry[:, 0:1, 0:1] = cos
        Ry[:, 0:1, 2:3] = -sin
        Ry[:, 2:3, 0:1] = sin
        Ry[:, 2:3, 2:3] = cos

        # Build Rz
        Rz = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device)
        theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi
        cos = torch.cos(theta)
        sin = torch.sin(theta)
        Rz[:, 0:1, 0:1] = cos
        Rz[:, 0:1, 1:2] = sin
        Rz[:, 1:2, 0:1] = -sin
        Rz[:, 1:2, 1:2] = cos

        x = x.transpose(1, 2)
        x = torch.matmul(Rx, x)
        #x = torch.matmul(Rx.transpose(1, 2), x)
        x = torch.matmul(Ry, x)
        #x = torch.matmul(Ry.transpose(1, 2), x)
        x = torch.matmul(Rz, x)
        #x = torch.matmul(Rz.transpose(1, 2), x)
        x = x.transpose(1, 2)
    else:
        raise Exception("Not implemented Error")

    return x.contiguous()

# Gradient clipping
class Queue():
    def __init__(self, max_len=50):
        self.items = []
        self.max_len = max_len

    def __len__(self):
        return len(self.items)

    def add(self, item):
        self.items.insert(0, item)
        if len(self) > self.max_len:
            self.items.pop()

    def mean(self):
        return np.mean(self.items)

    def std(self):
        return np.std(self.items)

def gradient_clipping(flow, gradnorm_queue):
    # Allow gradient norm to be 150% + 2 * stdev of the recent history.
    max_grad_norm = (1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std()) * 1

    # Clips gradient and returns the norm
    grad_norm = torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=max_grad_norm, norm_type=2.0)

    if float(grad_norm) > max_grad_norm:
        gradnorm_queue.add(float(max_grad_norm))
    else:
        gradnorm_queue.add(float(grad_norm))

    if float(grad_norm) > max_grad_norm:
        print(f'Clipped gradient with value {grad_norm:.1f} '
              f'while allowed {max_grad_norm:.1f}')
    return grad_norm