import math
from typing import Tuple, Mapping

import torch
import numpy as np
import matplotlib.pyplot as plt


class Gaussian():
    def __init__(self, dim=2, scale=1, loc=[0, 0], var=0.1):
        self.dim = dim
        self.scale = scale
        self.var = var
        self.loc = torch.tensor(loc)
        
    def density(self, point):
        device = point.device
        self.loc = self.loc.to(device)
        mahalanobis_dist = torch.sum((point - self.loc) ** 2, dim=-1)
        var = (1 / (2 * math.pi * self.var))
        pdf = var * torch.exp(-mahalanobis_dist / (2 * self.var))
        return pdf
    
    def sample(self, num_samples=1):
        m = torch.distributions.multivariate_normal.MultivariateNormal(
            self.loc.cpu().float(), math.sqrt(self.var) * torch.eye(self.dim)
        )
        data = m.sample((num_samples,))
        return data


def generate_meshgrid(
    reward_model: Mapping,
    xlim=[-1, 1], ylim=[-1, 1], num_points=100
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    x = np.linspace(xlim[0], xlim[1], num_points)
    y = np.linspace(ylim[0], ylim[1], num_points)
    x, y = np.meshgrid(x, y)
    z = np.zeros_like(x)
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            point = np.expand_dims(np.array([x[i, j], y[i, j]]), axis=0)
            point = torch.tensor(point).float()
            reward = reward_model(point).detach().numpy()
            z[i, j] = reward
    return x, y, z


def plot_2d(
    samples: torch.Tensor, reward_function, 
    title=None, xlim=None, ylim=None
    ):
    with torch.no_grad():
        rewards = reward_function(samples)
        fig = plt.figure()
        plt.scatter(
            samples[:, 0], samples[:, 1], 
            s=1, c=rewards, cmap='viridis',
        )
        plt.xlim(xlim)
        plt.ylim(ylim)
        plt.colorbar(label='Reward', shrink=0.8)
        plt.title(title)
        plt.show()


def plot_3d(
    x, y, z, title=None, 
    xlabel='x', ylabel='y', zlabel='z'
    ):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(x, y, z, cmap='viridis')
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel(zlabel)
    ax.set_title(title)
    plt.show()