import numpy as np
import torch

from functools import partial

from .target_distribution import TargetDistribution

from scipy.interpolate import CubicSpline

def multi_double_well_energy(
    dists: torch.Tensor,
    a: float = 0.0,
    b: float = -4.0,
    c: float = 0.9,
    offset: float = 4.0
) -> torch.Tensor:
    """
    Compute the multi-dimensional double well energy given pairwise distances.

    Arguments:
        dists: Tensor of shape `[n_batch, n_particles, n_other_particles]` containing pairwise distances.
        a, b, c: Coefficients for the polynomial terms in the double well potential.
        offset: Offset between the wells.

    Returns:
        energy: Tensor of shape `[n_batch, n_particles, n_other_particles]` containing the computed energies.
    """
    term1 = a * dists
    term2 = b * (dists - offset).pow(2)
    term3 = c * (dists - offset).pow(4)
    energy = term1 + term2 + term3
    return energy

def tile(a, dim, n_tile):
    """
    Tiles a pytorch tensor along one an arbitrary dimension.

    Parameters
    ----------
    a : PyTorch tensor
        the tensor which is to be tiled
    dim : Integer
        dimension along the tensor is tiled
    n_tile : Integer
        number of tiles

    Returns
    -------
    b : PyTorch tensor
        the tensor with dimension `dim` tiled `n_tile` times
    """
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = np.concatenate(
        [init_dim * np.arange(n_tile) + i for i in range(init_dim)]
    )
    order_index = torch.LongTensor(order_index).to(a).long()
    return torch.index_select(a, dim, order_index)

def distance_vectors(
    x: torch.Tensor, 
    remove_diagonal: bool = True
) -> torch.Tensor:
    """
    Computes the matrix R of all distance vectors between given input points where
    R_{ij} = x_i - x_j

    Arguments:
        x: Tensor of shape `[n_batch, n_particles, n_dimensions]` containing input points.
        remove_diagonal: Flag indicating whether the all-zero distance vectors x_i - x_i
    
    Returns:
        R: Tensor of shape `[n_batch, n_particles, n_other_particles, n_dimensions]` containing all distance vectors.
    """
    r = tile(x.unsqueeze(2), 2, x.shape[1])
    r = r - r.permute([0, 2, 1, 3])
    if remove_diagonal:
        r = r[:, torch.eye(x.shape[1], x.shape[1]) == 0].view(
            -1, x.shape[1], x.shape[1] - 1, x.shape[2]
        )
    return r

def distances_from_vectors(r, eps=1e-6):
    """
    Computes the all-distance matrix from given distance vectors.
    
    Parameters
    ----------
    r : torch.Tensor
        Matrix of all distance vectors r.
        Tensor of shape `[n_batch, n_particles, n_other_particles, n_dimensions]`
    eps : Small real number.
        Regularizer to avoid division by zero.
    
    Returns
    -------
    d : torch.Tensor
        All-distance matrix d.
        Tensor of shape `[n_batch, n_particles, n_other_particles]`.
    """
    return (r.pow(2).sum(dim=-1) + eps).sqrt()

def cubic_spline(x_new, x, c):
    # code from https://github.com/cambridge-mlg/Progressive-Tempering-Sampler-with-Diffusion/blob/main/ptsd/targets/lennard_jones.py
    x, c = x.to(x_new.device), c.to(x_new.device)
    intervals = torch.bucketize(x_new, x) - 1
    intervals = torch.clamp(intervals, 0, len(x) - 2)  # Ensure valid intervals
    # Calculate the difference from the left breakpoint of the interval
    dx = x_new - x[intervals]
    # Evaluate the cubic spline at x new
    y_new = (
        c[0, intervals] * dx**3
        + c[1, intervals] * dx**2
        + c[2, intervals] * dx
        + c[3, intervals]
    )
    return y_new

class MultiDoubleWell(TargetDistribution):
    def __init__(
        self, 
        dim: int,
        n_particles: int,
        a: float = 0.0,
        b: float = -4.0,
        c: float = 0.9,
        offset: float = 4.0,
        data_path: str = None
    ) -> None:
        """
        Multi-dimensional Double Well distribution.

        Arguments:
            dim (int): The dimension of the distribution. It must be divisible by n_particles, i.e., dim = n_particles * particle_dim.
            n_particles (int): The number of particles in the system.
            a (float): Coefficient for the linear term in the double well potential.
            b (float): Coefficient for the quadratic term in the double well potential.
            c (float): Coefficient for the quartic term in the double well potential.
            offset (float): Offset between the wells.
            data_path (str): Path to data file to load precomputed samples. 
        """
        super().__init__()

        assert dim % n_particles == 0, "dim must be divisible by n_particles"

        self._n_particles = n_particles
        self._particle_dim = dim // n_particles

        self._a = a
        self._b = b
        self._c = c
        self._offset = offset
        
        self._data_path = data_path
                

    @property
    def dim(self) -> int:
        """
        Returns the dimension of the distribution.
        Returns:
            dim: int
        """
        return self._n_particles * self._particle_dim

    def _distance_vectors(self, x):
        x_particles = x.view(-1, self._n_particles, self._particle_dim)
        return distance_vectors(x_particles)
    
    def _energy(self, x):
        """
        Computes the Lennard-Jones energy of the input x.

        Arguments:
            x: tensor of samples of size (batch_size, dim) or (batch_size, n_particles, particle_dim)
        
        Returns:
            energy: tensor of energies of size (batch_size, 1)
        
        """
        batch_size = x.shape[0]
        dists = distances_from_vectors( self._distance_vectors(x) )

        energies = multi_double_well_energy(dists, a=self._a, b=self._b, c=self._c, offset=self._offset)
        energies = energies.view(batch_size, -1).sum(dim=-1)

        return energies.unsqueeze(-1)
    
    def log_prob(self, x: torch.Tensor) -> torch.Tensor:
        """
        Returns the log probability of the input x

        Arguments:
            x: tensor of samples of size (batch_size, dim) or (batch_size, n_particles, particle_dim)
        Returns:
            log_prob: (batch_size, 1)
        """
        return -self._energy(x)

    def sample(self, num_samples):
        samples_array = np.load(self._data_path)
        idx = np.random.choice(samples_array.shape[0], num_samples, replace=False)
        samples = torch.tensor(samples_array[idx], dtype=torch.float32)
        return samples