import torch
from torch import nn
import numpy as np
import scipy
from deeptime.util.types import ensure_dtraj_list
from tqdm import tqdm
import math

def safe_exp(x, name=None):
    out = x.exp()
    if torch.isnan(out.sum()):
        print('%s is NaN' % name)
    return out

def isnan_num(x, name):
    if torch.isnan(x.sum()):
        print('%s is NaN' % name)
    return

def logsumexp_pair(x, y): #calculate log(x.exp + y.exp)
    m = torch.maximum(x, y)
    ex = torch.exp(x-m)
    ey = torch.exp(y-m)
    return m+torch.log(ex+ey)

def Tensor2Numpy(*args):
    if len(args) == 1:
        return args[0].cpu().detach().numpy()
    y = []
    for x in args:
        y.append(x.cpu().detach().numpy())
    return y

def remove_mean(samples, n_particles, n_dimensions):
    """Makes a configuration of many particle system mean-free.

    Parameters
    ----------
    samples : torch.Tensor
        Positions of n_particles in n_dimensions.

    Returns
    -------
    samples : torch.Tensor
        Mean-free positions of n_particles in n_dimensions.
    """
    shape = samples.shape
    if isinstance(samples, torch.Tensor):
        samples = samples.view(-1, n_particles, n_dimensions)
        samples = samples - torch.mean(samples, dim=1, keepdim=True)
        samples = samples.view(*shape)
    else:
        samples = samples.reshape(-1, n_particles, n_dimensions)
        samples = samples - samples.mean(axis=1, keepdims=True)
        samples = samples.reshape(*shape)
    return samples

def init_weights_zero(m):
    if isinstance(m, nn.Linear):
        nn.init.zeros_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

def compute_dihedral(p0, p1, p2, p3):
    # p0, p1, p2, p3: shape (batch, 3)
    b0 = -1.0 * (p1 - p0)
    b1 = p2 - p1
    b2 = p3 - p2

    # Normalize b1 so it doesn't affect magnitude of vector rejections
    if isinstance(p0, np.ndarray):
        b1 /= np.linalg.norm(b1, axis=1, keepdims=True)
        # Project b0 and b2 onto the plane perpendicular to b1
        v = b0 - np.sum(b0 * b1, axis=1, keepdims=True) * b1
        w = b2 - np.sum(b2 * b1, axis=1, keepdims=True) * b1

        x = np.sum(v * w, axis=1)
        y = np.sum(np.cross(b1, v) * w, axis=1)
        return np.arctan2(y, x)
    
    if isinstance(p0, torch.Tensor):
        b1 /= torch.norm(b1, dim=1, keepdim=True)

        # v = projection of b0 onto plane perpendicular to b1
        v = b0 - (b0 * b1).sum(dim=1, keepdim=True) * b1
        w = b2 - (b2 * b1).sum(dim=1, keepdim=True) * b1

        x = (v * w).sum(dim=1)
        y = torch.cross(b1, v, dim=1) * w
        y = y.sum(dim=1)

        return torch.atan2(y, x)  # in radians

def compute_bond_angle(p0, p1, p2):
    # p0, p1, p2: shape (batch, 3)
    b0 = p1 - p0
    b1 = p2 - p1
    if isinstance(p0, np.ndarray):
        b0 /= np.linalg.norm(b0, axis=1, keepdims=True)
        b1 /= np.linalg.norm(b1, axis=1, keepdims=True)
        cos_angle = np.sum(b0 * b1, axis=1)
        cos_angle = np.clip(cos_angle, -1.0, 1.0)
        angle = np.arccos(cos_angle)
        return np.degrees(angle)
    if isinstance(p0, torch.Tensor):
        b0 /= torch.norm(b0, dim=1, keepdim=True)
        b1 /= torch.norm(b1, dim=1, keepdim=True)
        cos_angle = torch.sum(b0 * b1, dim=1)
        cos_angle = torch.clamp(cos_angle, -1.0, 1.0)
        angle = torch.arccos(cos_angle)
        return torch.degrees(angle)

def sinusoidal_embedding(
    pos: torch.Tensor,
    emb_size: int, 
    max_pos: int = 10000,
) -> torch.Tensor:
    assert -max_pos <= pos.min().item() <= max_pos
    assert emb_size % 2 == 0, 'Please use an even embedding size.'
    half_emb_size = emb_size // 2
    idx = torch.arange(half_emb_size, dtype=torch.float32, device=pos.device)
    exponent = -1 * idx * math.log(max_pos) / (half_emb_size - 1)
    emb = pos[..., None] * torch.exp(exponent) # (..., half_emb_size)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # (..., emb_size)
    assert emb.size() == pos.size() + torch.Size([emb_size]), 'Embedding size mismatch.'
    return emb

def rand_projections(z_dim, num_samples=50):
    # This function generates `num_samples` random samples from the latent space's unit sphere
    projections = [w / np.sqrt((w**2).sum())
                   for w in np.random.normal(size=(num_samples, z_dim))]
    projections = torch.from_numpy(np.array(projections)).float()
    return projections

# Only used for unweighted samples
def sliced_wasserstein_distance(encoded_samples, prior_samples, projection_num=50, p=2, device='cpu'):
    # This function calculates the sliced-Wasserstein distance between the encoded samples and prior samples

    # derive latent space dimension size from random samples drawn from latent prior distribution
    z_dim = prior_samples.size(-1)

    # generate random projections in latent space
    projections = rand_projections(z_dim, projection_num).to(device)
    # calculate projections through the encoded samples
    encoded_projections = encoded_samples.matmul(projections.transpose(0, 1))
    # calculate projections through the prior distribution random samples
    prior_projections = (prior_samples.matmul(projections.transpose(0, 1)))
    # calculate the sliced wasserstein distance by
    # sorting the samples per random projection and
    # calculating the difference between the
    # encoded samples and drawn random samples
    # per random projection
    wasserstein_distance = (torch.sort(encoded_projections, dim=0)[0] -
                            torch.sort(prior_projections, dim=0)[0])
    # distance between latent space prior and encoded distributions
    # power of 2 by default for Wasserstein-2
    wasserstein_distance = torch.pow(wasserstein_distance, p)
    # approximate mean wasserstein_distance for each projection
    return wasserstein_distance.mean()

