import torch

torch.set_default_tensor_type(torch.cuda.FloatTensor)
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math
import gzip
import itertools
from torch.distributions import Normal

device = torch.device('cuda')
BATCH_SIZE = 100


class StandardScaler(object):
    def __init__(self):
        pass

    def fit(self, data):
        """Runs two ops, one for assigning the mean of the data to the internal mean, and
        another for assigning the standard deviation of the data to the internal standard deviation.
        This function must be called within a 'with <session>.as_default()' block.

        Arguments:
        data (np.ndarray): A numpy array containing the input

        Returns: None.
        """
        self.mu = np.mean(data, axis=0, keepdims=True)
        self.std = np.std(data, axis=0, keepdims=True)
        self.std[self.std < 1e-12] = 1.0

    def transform(self, data):
        """Transforms the input matrix data using the parameters of this scaler.

        Arguments:
        data (np.array): A numpy array containing the points to be transformed.

        Returns: (np.array) The transformed dataset.
        """
        return (data - self.mu) / self.std

    def inverse_transform(self, data):
        """Undoes the transformation performed by this scaler.

        Arguments:
        data (np.array): A numpy array containing the points to be transformed.

        Returns: (np.array) The transformed dataset.
        """
        return self.std * data + self.mu


def init_weights(m):
    def truncated_normal_init(t, mean=0.0, std=0.01):
        torch.nn.init.normal_(t, mean=mean, std=std)
        while True:
            cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
            if not torch.sum(cond):
                break
            t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
        return t

    if type(m) == nn.Linear or isinstance(m, gaussianFC):
        input_dim = m.in_features
        truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(input_dim)))
        m.bias.data.fill_(0.0)

class TerminatedNetwork(nn.Module):
    def __init__(self, num_inputs, hidden_dim):
        super(TerminatedNetwork, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=0.0001)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = F.sigmoid(self.linear3(x))
        return x

class gaussianFC(nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, weight_decay: float = 0., bias: bool = True) -> None:
        super(gaussianFC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        self.weight_decay = weight_decay
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        pass


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        w_times_x = torch.matmul(input, self.weight)
        return torch.add(w_times_x, self.bias[None, :])  # w times x + b

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


class GaussianModel(nn.Module):
    def __init__(self, state_size, action_size, reward_size, hidden_size=200, learning_rate=1e-3, use_decay=False):
        super(GaussianModel, self).__init__()
        self.hidden_size = hidden_size
        self.nn1 = gaussianFC(state_size + action_size, hidden_size, weight_decay=0.000025)
        self.nn2 = gaussianFC(hidden_size, hidden_size, weight_decay=0.00005)
        self.nn3 = gaussianFC(hidden_size, hidden_size, weight_decay=0.000075)
        self.nn4 = gaussianFC(hidden_size, hidden_size, weight_decay=0.000075)
        self.use_decay = use_decay

        self.output_dim = state_size + reward_size
        # Add variance output
        self.nn5 = gaussianFC(hidden_size, self.output_dim * 2, weight_decay=0.0001)

        self.max_logvar = nn.Parameter((torch.ones((1, self.output_dim)).float() / 2).to(device), requires_grad=False)
        self.min_logvar = nn.Parameter((-torch.ones((1, self.output_dim)).float() * 10).to(device), requires_grad=False)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.apply(init_weights)
        self.swish = Swish()

    def forward(self, x, ret_log_var=False):
        nn1_output = self.swish(self.nn1(x))
        nn2_output = self.swish(self.nn2(nn1_output))
        nn3_output = self.swish(self.nn3(nn2_output))
        nn4_output = self.swish(self.nn4(nn3_output))
        nn5_output = self.nn5(nn4_output)

        mean = nn5_output[:, :self.output_dim]

        logvar = self.max_logvar - F.softplus(self.max_logvar - nn5_output[:, self.output_dim:])
        logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)

        if ret_log_var:
            return mean, logvar
        else:
            return mean, torch.exp(logvar)

    def get_decay_loss(self):
        decay_loss = 0.
        for m in self.children():
            if isinstance(m, gaussianFC):
                decay_loss += m.weight_decay * torch.sum(torch.square(m.weight)) / 2.
                # print(m.weight.shape)
                # print(m, decay_loss, m.weight_decay)
        return decay_loss

    def loss(self, mean, logvar, labels, inc_var_loss=True):
        """
        mean, logvar: Ensemble_size x N x dim
        labels: N x dim
        """
        assert len(mean.shape) == len(logvar.shape) == len(labels.shape) == 2
        inv_var = torch.exp(-logvar)
        if inc_var_loss:
            # Average over batch and dim, sum over ensembles.
            mse_loss = torch.mean(torch.mean(torch.pow(mean - labels, 2) * inv_var, dim=-1), dim=-1)
            var_loss = torch.mean(torch.mean(logvar, dim=-1), dim=-1)
            total_loss = mse_loss + var_loss
        else:
            mse_loss = torch.mean(torch.mean(torch.pow(mean - labels, 2), dim=-1), dim=-1)
            total_loss = mse_loss
        return total_loss, mse_loss

    def variational_loss(self, mean, logvar, labels, weights, inc_var_loss=True):
        """
        mean, logvar: Ensemble_size x N x dim
        labels: N x dim
        """
        assert len(mean.shape) == len(logvar.shape) == len(labels.shape) == 2
        inv_var = torch.exp(-logvar)
        if inc_var_loss:
            # Average over batch and dim, sum over ensembles.

            mse_loss = torch.mean(torch.pow(mean - labels, 2) * inv_var, dim=-1)
            log_prob = mse_loss + torch.mean(logvar, dim=-1)
            mse_loss = torch.mean(mse_loss, dim=-1)
            total_loss = torch.mean(log_prob * weights, dim=-1)

        else:
            mse_loss = torch.mean(torch.mean(torch.pow(mean - labels, 2), dim=-1), dim=-1)
            total_loss = mse_loss
        return total_loss, mse_loss

    def train(self, loss):
        self.optimizer.zero_grad()

        # loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar)
        # print('loss:', loss.item())
        if self.use_decay:
            loss += self.get_decay_loss()
        loss.backward()
        # for name, param in self.named_parameters():
        #     if param.requires_grad:
        #         print(name, param.grad.shape, torch.mean(param.grad), param.grad.flatten()[:5])
        self.optimizer.step()

class CombinedModel(nn.Module):
    def __init__(self, state_size, action_size, reward_size, hidden_size=200, learning_rate=1e-3, use_decay=False):
        super(CombinedModel, self).__init__()
        self.hidden_size = hidden_size
        self.nn1 = gaussianFC(state_size + action_size, hidden_size, weight_decay=0.000025)
        self.nn2 = gaussianFC(hidden_size, hidden_size, weight_decay=0.00005)
        self.nn3 = gaussianFC(hidden_size, hidden_size, weight_decay=0.000075)
        self.nn4 = gaussianFC(hidden_size, hidden_size, weight_decay=0.000075)
        self.use_decay = use_decay

        self.output_dim = state_size + reward_size
        # Add variance output
        self.nn5 = gaussianFC(hidden_size, self.output_dim * 4, weight_decay=0.0001)

        self.max_logvar = nn.Parameter((torch.ones((1, self.output_dim)).float() / 2).to(device), requires_grad=False)
        self.min_logvar = nn.Parameter((-torch.ones((1, self.output_dim)).float() * 10).to(device), requires_grad=False)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.apply(init_weights)
        self.swish = Swish()

    def forward(self, x, ret_log_var=False):
        nn1_output = self.swish(self.nn1(x))
        nn2_output = self.swish(self.nn2(nn1_output))
        nn3_output = self.swish(self.nn3(nn2_output))
        nn4_output = self.swish(self.nn4(nn3_output))
        nn5_output = self.nn5(nn4_output)

        mean_p = nn5_output[:, :self.output_dim]

        logvar_p = self.max_logvar - F.softplus(self.max_logvar - nn5_output[:, self.output_dim:self.output_dim*2])
        logvar_p = self.min_logvar + F.softplus(logvar_p - self.min_logvar)

        mean_q = nn5_output[:, self.output_dim*2:self.output_dim*3]

        logvar_q = self.max_logvar - F.softplus(self.max_logvar - nn5_output[:, self.output_dim*3:])
        logvar_q = self.min_logvar + F.softplus(logvar_q - self.min_logvar)

        if ret_log_var:
            return mean_p, logvar_p, mean_q, logvar_q
        else:
            return mean_p, torch.exp(logvar_p), mean_q, torch.exp(logvar_q)

    def get_decay_loss(self):
        decay_loss = 0.
        for m in self.children():
            if isinstance(m, gaussianFC):
                decay_loss += m.weight_decay * torch.sum(torch.square(m.weight)) / 2.
                # print(m.weight.shape)
                # print(m, decay_loss, m.weight_decay)
        return decay_loss

    def combined_loss(self, mean_p, logvar_p, mean_q, logvar_q, labels, weights):
        """
        mean, logvar: Ensemble_size x N x dim
        labels: N x dim
        """
        inv_var_p = torch.exp(-logvar_p)
        inv_var_q = torch.exp(-logvar_q)

        # Average over batch and dim, sum over ensembles.
        p_loss = torch.mean(torch.pow(mean_p - labels, 2) * inv_var_p, dim=-1)
        log_prob_p = p_loss + torch.mean(logvar_p, dim=-1)

        q_loss = torch.mean(torch.pow(mean_q - labels, 2) * inv_var_q, dim=-1)
        log_prob_q = q_loss + torch.mean(logvar_p, dim=-1)

        total_loss = torch.mean(log_prob_q * weights, dim=-1) + torch.mean(log_prob_p, dim=-1)
        return total_loss

    def train(self, loss):
        self.optimizer.zero_grad()

        if self.use_decay:
            loss += self.get_decay_loss()
        loss.backward()

        self.optimizer.step()

class GaussianDynamicsModel():
    def __init__(self, state_size, action_size, reward_size=1, hidden_size=200, use_decay=False):
        self.model_list = []
        self.state_size = state_size
        self.action_size = action_size
        self.reward_size = reward_size
        self.gaussian_model = GaussianModel(state_size, action_size, reward_size, hidden_size, use_decay=use_decay)
        self.scaler = StandardScaler()

    def optimize_model(self, inputs, labels):
        train_input = torch.from_numpy(inputs).float().to(device)
        train_label = torch.from_numpy(labels).float().to(device)
        print('3')

    def train(self, inputs, labels, batch_size=256, holdout_ratio=0., max_epochs_since_update=5):
        self._max_epochs_since_update = max_epochs_since_update
        self._epochs_since_update = 0
        self._state = {}
        self._snapshots = (None, 1e10)

        num_holdout = int(inputs.shape[0] * holdout_ratio)

        train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
        holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]

        self.scaler.fit(train_inputs)
        train_inputs = self.scaler.transform(train_inputs)
        holdout_inputs = self.scaler.transform(holdout_inputs)

        for epoch in itertools.count():

            train_idx = np.random.permutation(train_inputs.shape[0])
            # train_idx = np.vstack([np.arange(train_inputs.shape[0])] for _ in range(self.network_size))
            for start_pos in range(0, train_inputs.shape[0], batch_size):
                idx = train_idx[start_pos: start_pos + batch_size]
                train_input = torch.from_numpy(train_inputs[idx]).float().to(device)
                train_label = torch.from_numpy(train_labels[idx]).float().to(device)
                mean, logvar = self.gaussian_model(train_input, ret_log_var=True)
                loss, _ = self.gaussian_model.loss(mean, logvar, train_label)
                self.gaussian_model.train(loss)

            with torch.no_grad():
                losses = []
                for start_pos in range(0, holdout_inputs.shape[0], batch_size):

                    holdout_input = torch.from_numpy(holdout_inputs[start_pos: start_pos + batch_size]).float().to(device)
                    holdout_label = torch.from_numpy(holdout_labels[start_pos: start_pos + batch_size]).float().to(device)
                    holdout_mean, holdout_logvar = self.gaussian_model(holdout_input, ret_log_var=True)
                    _, holdout_mse_losses = self.gaussian_model.loss(holdout_mean, holdout_logvar, holdout_label,
                                                                     inc_var_loss=False)
                    holdout_mse_losses = holdout_mse_losses.detach().cpu().numpy()
                    losses.append(holdout_mse_losses)
                holdout_mse_losses = np.stack(losses, axis=0)
                holdout_mse_losses = np.mean(holdout_mse_losses, axis=0)

                break_train = self._save_best(epoch, holdout_mse_losses)
                if break_train:
                    # print(epoch)
                    break
    def _save_best(self, epoch, holdout_losses):
        updated = False
        current = holdout_losses
        _, best = self._snapshots
        improvement = (best - current) / best
        if improvement > 0.01:
            self._snapshots = (epoch, current)
            # self._save_state(i)
            updated = True
            # improvement = (best - current) / best

        if updated:
            self._epochs_since_update = 0
        else:
            self._epochs_since_update += 1
        if self._epochs_since_update > self._max_epochs_since_update:
            return True
        else:
            return False

    def predict(self, inputs, batch_size=1024):
        inputs = self.scaler.transform(inputs)
        gaussian_mean, gaussian_var = [], []
        for i in range(0, inputs.shape[0], batch_size):
            input = torch.from_numpy(inputs[i:min(i + batch_size, inputs.shape[0])]).float().to(device)
            b_mean, b_var = self.gaussian_model(input, ret_log_var=False)
            gaussian_mean.append(b_mean.detach().cpu().numpy())
            gaussian_var.append(b_var.detach().cpu().numpy())
        gaussian_mean = np.concatenate(gaussian_mean, axis=0)
        gaussian_var = np.concatenate(gaussian_var, axis=0)

        return gaussian_mean, gaussian_var


class VariationalGaussianDynamicsModel():
    def __init__(self, state_size, action_size, hidden_size=200, use_decay=False):
        self.model_list = []
        self.state_size = state_size
        self.action_size = action_size
        self.gaussian_model = GaussianModel(state_size, action_size, 0, hidden_size, use_decay=use_decay)
        self.scaler = StandardScaler()

    def optimize_model(self, inputs):
        print('3')

    def train(self, inputs, labels, states, actions, rewards, next_states, done, critic, q_c, batch_size=256, holdout_ratio=0., max_epochs_since_update=6, beta = 1):
        mask = (~done).astype(int)
        self._max_epochs_since_update = max_epochs_since_update
        self._epochs_since_update = 0
        self._state = {}
        self._snapshots = (None, 1e10)

        num_holdout = int(inputs.shape[0] * holdout_ratio)

        train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
        holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]

        self.scaler.fit(train_inputs)
        train_inputs = self.scaler.transform(train_inputs)
        holdout_inputs = self.scaler.transform(holdout_inputs)

        beta = torch.from_numpy(beta).float().to(device)
        for epoch in itertools.count():
            train_idx = np.random.permutation(train_inputs.shape[0])
            for start_pos in range(0, train_inputs.shape[0], batch_size):
                idx = train_idx[start_pos: start_pos + batch_size]
                train_input = torch.from_numpy(train_inputs[idx]).float().to(device)
                train_label = torch.from_numpy(train_labels[idx]).float().to(device)
                state_batch = torch.from_numpy(states[idx]).float().to(device)
                action_batch = torch.from_numpy(actions[idx]).float().to(device)
                next_state_batch = torch.from_numpy(next_states[idx]).float().to(device)
                reward_batch = torch.FloatTensor(rewards[idx]).to(device).unsqueeze(-1)
                mask_batch = torch.FloatTensor(mask[idx]).to(device).unsqueeze(-1)

                with torch.no_grad():
                    qf1_pi, qf2_pi = critic(state_batch, action_batch)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    next_state_action, next_state_log_pi, _ = q_c.sample(next_state_batch)

                    qf1_next_target, qf2_next_target = critic(next_state_batch, next_state_action)
                    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)

                    next_q_value = reward_batch + mask_batch * 0.99 * (min_qf_next_target)

                    td_error = torch.exp(torch.clamp( (min_qf_pi - next_q_value)/beta, min = -4, max = 4)).squeeze(-1)

                mean, logvar = self.gaussian_model(train_input, ret_log_var=True)
                loss, _ = self.gaussian_model.variational_loss(mean, logvar, train_label, td_error)
                self.gaussian_model.train(loss)


            with torch.no_grad():
                losses = []
                for start_pos in range(0, holdout_inputs.shape[0], batch_size):
                    holdout_input = torch.from_numpy(holdout_inputs[start_pos: min(start_pos + batch_size, holdout_inputs.shape[0])]).float().to(device)
                    holdout_label = torch.from_numpy(holdout_labels[start_pos: min(start_pos + batch_size, holdout_inputs.shape[0])]).float().to(device)
                    holdout_mean, holdout_logvar = self.gaussian_model(holdout_input, ret_log_var=True)
                    _, holdout_mse_losses = self.gaussian_model.loss(holdout_mean, holdout_logvar, holdout_label, inc_var_loss=False)
                    holdout_mse_losses = holdout_mse_losses.detach().cpu().numpy()
                    losses.append(holdout_mse_losses)
                holdout_mse_losses = np.stack(losses, axis=0)
                holdout_mse_losses = np.mean(holdout_mse_losses, axis=0)
                break_train = self._save_best(epoch, holdout_mse_losses)
                if break_train:
                    # print(epoch)
                    break

    def reverse_train(self, train_inputs, p_model, critic, q_c, beta, batch_size=256):

        self.scaler.fit(train_inputs)
        train_inputs = self.scaler.transform(train_inputs)
        beta = torch.from_numpy(beta).float().to(device)

        for epoch in range(5):
            train_idx = np.random.permutation(train_inputs.shape[0])
            for start_pos in range(0, train_inputs.shape[0], batch_size):
                idx = train_idx[start_pos: start_pos + batch_size]
                train_input = torch.from_numpy(train_inputs[idx]).float().to(device)

                mean_q, log_var_q = self.gaussian_model(train_input, ret_log_var=True)
                mean_q += train_input[:, :self.state_size]

                var_q = log_var_q.exp()
                std_q = torch.sqrt(var_q)
                normal_q = Normal(mean_q, std_q)

                next_state_batch = normal_q.rsample()

                mean_p, log_var_p = p_model.gaussian_model(train_input, ret_log_var=True)
                mean_p[:, 1:] += train_input[:, :self.state_size]

                var_p = log_var_p.exp()

                q_term = -torch.sum(log_var_q, 1)/2

                p_term = (torch.square(mean_q-mean_p[:, 1:]) + var_q)/var_p[:,1:]
                p_term = -torch.sum(p_term, 1)/2

                qf1_pi = critic(next_state_batch)

                not_done = (torch.abs(next_state_batch[:, 0]) < 7.0) * (torch.abs(next_state_batch[:, 1]) < 7.0)
                not_done = not_done.unsqueeze(-1)

                min_qf_pi = (qf1_pi / beta) * not_done

                # variational_loss = (q_term - p_term).mean()
                variational_loss = (q_term - p_term - min_qf_pi).mean()
                self.gaussian_model.train(variational_loss)
            # print(torch.max(min_qf_pi))
            # print(torch.min(min_qf_pi))



    def _save_best(self, epoch, holdout_losses):
        updated = False
        current = holdout_losses
        _, best = self._snapshots
        improvement = (best - current) / best
        if improvement > 0.01:
            self._snapshots = (epoch, current)
            # self._save_state(i)
            updated = True
            # improvement = (best - current) / best

        if updated:
            self._epochs_since_update = 0
        else:
            self._epochs_since_update += 1
        if self._epochs_since_update > self._max_epochs_since_update:
            return True
        else:
            return False
    def predict(self, inputs, batch_size=1024, factored=True):
        inputs = self.scaler.transform(inputs)
        gaussian_mean, gaussian_var = [], []
        for i in range(0, inputs.shape[0], batch_size):
            input = torch.from_numpy(inputs[i:min(i + batch_size, inputs.shape[0])]).float().to(device)
            b_mean, b_var = self.gaussian_model(input, ret_log_var=False)
            gaussian_mean.append(b_mean.detach().cpu().numpy())
            gaussian_var.append(b_var.detach().cpu().numpy())
        gaussian_mean = np.concatenate(gaussian_mean, axis=0)
        gaussian_var = np.concatenate(gaussian_var, axis=0)

        return gaussian_mean, gaussian_var


# class CombinedDynamicsModel():
#     def __init__(self, state_size, action_size, hidden_size=200, use_decay=False):
#         self.model_list = []
#         self.state_size = state_size
#         self.action_size = action_size
#         self.gaussian_model = CombinedModel(state_size, action_size, 1, hidden_size, use_decay=use_decay)
#         self.scaler = StandardScaler()
#         # self.terminated_model = TerminatedNetwork(state_size, 256)
#
#     def train(self, inputs, labels, states, actions, rewards, next_states, done, critic, q_c, batch_size=256, holdout_ratio=0., max_epochs_since_update=6, beta = 1):
#         mask = (~done).astype(int)
#
#         num_holdout = int(inputs.shape[0] * holdout_ratio)
#
#         train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
#
#         self.scaler.fit(train_inputs)
#         train_inputs = self.scaler.transform(train_inputs)
#
#         beta = torch.from_numpy(beta).float().to(device)
#         for epoch in range(5):
#             train_idx = np.random.permutation(train_inputs.shape[0])
#             for start_pos in range(0, train_inputs.shape[0], batch_size):
#                 idx = train_idx[start_pos: start_pos + batch_size]
#                 train_input = torch.from_numpy(train_inputs[idx]).float().to(device)
#                 train_label = torch.from_numpy(train_labels[idx]).float().to(device)
#                 state_batch = torch.from_numpy(states[idx]).float().to(device)
#                 action_batch = torch.from_numpy(actions[idx]).float().to(device)
#                 next_state_batch = torch.from_numpy(next_states[idx]).float().to(device)
#                 reward_batch = torch.FloatTensor(rewards[idx]).to(device).unsqueeze(-1)
#                 mask_batch = torch.FloatTensor(mask[idx]).to(device).unsqueeze(-1)
#
#                 with torch.no_grad():
#                     qf1_pi, qf2_pi = critic(state_batch, action_batch)
#                     min_qf_pi = torch.min(qf1_pi, qf2_pi)
#                     next_state_action, next_state_log_pi, _ = q_c.sample(next_state_batch)
#
#                     qf1_next_target, qf2_next_target = critic(next_state_batch, next_state_action)
#                     min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
#
#                     next_q_value = reward_batch + mask_batch * 0.99 * (min_qf_next_target)
#
#                     td_error = torch.exp(torch.clamp( (min_qf_pi - next_q_value)/beta, min = -4, max = 4)).squeeze(-1)
#
#                 # exp_weight = torch.exp(torch.clamp((train_reward + V_values - Q_values), max=5)).detach()
#
#                 mean_p, logvar_p, mean_q, logvar_q = self.gaussian_model(train_input, ret_log_var=True)
#                 loss = self.gaussian_model.combined_loss(mean_p, logvar_p, mean_q, logvar_q, train_label, td_error)
#                 self.gaussian_model.train(loss)
#
#                 # predicted_outputs = self.terminated_model(next_state_batch)
#                 # t_loss = nn.BCELoss()(predicted_outputs, mask_batch)
#                 #
#                 # self.terminated_model.optimizer.zero_grad()
#                 # t_loss.backward()
#                 # self.terminated_model.optimizer.step()
#
#     def predict(self, inputs, batch_size=1024):
#         inputs = self.scaler.transform(inputs)
#         gaussian_mean, gaussian_var = [], []
#         for i in range(0, inputs.shape[0], batch_size):
#             input = torch.from_numpy(inputs[i:min(i + batch_size, inputs.shape[0])]).float().to(device)
#             b_mean, b_var = self.gaussian_model(input, ret_log_var=False)
#             gaussian_mean.append(b_mean.detach().cpu().numpy())
#             gaussian_var.append(b_var.detach().cpu().numpy())
#         gaussian_mean = np.concatenate(gaussian_mean, axis=0)
#         gaussian_var = np.concatenate(gaussian_var, axis=0)
#
#         return gaussian_mean, gaussian_var
#
# class Swish(nn.Module):
#     def __init__(self):
#         super(Swish, self).__init__()
#
#     def forward(self, x):
#         x = x * F.sigmoid(x)
#         return x
#
#
# def get_data(inputs_file_path, labels_file_path, num_examples):
#     with open(inputs_file_path, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
#         bytestream.read(16)
#         buf = bytestream.read(28 * 28 * num_examples)
#         data = np.frombuffer(buf, dtype=np.uint8) / 255.0
#         inputs = data.reshape(num_examples, 784)
#
#     with open(labels_file_path, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
#         bytestream.read(8)
#         buf = bytestream.read(num_examples)
#         labels = np.frombuffer(buf, dtype=np.uint8)
#
#     return np.array(inputs, dtype=np.float32), np.array(labels, dtype=np.int8)
#
#
