import os.path as osp
import pathlib

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import torch as th
import torch.distributions as D
from ANONYMOUStorch.utils import as_numpy
from torch.distributions.mixture_same_family import MixtureSameFamily

from datasets.img_tools import ImageEnergy, prepare_image
from utils.ksd import KSD

from .base_set import BaseSet
# pylint: disable=global-statement

g_fns = {}


def register(func):
    global g_fns
    g_fns[func.__name__] = func
    return func


@register
def checkerboard(x):
    x_pos_mod = th.div(2 * x[:, 0], 1, rounding_mode="floor")
    y_pos_mod = th.div(2 * x[:, 1], 1, rounding_mode="floor")
    # dx, dy = th.abs(x[:, 0]) - 2, th.abs(x[:,1]) - 2
    p_dist = th.abs(x) - 1
    sign_d = th.clip(
        th.norm(th.clip(p_dist, 0.0), dim=1) + th.clip(th.max(p_dist, dim=1)[0], max=0),
        0,
    )
    value = (x_pos_mod + y_pos_mod) % 2 * 1e6 + sign_d ** 2 * 1e5
    return value


@register
def strip(x):
    x_pos_mod = th.div(2 * x[:, 0], 1, rounding_mode="floor")
    value = x_pos_mod % 2 * 1e6 + ((x_pos_mod < -2) | (x_pos_mod >= 2)) * 1e6
    return th.clip(value, 0, 1e6)


# pylint: disable=attribute-defined-outside-init


class Base2DSet(BaseSet):
    """4x4, [-1,1]"""

    def __init__(self, len_data, is_linear=True):
        super().__init__(len_data, is_linear)
        self.data = th.tensor([0.0, 0.0]).cuda()  # pylint: disable= not-callable
        self.data_ndim = 2
        self.worker = KSD(self.lgv_gradient, beta=0.2)

    @th.no_grad()
    def ksd(self, points):
        with th.no_grad():
            gt_ksd = self.worker(self.sample(5000), adjust_beta=True)
        return self.worker(points) - gt_ksd

    def viz_pdf(self, fsave="checkboard-density.png", lim=6):
        x = th.linspace(-lim, lim, 100).cuda()
        xx, yy = th.meshgrid(x, x)
        points = th.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
        un_pdf = th.exp(-1 * self.get_disc(points))
        fig, axs = plt.subplots(1, 1, figsize=(1 * 7, 1 * 7))
        axs.imshow(as_numpy(un_pdf.view(100, 100)))
        fig.savefig(fsave)
        plt.close(fig)

    def cal_gt_big_z(self):
        x = th.linspace(-10, 10, 500).cuda()
        xx, yy = th.meshgrid(x, x)
        points = th.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
        un_pdf = th.exp(-1 * self.get_disc(points))
        pdf = un_pdf / un_pdf.sum()
        return (pdf * un_pdf).sum()


class FnPs(Base2DSet):
    def __init__(self, len_data, is_linear=True, fn_str="checkerboard"):
        global g_fns
        self.fn = g_fns[fn_str]
        super().__init__(len_data, is_linear)

    def get_gt_disc(self, x):
        return self.fn(x) / self.temp_t


class MG2D(Base2DSet):
    def __init__(self, len_data, is_linear=True, nmode=3, xlim=3.0, scale=0.15):
        mix = D.Categorical(th.ones(nmode).cuda())
        angles = np.linspace(0, 2 * 3.14, nmode, endpoint=False)
        poses = xlim * np.stack([np.cos(angles), np.sin(angles)]).T
        poses = th.from_numpy(poses).cuda()
        comp = D.Independent(
            D.Normal(poses, th.ones(size=(nmode, 2)).cuda() * scale * xlim), 1
        )

        self.gmm = MixtureSameFamily(mix, comp)

        super().__init__(len_data, is_linear)

    def get_gt_disc(self, x):
        return -self.gmm.log_prob(x)

    def sample(self, batch_size):
        return self.gmm.sample((batch_size,))


class Fun(Base2DSet):
    def __init__(self, len_data, is_linear=True):
        mix = D.Categorical(th.ones(9).cuda())
        xx, yy = np.mgrid[-5:5:3j, -5:5:3j]
        poses = th.from_numpy(np.vstack([xx.flatten(), yy.flatten()]).T).cuda()
        comp = D.Independent(
            D.Normal(poses, 1.0 * th.ones(size=(9, 2)).cuda() * np.sqrt(0.3)), 1
        )
        self.gmm = MixtureSameFamily(mix, comp)

        super().__init__(len_data, is_linear)

    def get_gt_disc(self, x):
        return -self.gmm.log_prob(x)

    def sample(self, batch_size):
        return self.gmm.sample((batch_size,))


class Chlg2D(Base2DSet):
    def __init__(self, len_data, is_linear=True):
        mix = D.Categorical(th.ones(3).cuda())
        mean = th.tensor([[0.9, 0.0], [-0.75, 0.0], [0.6, 0.9]]).cuda() / 0.3
        cov = (
            th.tensor(
                [
                    [[0.063, 0.0], [0.0, 0.0045]],
                    [[0.063, 0.0], [0.0, 0.0045]],
                    [[0.09, 0.085], [0.085, 0.09]],
                ]
            ).cuda()
            / 0.09
        )
        comp = D.Independent(D.multivariate_normal.MultivariateNormal(mean, cov), 0)

        self.gmm = MixtureSameFamily(mix, comp)
        super().__init__(len_data, is_linear)

    def get_gt_disc(self, x):
        return -th.log(
            (
                th.exp(self.gmm.log_prob(x))
                + th.exp(self.gmm.log_prob(x.flip(dims=(-1,))))
            )
            / 2.0
        )

    def sample(self, batch_size):
        return self.gmm.sample((batch_size,))

    @th.no_grad()
    def ksd(self, points):
        with th.no_grad():
            gt_ksd = self.worker(self.sample(1000), adjust_beta=True)
        return self.worker(points[:1000]) - gt_ksd  # fix cuda problem


class ImgPs(Base2DSet):
    def __init__(self, len_data, is_linear=True):
        fimg = osp.join(pathlib.Path(__file__).parent.resolve(), "labrador.jpg")
        # fimg = osp.join(pathlib.Path(__file__).parent.resolve(), "smiley.jpg")
        img = mpimg.imread(fimg)
        _, img_energy = prepare_image(
            img,
            crop=(10, 710, 240, 940),
            white_cutoff=225,
            gauss_sigma=3,
            background=0.01,
        )
        self.energy = ImageEnergy(
            img_energy[::-1].copy(), mean=[350, 350], scale=[300, 300]
        )
        super().__init__(len_data, is_linear)

    def get_gt_disc(self, x):
        return self.energy.energy(x.cpu()).cuda().flatten()
