import abc

import torch as th
from torch.utils.data import Dataset


class BaseSet(abc.ABC, Dataset):
    def __init__(self, len_data, is_linear=True):
        self.num_sample = len_data
        self.temp_t = 0.0 if is_linear else 1.0
        self.get_disc = self.linear_intepolate if is_linear else self.temp_intepolate
        self.big_t = self.temp_t
        self.big_z = 1.0

        self.data = None
        self.data_ndim = None

    def cal_gt_big_z(self):  # pylint: disable = no-self-use
        return 1.0

    @abc.abstractmethod
    def get_gt_disc(self, x):
        return

    @property
    def ndim(self):
        return self.data_ndim

    def sample(self, batch_size):  # pylint: disable=no-self-use
        del batch_size
        raise NotImplementedError

    @property
    def big_t(self):
        return self.temp_t

    @big_t.setter
    def big_t(self, value):
        self.temp_t = value
        # BUG: we do not need calculation of z
        # self.big_z = self.cal_gt_big_z()

    def temp_intepolate(self, x):
        return self.get_gt_disc(x) / self.temp_t

    def linear_intepolate(self, x):
        neg_log_p = self.get_gt_disc(x)

        # calculate gaussian part
        x = x.flatten(start_dim=1)
        x_res = x - th.mean(x, dim=0, keepdim=True)
        neg_log_gaussian = 0.5 * th.sum(x_res * x_res, dim=1)

        assert neg_log_p.shape == neg_log_gaussian.shape

        return self.temp_t * neg_log_gaussian + (1 - self.temp_t) * neg_log_p

    def terminal_loss(self, x):
        disc_loss = self.get_disc(x)
        quad_loss = -0.5 * x.pow(2).sum(dim=-1)
        return disc_loss + quad_loss

    def lgv_gradient(self, x):
        with th.no_grad():
            copy_x = x.detach().clone()
            copy_x.requires_grad = True
            with th.enable_grad():
                # TODO: should it be _gt_disc or _disc
                self.get_disc(copy_x).sum().backward()
                lgv_data = copy_x.grad.data
            return lgv_data

    def energy_unpdf(self, x):
        return th.exp(-1 * self.get_disc(x))

    def cal_big_z(self, sample):
        return th.exp(-1 * self.get_disc(sample)).mean()

    # def z_error(self, sample, weight=None):
    # return self.cal_big_z(sample) - self.big_z

    def __len__(self):
        return self.num_sample

    def __getitem__(self, idx):
        return self.data
