import torch
import random
import numpy as np
import time
from SIP.args import cmd_args
from statsmodels.tsa.stattools import acf
from itertools import combinations
from tqdm import tqdm
import math
from torch.distributions import Exponential

class BaseSampler(object):
    def __init__(self, seed=cmd_args.seed):
        self.rng = np.random.default_rng(seed)
        self._reset()

    def _reset(self):
        self.logp = []
        self.trace = []
        self.elapse = 0
        self.succ = 0

    def _step(self, model, t, x, *args, **kwargs):
        raise NotImplementedError

    def sample(self, model, T=1000, method='Gibbs', *args, **kwargs):
        self._reset()
        self.x = model.init_state()
        self.energy = model.energy(self.x)

        progress_bar = tqdm(range(T))
        progress_bar.set_description(f"[{method}, {model.info}]")
        t_begin = time.time()
        for t in progress_bar:
            self._step(model, t, *args, **kwargs)
            self.logp.append([self.energy.item()])
            self.trace.append([model.trace(self.x)])
        t_end = time.time()
        self.elapse += t_end - t_begin

        return self.logp, self.trace, self.elapse, self.succ


class RWSampler(BaseSampler):
    def __init__(self, R=1, seed=cmd_args.seed):
        super().__init__(seed=seed)
        self.R = R

    def _step(self, model, t, R=1, *args, **kwargs):
        r = self.rng.integers(1, self.R+1)
        idx = self.rng.choice(model.num_nodes, r, replace=False)
        new_state = model.flip_state(self.x, idx)
        new_energy = model.energy(new_state)
        if self.rng.random() < torch.exp(- new_energy + self.energy):
            self.x = new_state
            self.energy = new_energy
            self.succ += 1


class GibbsSampler(BaseSampler):
    def __init__(self, R=1, seed=cmd_args.seed):
        super().__init__(seed=seed)
        self.R = R

    def _step(self, model, t, *args, **kwargs):
        coordinates = [(t * self.R + i) % model.num_nodes for i in range(self.R)]
        new_state = []
        for r in range(self.R + 1):
            for idx in combinations(coordinates, r):
                new_state.append(model.flip_state(self.x, idx))

        new_state = torch.stack(new_state, dim=0)
        new_energy = model.energy(new_state)
        prob = torch.exp(- new_energy + self.energy)
        try:
            idx = torch.multinomial(prob, 1, replacement=True).item()
        except:
            idx = 0

        if idx != 0:
            self.succ += 1
            self.x = new_state[idx]
            self.energy = new_energy[idx]

class LBSampler(BaseSampler):
    def __init__(self, R=1, seed=cmd_args.seed):
        super().__init__(seed)
        assert R == 1
        self.R = R

    def _step(self, model, t, *args, **kwargs):
        change_x = model.change(self.x)
        prob_x = torch.softmax(- change_x / 2, dim=0)
        idx = torch.multinomial(prob_x, self.R, replacement=True)
        y = model.flip_state(self.x, idx)
        energy_y = model.energy(y)

        change_y = model.change(y)
        prob_y = torch.softmax(- change_y / 2, dim=0)

        if self.rng.random() < torch.exp(- energy_y + self.energy) * prob_y[idx] / prob_x[idx]:
            self.x = y
            self.energy = energy_y
            self.prob = prob_y
            self.succ += 1


class MTSampler(BaseSampler):
    def __init__(self, K=20, R=1, seed=cmd_args.seed):
        super().__init__(seed=seed)
        self.R = R
        self.K = K

    def _step(self, model, t, *args, **kwargs):
        new_state = []
        R = self.rng.integers(1, 2 * self.R, self.K)
        for k in range(self.K):
            idx = self.rng.choice(model.num_nodes, R[k], replace=False)
            new_state.append(model.flip_state(self.x, idx))
        new_state = torch.stack(new_state, dim=0)
        new_energy = model.energy(new_state)
        prob = torch.exp((- new_energy + self.energy) / 2)
        Zx = prob.sum()
        idx = int(torch.multinomial(prob, 1, replacement=True))
        y = new_state[idx]
        energy_y = new_energy[idx]

        new_state = [self.x]
        for k in range(self.K - 1):
            r = R[k]
            idx = self.rng.choice(model.num_nodes, r, replace=False)
            new_state.append(model.flip_state(y, idx))
        new_state = torch.stack(new_state, dim=0)
        new_energy = model.energy(new_state)
        Zy = torch.exp((- new_energy + energy_y) / 2).sum()

        if self.rng.random() < Zx / Zy:
            self.x = y
            self.energy = energy_y
            self.succ += 1


class HBSampler(BaseSampler):
    def __init__(self, block_size=10, hamming_dist=1, seed=cmd_args.seed):
        super().__init__(seed=seed)
        self.block_size = block_size
        self.hamming_dis = hamming_dist
        self._inds = []

    def _step(self, model, t, *args, **kwargs):
        if len(self._inds) == 0:
            self._inds = self.rng.permutation(model.num_nodes)

        inds = self._inds[:self.block_size]
        self._inds = self._inds[self.block_size:]


    @staticmethod
    def hamming_ball(n, k):
        ball = [np.zeros((n,))]
        for i in range(k + 1)[1:]:
            it = combinations(range(n), i)
            for tup in it:
                vec = np.zeros((n,))
                for ind in tup:
                    vec[ind] = 1.
                ball.append(vec)
        return ball




class GWGSampler(BaseSampler):
    def __init__(self, R=1, seed=cmd_args.seed):
        super().__init__(seed=seed)
        self.R = R

    def _step(self, model, t, *args, **kwargs):
        R = int(self.rng.integers(1, 2 * self.R, 1))
        grad = model.grad(self.x)
        delta_x = model.delta(self.x)
        energy_change = delta_x * grad
        prob_x = torch.softmax(- energy_change / 2, dim=-1)
        log_x = torch.log_softmax(- energy_change / 2, dim=-1)
        if prob_x.sum() != 1:
            prob_x /= prob_x.sum()
        idx = torch.multinomial(prob_x, R, replacement=True)
        y = model.flip_state(self.x, idx)

        energy_y = model.energy(y)
        grad_y = model.grad(y)
        delta_y = model.delta(y)
        new_energy_change = delta_y * grad_y
        log_y = torch.log_softmax(- new_energy_change / 2, dim=-1)

        Qx = 0
        Qy = 0
        for id in idx:
            Qx += log_x[id]
            Qy += log_y[id]

        if self.rng.random() < torch.exp(- energy_y + self.energy + Qy - Qx):
            self.x = y
            self.energy = energy_y
            self.succ += 1


class MSFSampler(BaseSampler):
    def __init__(self, R=1, seed=cmd_args.seed):
        super().__init__(seed)
        self.R = R
        self.grad = None

    def _step(self, model, t, *args, **kwargs):
        R = int(self.rng.integers(1, 2 * self.R, 1))
        log_ratio = 0
        Delta = []
        Idx = []

        # temp = torch.linspace(1 / 3, 2 / 3, R + 2)[1:-1]
        # first step
        if self.grad is None:
            grad_x = model.grad(self.x)
        else:
            grad_x = self.grad
        delta = model.delta(self.x)
        Delta.append(delta)
        energy_change = delta * grad_x
        prob = torch.softmax(- energy_change / 2, dim=-1)
        # prob = torch.softmax(- energy_change * temp[0], dim=-1)
        idx = torch.multinomial(prob, 1, replacement=True)
        Idx.append(idx)
        x = model.flip_state(self.x, idx)

        # intermediate steps
        for _ in range(1, R):
            delta = model.delta(x)
            Delta.append(delta)
            energy_change = delta * grad_x
            prob = torch.softmax(- energy_change / 2, dim=-1)
            # prob = torch.softmax(- energy_change * temp[_], dim=-1)
            idx = torch.multinomial(prob, 1, replacement=True)
            Idx.append(idx)
            x = model.flip_state(x, idx)

        # last step
        delta = model.delta(x)
        Delta.append(delta)
        grad_y = model.grad(x)
        energy_y = model.energy(x)
        delta_xy =  torch.stack(Delta[:-1], dim=0) * grad_x
        delta_yx = torch.stack(Delta[1:], dim=0) * grad_y
        log_xy = torch.log_softmax(- delta_xy / 2, dim=-1)
        log_yx = torch.log_softmax(- delta_yx / 2, dim=-1)
        # log_xy = torch.log_softmax(- delta_xy * temp[:, None], dim=-1)
        # log_yx = torch.log_softmax(- delta_yx * temp.flip(0)[:, None], dim=-1)
        for i, idx in enumerate(Idx):
            log_ratio += log_yx[i][idx] - log_xy[i][idx]

        if self.rng.random() < torch.exp(- energy_y + self.energy + log_ratio):
            self.x = x
            self.energy = energy_y
            self.succ += 1


class MSASampler(BaseSampler):
    """
    Multi Step Accurate Sampler
    """
    def __init__(self, R=1, seed=cmd_args.seed):
        super().__init__(seed=seed)
        self.R = R
        self.energy_change = None
        self.Z = None

    def _step(self, model, t, *args, **kwargs):
        R = int(self.rng.integers(1, 2 * self.R, 1))
        x = self.x
        indices = []
        for t in range(R):
            if t == 0:
                if self.Z is None:
                    energy_change = model.change(x)
                    Zx = torch.logsumexp(- energy_change / 2, dim=-1)
                else:
                    energy_change = self.energy_change
                    Zx = self.Z
            else:
                energy_change = model.change(x)
            prob = torch.exp(-energy_change / 2)
            idx = torch.multinomial(prob, 1, replacement=True)
            indices.append(idx)
            x = model.flip_state(x, idx)

        energy_change_y = model.change(x)
        Zy = torch.logsumexp(- energy_change_y / 2, dim=-1)


        if self.rng.random() < torch.exp(Zx - Zy):
            self.x = x
            self.energy = model.energy(self.x)
            self.succ += 1
            self.energy_change = energy_change_y
            self.Z = Zy


# class MSFSampler(BaseSampler):
#     """
#     Multi Step Fast Sampler
#     """
#     def __init__(self, R=1, seed=cmd_args.seed):
#         super().__init__(seed=seed)
#         self.R = R
#         self.Zx = None
#         self.prob_x = None
#
#     def _step(self, model, t, *args, **kwargs):
#         R = int(self.rng.integers(1, 2 * self.R, 1))
#         if self.Zx is None:
#             grad = model.grad(self.x)
#             delta_x = model.delta(self.x)
#             energy_change = delta_x * grad
#             Zx = torch.exp(- energy_change / 2).sum()
#             prob_x = torch.softmax(- energy_change / 2, dim=0)
#             if prob_x.sum() != 1:
#                 prob_x /= prob_x.sum()
#         else:
#             prob_x = self.prob_x
#             Zx = self.Zx
#         idx = torch.multinomial(prob_x, R, replacement=False)
#         y = model.flip_state(self.x, idx)
#
#         new_grad = model.grad(y)
#         delta_y = model.delta(y)
#         energy_change_y = delta_y * new_grad
#         Zy = torch.exp(- energy_change_y / 2).sum()
#
#         if self.rng.random() < Zx / Zy:
#             self.x = y
#             self.energy = model.energy(self.x)
#             self.succ += 1
#             self.prob_x = torch.softmax(- energy_change_y / 2, dim=0)
#             if self.prob_x.sum() != 1:
#                 self.prob_x /= self.prob_x.sum()
#             self.Zx = Zy



#

# class MSFSampler(BaseSampler):
#     def __init__(self, R=1, seed=cmd_args.seed):
#         super().__init__(seed)
#         self.R = R
#
#     def _step(self, model, t, *args, **kwargs):
#         R = int(self.rng.integers(1, 2 * self.R, 1))
#
#         # first step
#         grad_x = model.grad(self.x)
#         delta = model.delta(self.x)
#         energy_change = delta * grad_x
#         prob = torch.softmax(- energy_change / 2, dim=-1)
#         Zx = torch.logsumexp(- energy_change / 2, dim=-1)
#         idx = torch.multinomial(prob, 1, replacement=True)
#         x = model.flip_state(self.x, idx)
#
#         # intermediate steps
#         for _ in range(2, R):
#             delta = model.delta(x)
#             energy_change = delta * grad_x
#             prob = torch.softmax(- energy_change / 2, dim=-1)
#             idx = torch.multinomial(prob, 1, replacement=True)
#             x = model.flip_state(x, idx)
#
#         # last step
#         delta_y = model.delta(x)
#         energy_y = model.energy(x)
#         energy_change_y = delta_y * grad_x
#         Zy = torch.logsumexp(- energy_change_y / 2, dim=-1)
#
#         if self.rng.random() < torch.exp(- energy_y + self.energy + (grad_x * (x - self.x)).sum(dim=-1) + Zx - Zy):
#             self.x = x
#             self.energy = energy_y
#             self.succ += 1

# class MSASampler(BaseSampler):
#     """
#     Multi Step Accurate Sampler
#     """
#     def __init__(self, R=1, seed=cmd_args.seed):
#         super().__init__(seed=seed)
#         self.R = R
#
#     def _step(self, model, t, *args, **kwargs):
#         R = int(self.rng.integers(1, 2 * self.R, 1))
#
#         indices = []
#         change = []
#         energy_change = model.change(self.x)
#         prob = torch.exp(- energy_change / 2)
#         idx = torch.multinomial(prob, 1, replacement=True)
#         indices.append(idx)
#         change.append(energy_change)
#         x = model.flip_state(self.x, idx)
#
#         for t in range(1, R):
#             energy_change = model.change(x)
#             prob = torch.exp(- energy_change / 2)
#             prob[torch.cat(indices, dim=0)] = 0
#             idx = torch.multinomial(prob, 1, replacement=True)
#             indices.append(idx)
#             change.append(energy_change)
#             x = model.flip_state(self.x, idx)
#
#         energy_change = model.change(x)
#         energy_y = model.energy(x)
#         change.append(energy_change)
#         change_x = torch.stack(change[:-1], dim=0)
#         change_y = torch.stack(change[1:], dim=0)
#
#         for i, idx in enumerate(indices):
#             change_x[i + 1:, idx] = float('inf')
#             change_y[:i, idx] = float('inf')
#             # change_x[i+1:, idx].fill_(float('inf'))
#             # change_y[:i, idx].fill_(float('inf'))
#         change_x2y = torch.log_softmax(- change_x / 2, dim=-1)
#         change_y2x = torch.log_softmax(- change_y / 2, dim=-1)
#
#         log_ratio = 0
#         for i, idx in enumerate(indices):
#             log_ratio += change_y2x[i, idx] - change_x2y[i, idx]
#
#
#         if self.rng.random() < torch.exp(- energy_y + self.energy + log_ratio):
#             self.x = x
#             self.energy = model.energy(self.x)
#             self.succ += 1

# class ContinuousTimeSampler(object):
#     def __init__(self, deltat=1e-3, seed=cmd_args.seed):
#         self.rng = np.random.default_rng(seed)
#         self.rs = np.random.RandomState(seed)
#         self.trace = []
#         self.t_track = []
#         self.succ = 0
#         self.deltat = deltat
#         self.t = 0
#
#     def _reset(self):
#         self.t = 0
#         self.trace = []
#         self.t_track = []
#         self.succ = 0
#
#     def _step(self, model, t, x, *args, **kwargs):
#         raise NotImplementedError
#
#     def sample(self, model, T=1000, method='Gibbs', max_lag=100, *args, **kwargs):
#         self._reset()
#         self.x = model.init_state()
#         self.energy = model.energy(self.x)
#
#         progress_bar = tqdm(range(T))
#         progress_bar.set_description(f"[{method}]")
#         total_time = 0
#         for t in progress_bar:
#             t_begin = time.time()
#             self._step(model, t, *args, **kwargs)
#             t_end = time.time()
#             total_time += t_end - t_begin
#         auto_cor = acf(self.trace[T//2:], nlags=min([T//2 - 1, max_lag]), fft=True)
#         rho = 0
#         for i in range(len(auto_cor)):
#             if auto_cor[i] < 0:
#                 break
#             rho += auto_cor[i]
#         ess_n = (T // 2) / (1 + 2 * rho)
#         cost_t = total_time
#         ess_t = ess_n / cost_t
#         acc_rate = self.succ / T
#         print(f"method: {method}, avg jump time: {np.mean(self.t_track):.6f}")
#         return self.trace, cost_t, auto_cor, ess_n, ess_t, acc_rate


# class StocBalancedSampler(BaseSampler):
#     def __init__(self, K=1, seed=cmd_args.seed):
#         super().__init__(seed=seed)
#         self.K = K
#         self.p = None
#         self.count = 0
#
#     def _step(self, model, t, R=1, *args, **kwargs):
#         R = self.rs.randint(0, R - 1, self.K) + 1
#         indices = [[]]
#         for k in range(self.K):
#             if R[k] > 1:
#                 indices.append([self.rng.choice(model.num_nodes, R[k] - 1, p=self.p, replace=False)])
#         # indices = [[]]
#         # if R > 1:
#         #     for k in range(self.K):
#         #         indices.append([self.rng.choice(model.num_nodes, self.rs.randint(R) + 1, p=self.p, replace=False)])
#         state_Y = []
#         for i in indices:
#             state_Y.append(model.flip_state(self.x, i))
#         state_Y = torch.stack(state_Y, dim=0)
#         energy_Y, grad_Y = model.grad(state_Y)
#         delta_Y = model.delta(state_Y)
#         energy_Y_change = energy_Y - self.energy + delta_Y * grad_Y
#         energy_Y_change = energy_Y_change.reshape(-1)
#         prob_Y = torch.softmax(- energy_Y_change / 2, dim=0)
#         idx = torch.multinomial(prob_Y, 1, replacement=True)
#         Qx = prob_Y[idx]
#         idx_y = (int(idx) // model.num_nodes, int(idx) % model.num_nodes)
#         y = model.flip_state(state_Y[idx_y[0]], [idx_y[1]])
#
#         state_X = []
#         for i in indices:
#             state_X.append(model.flip_state(y, i))
#         state_X = torch.stack(state_X, dim=0)
#         energy_X, grad_X = model.grad(state_X)
#         energy_y = energy_X[0]
#         delta_X = model.delta(state_X)
#         energy_X_change = (energy_X - energy_y + delta_X * grad_X).reshape(-1)
#         prob_X = torch.softmax(- energy_X_change / 2, dim=0)
#         Qy = prob_X[idx]
#
#         if self.rs.rand() < torch.exp(- energy_y + self.energy) * Qy / Qx:
#             self.x = y
#             self.energy = energy_y
#             self.succ += 1
#             self.p = prob_X[: model.num_nodes]
#             self.p[idx_y[1]] = 0
#             self.p /= self.p.sum()
#             self.p = self.p.detach().cpu()
#
# class FixRadiusSampler(BaseSampler):
#     def __init__(self, K=1, seed=cmd_args.seed):
#         super().__init__(seed=seed)
#         self.K = K
#         self.p = None
#         self.count = 0
#
#     def _step(self, model, t, R=1, *args, **kwargs):
#         self.count += R
#         indices = [[]]
#         if R > 1:
#             for k in range(self.K):
#                 indices.append([self.rng.choice(model.num_nodes, R - 1, p=self.p, replace=False)])
#         state_Y = []
#         for i in indices:
#             state_Y.append(model.flip_state(self.x, i))
#         state_Y = torch.stack(state_Y, dim=0)
#         energy_Y, grad_Y = model.grad(state_Y)
#         delta_Y = model.delta(state_Y)
#         energy_Y_change = energy_Y - self.energy + delta_Y * grad_Y
#         energy_Y_change = energy_Y_change.reshape(-1)
#         prob_Y = torch.softmax(- energy_Y_change / 2, dim=0)
#         idx = torch.multinomial(prob_Y, 1, replacement=True)
#         Qx = prob_Y[idx]
#         idx_y = (int(idx) // model.num_nodes, int(idx) % model.num_nodes)
#         y = model.flip_state(state_Y[idx_y[0]], [idx_y[1]])
#
#         state_X = []
#         for i in indices:
#             state_X.append(model.flip_state(y, i))
#         state_X = torch.stack(state_X, dim=0)
#         energy_X, grad_X = model.grad(state_X)
#         energy_y = energy_X[0]
#         delta_X = model.delta(state_X)
#         energy_X_change = (energy_X - energy_y + delta_X * grad_X).reshape(-1)
#         prob_X = torch.softmax(- energy_X_change / 2, dim=0)
#         Qy = prob_X[idx]
#
#         if self.rs.rand() < torch.exp(- energy_y + self.energy) * Qy / Qx:
#             self.x = y
#             self.energy = energy_y
#             self.succ += 1
#             self.p = prob_X[: model.num_nodes]
#             self.p[idx_y[1]] = 0
#             self.p /= self.p.sum()
#             self.p = self.p.detach().cpu()
#
#
# class StocAdaptiveSampler(BaseSampler):
#     def __init__(self, K=1, R=1, burn_in=2000, seed=cmd_args.seed):
#         super().__init__(seed=seed)
#         self.K = K
#         self.R = R
#         self.deltat = None
#         self.deltar = np.zeros(K, dtype=int)
#         self.p = None
#         self.burn_in = burn_in
#
#     def _step(self, model, t, *args, **kwargs):
#         # currently, the energy change is specfic for ising model
#         R = self.R + self.deltar
#         indices = [[]]
#         for k in range(self.K):
#             if R[k] > 1:
#                 indices.append([self.rng.choice(model.num_nodes, R[k] - 1, p=self.p, replace=False)])
#         state_Y = []
#         for i in indices:
#             state_Y.append(model.flip_state(self.x, i))
#         state_Y = torch.stack(state_Y, dim=0)
#         energy_Y, grad_Y = model.grad(state_Y)
#         delta_Y = model.delta(state_Y)
#         energy_Y_change = energy_Y - self.energy + delta_Y * grad_Y
#         energy_Y_change = energy_Y_change.reshape(-1)
#         prob_Y = torch.softmax(- energy_Y_change / 2, dim=0)
#         idx = torch.multinomial(prob_Y, 1, replacement=True)
#         Qx = prob_Y[idx]
#         idx_y = (int(idx) // model.num_nodes, int(idx) % model.num_nodes)
#         y = model.flip_state(state_Y[idx_y[0]], [idx_y[1]])
#
#         state_X = []
#         for i in indices:
#             state_X.append(model.flip_state(y, i))
#         state_X = torch.stack(state_X, dim=0)
#         energy_X, grad_X = model.grad(state_X)
#         energy_y = energy_X[0]
#         delta_X = model.delta(state_X)
#         energy_X_change = (energy_X - energy_y + delta_X * grad_X).reshape(-1)
#         prob_X = torch.softmax(- energy_X_change / 2, dim=0)
#         Qy = prob_X[idx]
#
#         if self.rs.rand() < torch.exp(- energy_y + self.energy) * Qy / Qx:
#             self.x = y
#             self.energy = energy_y
#             self.succ += 1
#             # todo: p for sample
#             self.p = prob_X[: model.num_nodes]
#             self.p[idx_y[1]] = 0
#             self.p /= self.p.sum()
#             self.p = self.p.detach().cpu()
#             if t > self.burn_in:
#                 Z = torch.exp(- energy_X_change[:model.num_nodes] / 2).sum()
#                 if self.deltat == None:
#                     self.deltat = self.R / Z
#                 else:
#                     self.deltat = 0.9 * self.deltat + 0.1 * self.R / Z
#                 d = Exponential(Z)
#                 jump_time = d.sample([self.K, self.R+1])
#                 for k in range(self.K):
#                     if jump_time[k, :self.R - 2].sum() > self.deltat:
#                         self.deltar[k] = -2
#                     elif jump_time[k, :self.R - 1].sum() > self.deltat:
#                         self.deltar[k] = -1
#                     elif jump_time[k, :self.R].sum() > self.deltat:
#                         self.deltar[k] = 0
#                     elif jump_time[k].sum() > self.deltat:
#                         self.deltar[k] = 1
#                     else:
#                         self.deltar[k] = 2
#                 # print(self.deltar)


# class ContinuousTimeLocallyBalancedSampler(ContinuousTimeSampler):
#     def __init__(self, deltat=2e-2, seed=cmd_args.seed):
#         super().__init__(deltat, seed)
#
#     def _step(self, model, t, R=1, *args, **kwargs):
#         if t * self.deltat < self.t:
#             self.trace.append([self.energy])
#         else:
#             new_state = []
#             for r in range(1, R+1):
#                 for idx in combinations(range(model.num_nodes), r):
#                     new_state.append(model.flip_state(self.x, idx))
#             new_state = torch.stack(new_state, dim=0)
#             new_energy = model.energy(new_state)
#             prob = torch.exp((- new_energy + self.energy) / 2)
#             Zx = prob.sum()
#             prob /= Zx
#             idx = int(self.rng.choice(new_energy.shape[0], size=1, p=prob))
#             y = new_state[idx]
#             energy_y = new_energy[idx].item()
#
#             self.x = y
#             self.energy = energy_y
#             self.succ += 1
#             d = Exponential(Zx)
#             jump_time = d.sample()
#             self.t += jump_time
#             self.t_track.append(jump_time)
#             self.trace.append([self.energy])
#
#
#
# class TabuSampler(ContinuousTimeSampler):
#     def __init__(self, deltat=3e-2, seed=cmd_args.seed):
#         super().__init__(deltat, seed)
#         self.tau = 1
#
#     def _step(self, model, t, R=1, *args, **kwargs):
#         if t == 0:
#             self.alpha = torch.ones(model.num_nodes)
#         if t * self.deltat < self.t:
#             self.trace.append([self.energy])
#         else:
#             new_state = []
#             for r in range(1, R+1):
#                 for idx in combinations(range(model.num_nodes), r):
#                     new_state.append(model.flip_state(self.x, idx))
#             new_state = torch.stack(new_state, dim=0)
#             new_energy = model.energy(new_state)
#             prob = torch.exp((- new_energy + self.energy) / 2)
#             mask = self.alpha == self.tau
#             Zp = prob[mask].sum()
#             Zn = prob[~mask].sum()
#             Zx = Zp if Zp > Zn else Zn
#             d = Exponential(Zx)
#             jump_time = d.sample()
#             self.t += jump_time
#             self.t_track.append(jump_time)
#             if self.rs.rand() < Zp / Zx:
#                 prob[~mask] = 0
#                 idx = int(torch.multinomial(prob, 1))
#                 self.alpha[idx] *= -1
#                 y = new_state[idx]
#                 energy_y = new_energy[idx].item()
#                 self.x = y
#                 self.energy = energy_y
#                 self.succ += 1
#             if self.rs.rand() < (Zx - Zp) / Zx:
#                 self.tau *= -1
#             self.trace.append([self.energy])
#
#
# class BlockBalancedSampler(BaseSampler):
#     def __init__(self, seed=cmd_args.seed, B=10):
#         super().__init__(seed)
#         self.B = B
#
#     def _step(self, model, t, R=1, *args, **kwargs):
#         # if t % (model.num_nodes // self.B) == 0:
#         if t == 0:
#             self.coordinates = self.rng.permutation(model.num_nodes)
#         block_idx = t % (model.num_nodes // self.B)
#         new_state = []
#         for r in range(1, R+1):
#             for idx in combinations(self.coordinates[block_idx * self.B : (block_idx + 1) * self.B], r):
#                 new_state.append(model.flip_state(self.x, idx))
#         new_state = torch.stack(new_state, dim=0)
#         new_energy = model.energy(new_state)
#         prob = torch.exp((- new_energy + self.energy) / 2)
#         Zx = prob.sum()
#         prob /= Zx
#         idx = int(self.rng.choice(new_energy.shape[0], size=1, p=prob))
#         y = new_state[idx]
#         energy_y = new_energy[idx].item()
#
#         new_state = []
#         for r in range(1, R + 1):
#             for idx in combinations(self.coordinates[block_idx * self.B : (block_idx + 1) * self.B], r):
#                 new_state.append(model.flip_state(y, idx))
#         new_state = torch.stack(new_state, dim=0)
#         new_energy = model.energy(new_state)
#         Zy = torch.exp((- new_energy + energy_y) / 2).sum()
#
#         if self.rs.rand() < Zx / Zy:
#             self.x = y
#             self.energy = energy_y
#             self.succ += 1

# class MultipleTryBalancedSampler(RandomWalkSampler):
#     def __init__(self, K=1, seed=cmd_args.seed):
#         super().__init__(seed=seed)
#         self.K = K
#
#     def _step(self, model, t, R=1, *args, **kwargs):
#         if t == 0:
#             self._init_p(model.num_nodes, R)
#         new_state = []
#         for k in range(self.K):
#             r = self._sample_radius()
#             idx = self.rng.choice(model.num_nodes, r, replace=False)
#             new_state.append(model.flip_state(self.x, idx))
#         new_state = torch.stack(new_state, dim=0)
#         new_energy = model.energy(new_state)
#         prob = torch.exp((- new_energy + self.energy) / 2)
#         Zx = prob.sum()
#         idx = int(self.rng.choice(self.K, size=1, p=prob / Zx))
#         y = new_state[idx]
#         energy_y = new_energy[idx].item()
#
#         new_state = [self.x]
#         for k in range(self.K - 1):
#             r = self._sample_radius()
#             idx = self.rng.choice(model.num_nodes, r, replace=False)
#             new_state.append(model.flip_state(y, idx))
#         new_state = torch.stack(new_state, dim=0)
#         new_energy = model.energy(new_state)
#         Zy = torch.exp((- new_energy + energy_y) / 2).sum()
#
#         if self.rs.rand() < Zx / Zy:
#             self.x = y
#             self.energy = energy_y
#             self.succ += 1

## todo: add s-w

