import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F


class Ring(object):
    def __init__(self, device="cpu", sigma = 2, r_min=1, r_max=2, edge_width = 0.1):
        self.device = device
        self.sigma = sigma
        self.r_min = r_min
        self.r_max = r_max
        self.edge_width = edge_width
    def bondary_eq(self,X):
        return 1/4 * (torch.sum(X**2,dim=-1) - self.r_min**2) * (torch.sum(X**2,dim=-1) - self.r_max**2)
    def nabla_bound(self,X):
        return X * ((X**2).sum(dim=-1)[:,None] -(self.r_max**2 + self.r_min**2)/2)
    def logp(self, X):
        logp_ori = -0.5 * 2 * np.log(2 * np.pi) - 2 * np.log(self.sigma) -1/(2*self.sigma**2) * torch.sum(X**2,dim=-1)
        logp_ori[torch.logical_or(self.r_min>X.norm(2,dim=1),X.norm(2,dim=1) > self.r_max)] = -1000
        return logp_ori
    def sample(self,bs = 1000):
        samples = torch.randn([bs * 10,2]).to(self.device) * self.sigma
        return (samples[torch.logical_and(self.r_min<samples.norm(2,dim=1),samples.norm(2,dim=1) < self.r_max)])[:bs]
    def score(self, X):
        score_ori = -X/self.sigma**2
        score_ori[X.norm(2,dim=1) > self.r_max] = 0
        score_ori[X.norm(2,dim=1) < self.r_min] = 0
        return score_ori
    def contour_plot(self, ax, fnet=None, samples=None, save_to_path="./result.png", fig_title = "", quiver=False, num_pt = 5000, plot_edge=True):
        bbox = [-3, 3, -3, 3]
        xx, yy = np.mgrid[bbox[0]:bbox[1]:500j, bbox[2]:bbox[3]:500j]
        positions = np.vstack([xx.ravel(), yy.ravel()])
        f = -np.log(-np.reshape(self.logp(torch.Tensor(positions.T).to(self.device)).cpu().numpy(), xx.shape))
        if samples is None:
            samples = self.sample(num_pt)
            edge_sample = samples[self.bondary_eq(
                    samples - np.sign(self.bondary_eq(samples)[:,None])*self.edge_width * self.nabla_bound(samples)/(self.nabla_bound(samples).norm(2,dim=-1)[:,None])
                    ) * self.bondary_eq(samples) < 0]
            samples = samples.cpu().numpy()
            edge_sample = edge_sample.cpu().numpy()
            samples = [samples, edge_sample]
        cxx, cyy = np.mgrid[bbox[0]:bbox[1]:30j, bbox[2]:bbox[3]:30j]
        ax.axis(bbox)
        ax.set_aspect(abs(bbox[1]-bbox[0])/abs(bbox[3]-bbox[2]))
        ax.contourf(xx, yy, f, cmap='Blues', alpha=0.8, levels = 11)
        ax.plot(samples[0].cpu()[:, 0], samples[0].cpu()[:,1], '.', markersize= 2, color='#ff7f0e')
        if plot_edge:
            ax.plot(samples[1].cpu()[:, 0], samples[1].cpu()[:,1], '.', markersize= 2, color='red')
        if quiver:
            cpositions = np.vstack([cxx.ravel(), cyy.ravel()])
            if fnet is None:
                scores = np.reshape(self.nabla_bound(torch.Tensor(cpositions.T).to(self.device)).detach().cpu().numpy(), cpositions.T.shape)
            else:
                scores = np.reshape(fnet(torch.Tensor(cpositions.T).to(self.device)).detach().cpu().numpy(), cpositions.T.shape)
            ax.quiver(cxx, cyy, scores[:, 0], scores[:, 1], width=0.002)
        plt.xticks(fontsize = 15)
        plt.yticks(fontsize = 15)
        ax.set_title(fig_title, fontsize = 30, y=1.04)
        if save_to_path is not None:
            torch.save(scores,save_to_path.replace(".png","scores.pt"))
            torch.save(samples,save_to_path.replace(".png",".pt"))
            plt.savefig(save_to_path, bbox_inches='tight')


class Cardioid(object):
    def __init__(self, device="cpu", sigma = 2, shape_param = 1.2, edge_width = 0.1):
        self.device = device
        self.sigma = sigma
        self.shape_param = shape_param
        self.edge_width = edge_width
    def bondary_eq(self,X):
        return X[:,0]**2 + (X[:,1] * self.shape_param - torch.pow(X[:,0]**2,1/3))**2 - 4
    def nabla_bound(self,X):
        auxi_term = (X[:,1] * self.shape_param - torch.pow(X[:,0]**2+0.0001,1/3))

        return torch.cat(((2 * X[:,0] - 4/3 * (torch.sigmoid(X[:,0]) -0.5) * 2 * 1/(torch.pow(torch.abs(X[:,0])+0.0001,1/3)+0.0001) * auxi_term)[:,None], 
                          (self.shape_param * auxi_term)[:,None]
                          )
                          ,dim=-1
                        )
    def logp(self, X):
        bond_eq = self.bondary_eq(X)
        logp_ori = -0.5 * 2 * np.log(2 * np.pi) - 2 * np.log(self.sigma) -1/(2*self.sigma**2) * torch.sum(X**2,dim=-1)
        logp_ori[bond_eq > 0] = -1000
        return logp_ori
    def sample(self,bs = 1000):
        X = torch.randn([bs * 10,2]).to(self.device) * self.sigma
        bond_eq = self.bondary_eq(X)
        return (X[(bond_eq) < 0])[:bs]
    def score(self, X):
        score_ori = -X/self.sigma**2
        return score_ori
    def contour_plot(self, ax, fnet=None, samples=None, save_to_path="./result.png", fig_title = "", quiver=False, num_pt = 5000, plot_edge=True):
        bbox = [-3, 3, -3, 3]
        xx, yy = np.mgrid[bbox[0]:bbox[1]:500j, bbox[2]:bbox[3]:500j]
        positions = np.vstack([xx.ravel(), yy.ravel()])
        f = -np.log(-np.reshape(self.logp(torch.Tensor(positions.T).to(self.device)).cpu().numpy(), xx.shape))
        if samples is None:
            samples = self.sample(num_pt)
            edge_sample = samples[self.bondary_eq(
                    samples - np.sign(self.bondary_eq(samples)[:,None])*self.edge_width * self.nabla_bound(samples)/(self.nabla_bound(samples).norm(2,dim=-1)[:,None])
                    ) * self.bondary_eq(samples) < 0]
            samples = samples.cpu().numpy()
            edge_sample = edge_sample.cpu().numpy()
            samples = [samples, edge_sample]
        
        cxx, cyy = np.mgrid[bbox[0]:bbox[1]:30j, bbox[2]:bbox[3]:30j]
        ax.axis(bbox)
        ax.set_aspect(abs(bbox[1]-bbox[0])/abs(bbox[3]-bbox[2]))
        ax.contourf(xx, yy, f, cmap='Blues', alpha=0.8, levels = 11)
        ax.plot(samples[0].cpu()[:, 0], samples[0].cpu()[:,1], '.', markersize= 2, color='#ff7f0e')
        if plot_edge:
            ax.plot(samples[1].cpu()[:, 0], samples[1].cpu()[:,1], '.', markersize= 2, color='red')
        if quiver:
            cpositions = np.vstack([cxx.ravel(), cyy.ravel()])
            if fnet is None:
                scores = np.reshape(self.nabla_bound(torch.Tensor(cpositions.T).to(self.device)).detach().cpu().numpy(), cpositions.T.shape)
            else:
                scores = np.reshape(fnet(torch.Tensor(cpositions.T).to(self.device)).detach().cpu().numpy(), cpositions.T.shape)
            ax.quiver(cxx, cyy, scores[:, 0], scores[:, 1], width=0.002)
        plt.xticks(fontsize = 15)
        plt.yticks(fontsize = 15)
        ax.set_title(fig_title, fontsize = 30, y=1.04)
        if save_to_path is not None:
            torch.save(scores,save_to_path.replace(".png","scores.pt"))
            torch.save(samples,save_to_path.replace(".png",".pt"))
            plt.savefig(save_to_path, bbox_inches='tight')


class DoubleMoon(object):
    def __init__(self, device="cpu", bd = -5.0, edge_width = 0.1):
        self.device = device
        self.bd = bd
        self.edge_width = edge_width
    def bondary_eq(self, X):
        means_d1 = torch.tensor([[3.0, -3.0]]).to(self.device)
        logp_ori = (-2 * (torch.sqrt(torch.sum(X**2, dim=-1)) - 3.0)**2 +
                     torch.logsumexp(-2 * ((X[:,0])[:,None]- means_d1)**2, dim = 1)
                    )
        
        return -logp_ori + self.bd
    def nabla_bound(self,X):
        score_part_1 = -4 * (X - (3.0 * (X+1e-6)/(torch.sqrt(torch.sum(X**2, dim = -1))[:,None] + 1e-6)))
        X1_minus_means = (X[:,0])[:,None]- torch.tensor([[3.0, -3.0]]).to(X.device)
        score_part_2 = -4 * (X1_minus_means * F.softmax((-2) * X1_minus_means**2, dim=-1)).sum(dim=-1)
        nabla_bound_X = score_part_1
        nabla_bound_X[:,0] = nabla_bound_X[:,0] + score_part_2
        return - nabla_bound_X
    def logp(self, X):
        means_d1 = torch.tensor([[3.0, -3.0]]).to(self.device)
        logp_ori = (-2 * (torch.sqrt(torch.sum(X**2, dim=-1)) - 3.0)**2 + torch.logsumexp(-2 * ((X[:,0])[:,None]- means_d1)**2, dim = 1))
        logp_ori[logp_ori < self.bd] = -10000
        return logp_ori
    def score(self, X):
        score_part_1 = -4 * (X - (3.0 * (X+1e-6)/(torch.sqrt(torch.sum(X**2, dim = -1))[:,None] + 1e-6)))
        X1_minus_means = (X[:,0])[:,None]- torch.tensor([[3.0, -3.0]]).to(X.device)
        score_part_2 = -4 * (X1_minus_means * F.softmax((-2) * X1_minus_means**2, dim=-1)).sum(dim=-1)
        score_ori = score_part_1
        score_ori[:,0] = score_ori[:,0] + score_part_2
        score_ori[self.bondary_eq(X) > 0] = 0
        return score_ori
    def sample(self, bs=1000, loop = 10000, epsilon_0 = 5 * 1e-4, alpha = 0, accept_rate=False):
        """
        In general, we can sample the agent ground truth with langevin dynamics
        """
        Z = torch.zeros(bs * 10, 2).to(self.device)
        for t in range(0, loop):
            compu_targetscore = self.score(Z)
            learn_rate = epsilon_0/(1+t)**alpha
            Z = Z + learn_rate/2 * compu_targetscore + np.sqrt(learn_rate) * torch.randn([Z.shape[0],2]).to(self.device)
        bond_eq = self.bondary_eq(Z)
        constrained_samples = Z[(bond_eq) < 0]
        accept_rate = constrained_samples.shape[0]/Z.shape[0]
        if accept_rate: 
            return constrained_samples[:bs], accept_rate
        else:
            return constrained_samples[:bs]
    
    def contour_plot(self, ax, fnet=None, samples=None, save_to_path="./result.png", fig_title = "", quiver=False, num_pt = 5000, plot_edge=True):
        bbox = [-4, 4, -4, 4]
        xx, yy = np.mgrid[bbox[0]:bbox[1]:500j, bbox[2]:bbox[3]:500j]
        positions = np.vstack([xx.ravel(), yy.ravel()])
        f = -np.log(-np.reshape(self.logp(torch.Tensor(positions.T).to(self.device)).cpu().numpy(), xx.shape))
        if samples is None:
            samples, accept_rate = self.sample(num_pt,accept_rate=True)
            assert samples.shape[0] == num_pt, "The number of samples does not meet the requirements."
            edge_sample = samples[self.bondary_eq(
                    samples - np.sign(self.bondary_eq(samples)[:,None])*self.edge_width * self.nabla_bound(samples)/(self.nabla_bound(samples).norm(2,dim=-1)[:,None])
                    ) * self.bondary_eq(samples) < 0]
            samples = samples.cpu().numpy()
            edge_sample = edge_sample.cpu().numpy()
            samples = [samples, edge_sample]
        
        cxx, cyy = np.mgrid[bbox[0]:bbox[1]:30j, bbox[2]:bbox[3]:30j]
        ax.axis(bbox)
        ax.set_aspect(abs(bbox[1]-bbox[0])/abs(bbox[3]-bbox[2]))
        ax.contourf(xx, yy, f, cmap='Blues', alpha=0.8, levels = 11)
        ax.plot(samples[0].cpu()[:, 0], samples[0].cpu()[:,1], '.', markersize= 2, color='#ff7f0e')
        if plot_edge:
            ax.plot(samples[1].cpu()[:, 0], samples[1].cpu()[:,1], '.', markersize= 2, color='red')
        if quiver:
            cpositions = np.vstack([cxx.ravel(), cyy.ravel()])
            if fnet is None:
                scores = np.reshape(self.nabla_bound(torch.Tensor(cpositions.T).to(self.device)).detach().cpu().numpy(), cpositions.T.shape)
            else:
                scores = np.reshape(fnet(torch.Tensor(cpositions.T).to(self.device)).detach().cpu().numpy(), cpositions.T.shape)
            ax.quiver(cxx, cyy, scores[:, 0], scores[:, 1], width=0.002)
        plt.xticks(fontsize = 15)
        plt.yticks(fontsize = 15)
        ax.set_title(fig_title, fontsize = 30, y=1.04)
        if save_to_path is not None:
            torch.save(scores,save_to_path.replace(".png","scores.pt"))
            torch.save(samples,save_to_path.replace(".png",".pt"))
            plt.savefig(save_to_path, bbox_inches='tight')