import faiss
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint as odeint
from math import sqrt

device = torch.device("cpu")
torch.manual_seed(0)
torch.set_default_dtype(torch.float32)

### CLOSED-FORM DIFFUSION CODE ###


class SmoothedScore(nn.Module):
    """Heat-smoothed closed-form score function."""

    def __init__(
        self,
        space_dims,
        target_sampler,
        n_samples,
        smoothing_sigma=0,
        n_points=1,
        fixed_noise=None
    ):
        super(SmoothedScore, self).__init__()
        self.space_dims = space_dims
        self.target_sampler = target_sampler  # target_sampler is a callable that returns a sample of size (B, space_dims)
        self.n_samples = (
            n_samples  # number of samples to use to estimate the score function
        )
        self.smoothing_sigma = smoothing_sigma  # Std of Gaussian noise to add to z to approximate conv against heat kernel
        self.n_points = n_points  # Number of draws of Gaussian noise to average over to approximate conv against heat kernel
        self.fixed_noise = fixed_noise

    def forward(
        self,
        t,  # t will always be a scalar during calls from torchdiffeq.odeint
        z,  # (B, space_dims)
    ):
        """Compute the smoothed score h(t,z) at time t and position z using self.n_samples samples from self.target_sampler."""

        # Begin by computing the score function of the target distribution at z using samples

        # Sample from the target distribution
        x = self.target_sampler.sample(self.n_samples).to(
            device
        )  # (n_samples, space_dims)

        # If t=0, then the score function points from z to the mean of the x's at each z in the batch
        if t == 0:
            target_mean = torch.mean(x, dim=0)  # (space_dims)
            return (
                target_mean.unsqueeze(0).repeat(z.shape[0], 1) - z
            )  # (B, space_dims) -- note the -z term is important here!
        else:
            # Rescale the x's by t
            tx = t * x
            # Compute Gaussian density weights of each rescaled target sample tx[i] centered at z
            tx = tx.unsqueeze(0)  # (1, n_samples, space_dims)
            tx_for_weights = tx.unsqueeze(2)  # (1, n_samples, 1, space_dims)

            z_for_weights = z.unsqueeze(1).repeat(
                1, self.n_points, 1
            )  # (B, n_points, d)
            z_for_weights = z_for_weights.unsqueeze(1)  # (B, 1, n_points, space_dims)
            if self.fixed_noise is None:
                noise = (
                    torch.randn(z.shape[0], 1, self.n_points, z.shape[1], device=device)
                    * self.smoothing_sigma * 1
                )  # (B, 1, n_points, space_dims)
            else:
                noise = self.fixed_noise
            sigma = 1 - t
            # weights[i,j] = exp(-||z[i] - tx[j]||^2 / (2 * sigma^2))
            weights = -1 * torch.sum(
                (tx_for_weights - (z_for_weights + noise)) ** 2, dim=3
            ) / (2 * sigma**2)  # (B, n_samples, n_points)
            # Subtract max to avoid underflow
            weights = (
                weights - torch.max(weights, dim=1, keepdim=True)[0]
            )  # (B, n_samples, n_points)
            # Compute softmax of weights
            weights = F.softmax(weights, dim=1)
            # Take mean over n_points
            weights = torch.mean(weights, dim=2)  # (B, n_samples)

            # Compute the score function of the target distribution at z using the weights
            # score[i] = sum_j weights[i,j] * (tx[j] - z[i]) / sigma^2
            avg_sample = torch.sum(weights.unsqueeze(2) * tx, dim=1)  # (B, space_dims)
            score = (avg_sample - z.squeeze()) / (sigma**2)  # (B, space_dims)

            return score
        
class KNNSmoothedScore(nn.Module):
    """Heat-smoothed closed-form score function computed using the DEANN approach."""

    def __init__(
        self,
        space_dims,
        target_sampler,
        n_random_samples,
        smoothing_sigma=0,
        n_points=1,
        fixed_noise=None,
        index=None,
        num_nn=100
    ):
        super(KNNSmoothedScore, self).__init__()
        self.space_dims = space_dims
        self.target_sampler = target_sampler  # target_sampler is a callable that returns a sample of size (B, space_dims)
        self.n_random_samples = (
            n_random_samples  # number of randomsamples to use to estimate the score function
        )
        self.smoothing_sigma = smoothing_sigma  # Std of Gaussian noise to add to z to approximate conv against heat kernel
        self.n_points = n_points  # Number of draws of Gaussian noise to average over to approximate conv against heat kernel
        self.fixed_noise = fixed_noise
        self.index = index
        self.num_nn = num_nn

    def forward(
        self,
        t,  # t will always be a scalar during calls from torchdiffeq.odeint
        z,  # (B, space_dims)
    ):
        """Compute the smoothed score h(t,z) at time t and position z using self.n_samples samples from self.target_sampler."""

        # Begin by computing the score function of the target distribution at z using samples

        x = self.target_sampler.data.to(device) # (N, space_dims)

        # If t=0, then the score function points from z to the mean of the x's at each z in the batch
        if t == 0:
            target_mean = torch.mean(x, dim=0)  # (space_dims)
            return (
                target_mean.unsqueeze(0).repeat(z.shape[0], 1) - z
            )  # (B, space_dims) -- note the -z term is important here!
        else:
            # Rescale the x's by t
            tx = t * x
            # Compute Gaussian density weights of each rescaled target sample tx[i] centered at z
            tx = tx.unsqueeze(0)  # (1, n_samples, space_dims)
            tx_for_weights = tx.unsqueeze(2)  # (1, n_samples, 1, space_dims)
            # Seach for the nearest neighbors of (1/t)*z in the index
            D, I_nn = self.index.search((1 / t) * z.cpu().numpy(), self.num_nn)
            I_nn = torch.from_numpy(I_nn).to(device) # (B, num_nn)

            # Also draw random indices
            I_random = torch.randint(0, x.shape[0], (z.shape[0], self.n_random_samples), device=device) # (B, n_random_samples)

            z_for_weights = z.unsqueeze(1).repeat(
                1, self.n_points, 1
            )  # (B, n_points, d)
            z_for_weights = z_for_weights.unsqueeze(1)  # (B, 1, n_points, d)
            if self.fixed_noise is None:
                noise = (
                    torch.randn(z.shape[0], 1, self.n_points, z.shape[1], device=device)
                    * self.smoothing_sigma * t
                )  # (B, 1, n_points, d)
            else:
                noise = self.fixed_noise * t
            sigma = 1 - t

            # weights[b,k,j] = -||z[b] + noise[:,:,j,:] - tx[I[b,k]]||^2 / (2 * sigma^2)
            noisy_z = z_for_weights + noise # (B, 1, n_points, d)
            tx_for_weights = tx_for_weights.squeeze(0) # (n_samples, 1, space_dims)
            tx_for_weights_nn = tx_for_weights[I_nn] # (B, num_nn, 1, space_dims)

            weights_nn = -torch.sum(
                (tx_for_weights_nn - noisy_z) ** 2, dim=3
            ) / (2 * sigma**2)  # (B, num_nn, n_points)       


            # Do the same for the random samples
            tx_for_weights_random = tx_for_weights[I_random] # (B, n_random_samples, 1, space_dims)
            weights_random = -torch.sum(
                (tx_for_weights_random - noisy_z) ** 2, dim=3
            ) / (2 * sigma**2)  # (B, n_random_samples, n_points)


            # Now concatenate the weights and softmax
            weights = torch.cat([weights_nn, weights_random],dim=1)  # (B, num_nn+n_random_samples, n_points)


            # Subtract max to avoid underflow
            weights = (
                weights - torch.max(weights, dim=1, keepdim=True)[0]
            )  # (B, n_samples, n_points)
            # Compute softmax of weights manually, rescaling the weights corresponding to weights_random first
            # Exponentiate the weights
            weights = torch.exp(weights)  # (B, num_nn+n_random_samples, n_points)
            # Rescale the weights corresponding to weights_random by x.shape[0] - num_nn / n_random_samples
            weights[:, self.num_nn:] *= (x.shape[0] - self.num_nn) / self.n_random_samples
            # Normalize the weights
            weights = weights / torch.sum(weights, dim=1, keepdim=True)  # (B, num_nn+n_random_samples, n_points)
            # Take mean over n_points
            weights = torch.mean(weights, dim=2)  # (B, num_nn+n_random_samples)

            # Compute the score function of the target distribution at z using the weights
            # score[i] = sum_j weights[i,j] * (tx[j] - z[i]) / sigma^2

            # Concatenate the nearest neighbors and random samples indices
            I = torch.cat([I_nn, I_random], dim=1) # (B, num_nn+n_random_samples)
            tx_for_sum = tx.squeeze(0)[I] # (B, num_nn+n_random_samples, space_dims)
            avg_sample = torch.sum(weights.unsqueeze(2) * tx_for_sum, dim=1)  # (B, space_dims)
            score = (avg_sample - z.squeeze()) / (sigma**2)  # (B, space_dims)

            return score
        
class OneStepSampler(nn.Module):
    """Heat-smoothed closed-form score function."""

    def __init__(
        self,
        space_dims,
        train_samples,
        n_samples,
        smoothing_sigma=0,
        n_points=1,
        t=0.999
    ):
        super(OneStepSampler, self).__init__()
        self.space_dims = space_dims
        self.train_samples = train_samples  # target_sampler is a callable that returns a sample of size (B, space_dims)
        self.n_samples = (
            n_samples  # number of samples to use to estimate the score function
        )
        self.gumbel_sigma = smoothing_sigma  # Std of Gumbel noise
        self.n_points = n_points  # Number of draws of Gumbel noise
        self.t = t

    def forward(
        self,
        z,  # (B, space_dims)
    ):
        """Compute the smoothed score h(t,z) at time t and position z using self.n_samples samples from self.target_sampler."""

        # Begin by computing the score function of the target distribution at z using samples

        # Sample from the target distribution
        x = self.train_samples.to(device)  # (n_samples, space_dims)

        # Compute Gaussian density weights of each rescaled target sample tx[i] centered at z
        x = x.unsqueeze(0)  # (1, n_samples, space_dims)
        x_for_weights = x.unsqueeze(2)  # (1, n_samples, 1, space_dims)

        z_for_weights = z.unsqueeze(1).repeat(
            1, self.n_points, 1
        )  # (B, n_points, d)
        z_for_weights = z_for_weights.unsqueeze(1)  # (B, 1, n_points, space_dims)

        sigma = 1 - self.t

        gumbel_noise = torch.rand(z.shape[0], self.n_samples, self.n_points, device=device) # (B, n_samples, n_points)
        gumbel_noise = -self.gumbel_sigma * torch.log(-torch.log(gumbel_noise))
        # weights[i,j] = exp(-||z[i] - tx[j]||^2 / (2 * sigma^2))
        weights = -1 * (torch.sum(
            (x_for_weights - z_for_weights) ** 2, dim=3
        ) + gumbel_noise) / (2 * sigma**2)  # (B, n_samples, n_points)
        # Subtract max to avoid underflow
        weights = (
            weights - torch.max(weights, dim=1, keepdim=True)[0]
        )  # (B, n_samples, n_points)
        # Compute softmax of weights
        weights = F.softmax(weights, dim=1)
        # Take mean over n_points
        weights = torch.mean(weights, dim=2)  # (B, n_samples)

        # Compute the score function of the target distribution at z using the weights
        avg_sample = torch.sum(weights.unsqueeze(2) * x, dim=1)  # (B, space_dims)

        return avg_sample


class VelFromScore(nn.Module):
    """Spatially-varying and time-varying velocity field v(t,z) defined using closed-form formula."""

    def __init__(self, score_fn):
        super(VelFromScore, self).__init__()
        self.score_fn = score_fn

    def forward(
        self,
        t,  # t will always be a scalar during calls from torchdiffeq.odeint
        z,  # (B, space_dims)
    ):
        """Compute the velocity field v(t,z) at time t and position z using self.n_samples samples from self.target_sampler."""

        score = self.score_fn(t, z)  # (B, space_dims)

        if t == 0:
            return score
        else:
            # Compute the velocity field v(t,z) at z using the score function
            v = (1 / t) * (torch.squeeze(z) + (1 - t) * score)  # (B, space_dims)
            return v


# Define class SamplerAtT that takes a fixed set of N samples from a target dist and time T as parameters. Upon initialization, the fixed samples are rescaled by T.
# Has a .sample(n_samples) method that uniformly draws from the rescaled samples and adds Gaussian noise with std (1-T).


class SamplerAtT:
    def __init__(self, train_samples, T):
        self.train_samples = train_samples * T
        self.T = T

    def sample(self, n_samples):
        # Uniformly draw n_samples from the fixed samples
        x = self.train_samples[
            torch.randint(self.train_samples.shape[0], (n_samples,), device=device)
        ]
        # Add Gaussian noise with std (1-T)
        x = x + torch.randn(n_samples, self.train_samples.shape[1], device=device) * (1 - self.T)
        return x


def advect_samples(
    target_sampler,
    velocity_field,
    space_dims=2,
    step_size=1e-3,
    n_model_samples=1000,
    T=0,
    train_samples=None,
    display=False,
):
    """Advect from t=T to t=1 using the velocity field and return the advected samples along with samples from the true target dist."""
    if T == 0:
        base_samples = (1) * torch.randn(n_model_samples, space_dims).float().to(device)
    else:
        base_sampler = SamplerAtT(train_samples, T)
        base_samples = base_sampler.sample(n_model_samples).float().to(device)
    times = torch.linspace(T, 1, 2, device=device)
    options = dict(step_size=step_size)
    model_samples = odeint(
        velocity_field, base_samples, times, method="euler", options=options
    )[-1]  # (B, d)
    true_samples = target_sampler.sample(n_model_samples).float().to(device)  # (B, d)
    if display:
        plt.figure(figsize=(20, 20))
        plt.scatter(
            true_samples[:, 0].cpu().numpy(),
            true_samples[:, 1].cpu().numpy(),
            color="blue",
            s=100,
        )
        plt.scatter(
            model_samples[:, 0].detach().cpu().numpy(),
            model_samples[:, 1].detach().cpu().numpy(),
            color="red",
            alpha=0.2,
        )
    return model_samples, true_samples
