import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
import itertools
from torch.optim import Adam
import gym
import time
from torchdiffeq import odeint_adjoint as odeint
from tqdm import tqdm
import random
import os
import sac.core as core
from sac import ReplayBuffer
from utils.logx import EpochLogger
from utils.run_utils import setup_logger_kwargs
from utils.run_utils import setup_logger_kwargs
import pdb


def choose_nonlinearity(name):
    if name == 'tanh':
        nl = torch.tanh
    elif name == 'relu':
        nl = torch.relu
    elif name == 'sigmoid':
        nl = torch.sigmoid
    elif name == 'softplus':
        nl = torch.nn.functional.softplus
    elif name == 'selu':
        nl = torch.nn.functional.selu
    elif name == 'elu':
        nl = torch.nn.functional.elu
    elif name == 'swish':
        nl = lambda x: x * torch.sigmoid(x)
    else:
        raise ValueError("nonlinearity not recognized")
    return nl


class MLP(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, nonlinearity='tanh'):
        super(MLP, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, output_dim, bias=False)

        for l in [self.linear1, self.linear2, self.linear3]:
            torch.nn.init.orthogonal_(l.weight)  # use a principled initialization

        self.nonlinearity = choose_nonlinearity(nonlinearity)

    def forward(self, x):
        h = self.nonlinearity(self.linear1(x))
        h = self.nonlinearity(self.linear2(h))
        return self.linear3(h)


class MLPAutoencoder(torch.nn.Module):
    '''A salt-of-the-earth MLP Autoencoder + some edgy res connections'''

    def __init__(self, input_dim, hidden_dim, latent_dim, nonlinearity='tanh'):
        super(MLPAutoencoder, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear4 = torch.nn.Linear(hidden_dim, latent_dim)

        self.linear5 = torch.nn.Linear(latent_dim, hidden_dim)
        self.linear6 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear7 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear8 = torch.nn.Linear(hidden_dim, input_dim)

        for l in [self.linear1, self.linear2, self.linear3, self.linear4,
                  self.linear5, self.linear6, self.linear7, self.linear8]:
            torch.nn.init.orthogonal_(l.weight)  # use a principled initialization

        self.nonlinearity = choose_nonlinearity(nonlinearity)

    def encode(self, x):
        h = self.nonlinearity(self.linear1(x))
        h = h + self.nonlinearity(self.linear2(h))
        h = h + self.nonlinearity(self.linear3(h))
        return self.linear4(h)

    def decode(self, z):
        h = self.nonlinearity(self.linear5(z))
        h = h + self.nonlinearity(self.linear6(h))
        h = h + self.nonlinearity(self.linear7(h))
        return self.linear8(h)

    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat


class ODEFunc(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.mlp = MLP(input_dim, hidden_dim, output_dim)
        self.a = None

    def forward(self, t, x):
        return self.mlp(torch.cat((x, self.a), dim=1))



class NODANoPartial(nn.Module):

    def __init__(self, observation_space, action_space, hidden_dim_ode=256, hidden_dim_ae=256, latent_dim=32,
                 use_ode=True, tol=1e-5):
        super(NODANoPartial, self).__init__()
        observation_shape = observation_space.shape
        action_shape = action_space.shape
        if type(observation_shape) is tuple:
            observation_shape = observation_shape[0]
        if type(action_shape) is tuple:
            action_shape = action_shape[0]
        self.observation_shape = observation_shape
        self.action_shape = action_shape
        self.latent_dim = latent_dim
        self.ae = MLPAutoencoder(observation_shape, hidden_dim_ae, self.latent_dim, nonlinearity='relu')
        self.integration_time = torch.tensor([0, 1]).float()
        self.odefunc = ODEFunc(self.latent_dim + self.action_shape, hidden_dim_ode, self.latent_dim)
        self.rew_nn = MLP(self.latent_dim + self.action_shape, hidden_dim_ae, 1)
        self.use_ode = use_ode
        self.tol = tol
        self.o = None
        self.latent_s = None

    def set_o(self, o):
        if len(o.size()) == 1:
            self.o = o.unsqueeze(0)
        else:
            self.o = o
        self.latent_s = self.ae.encode(self.o)

    def forward(self, o, a):
        latent_s = self.ae.encode(o)
        if self.use_ode:

            # def odefunc(t, x):
            #     return self.odefunc(torch.cat((x, a), dim=1))
            self.odefunc.a = a
            latent_s2 =odeint(self.odefunc, latent_s, self.integration_time.to(latent_s.device),
                              rtol=self.tol, atol=self.tol)[1]
        else:
            self.odefunc.a = a
            latent_s2 = self.odefunc(0, latent_s)
        o2 = self.ae.decode(latent_s2)
        o_recon = self.ae.decode(latent_s)
        r = self.rew_nn(torch.cat((latent_s, a), dim=1)).squeeze(-1)
        return o2, r, o_recon

    def step(self, a):
        with torch.no_grad():
            if len(a.size()) == 1:
                a = a.unsqueeze(0)
            if self.use_ode:

                # def odefunc(t, x):
                #     return self.odefunc(torch.cat((x, a), dim=1))
                self.odefunc.a = a
                latent_s2 = odeint(self.odefunc,
                                   self.latent_s,
                                   self.integration_time.to(self.latent_s.device),
                                   rtol=self.tol,
                                   atol=self.tol)[1]
            else:
                self.odefunc.a = a
                latent_s2 = self.odefunc(0, self.latent_s)
            self.o = self.ae.decode(latent_s2)
            r = self.rew_nn(torch.cat((self.latent_s, a), dim=1)).squeeze(-1)
            self.latent_s = latent_s2
        return self.o, r, np.array([0] * len(r))


class NODA(nn.Module):

    def __init__(self, observation_space, action_space, hidden_dim_ode=256, hidden_dim_ae=256, latent_dim=32,
                 use_ode=True, tol=1e-5):
        super(NODA, self).__init__()
        observation_shape = observation_space.shape
        action_shape = action_space.shape
        if type(observation_shape) is tuple:
            observation_shape = observation_shape[0]
        if type(action_shape) is tuple:
            action_shape = action_shape[0]
        self.observation_shape = observation_shape
        self.action_shape = action_shape
        self.latent_dim = latent_dim
        assert self.latent_dim // 2 * 2 == self.latent_dim
        self.half_latent_dim = self.latent_dim // 2
        self.ae = MLPAutoencoder(observation_shape, hidden_dim_ae, self.latent_dim, nonlinearity='relu')
        self.integration_time = torch.tensor([0, 1]).float()
        self.odefunc = MLP(self.latent_dim + self.action_shape, hidden_dim_ode, self.latent_dim)
        self.odefunc1 = MLP(self.latent_dim, hidden_dim_ode, 1)
        self.odefunc2 = MLP(self.latent_dim + self.action_shape, hidden_dim_ode, self.half_latent_dim)
        self.rew_nn = MLP(self.latent_dim + self.action_shape, hidden_dim_ae, 1)
        self.use_ode = use_ode
        self.tol = tol
        self.o = None
        self.latent_s = None

    def set_o(self, o):
        if len(o.size()) == 1:
            self.o = o.unsqueeze(0)
        else:
            self.o = o
        self.latent_s = self.ae.encode(self.o)

    def forward(self, o, a):
        latent_s = self.ae.encode(o)
        if self.use_ode:

            def odefunc(t, x):
                grad = torch.autograd.grad(self.odefunc1(x).sum(), x, create_graph=False, retain_graph=False)[0]
                with torch.no_grad():
                    grad_q = -grad[:, :self.half_latent_dim].clone()
                    grad_p = grad[:, self.half_latent_dim:].clone()
                    grad[:, :self.half_latent_dim] = grad_p
                    grad[:, self.half_latent_dim:] = grad_q
                    grad[:, self.half_latent_dim:] += self.odefunc2(torch.cat((x, a), dim=1))
                return grad
            latent_s2 =odeint(odefunc, latent_s, self.integration_time.to(latent_s.device),
                              rtol=self.tol, atol=self.tol)[1]
        else:
            self.odefunc.a = a
            latent_s2 = self.odefunc(0, latent_s)
        o2 = self.ae.decode(latent_s2)
        o_recon = self.ae.decode(latent_s)
        r = self.rew_nn(torch.cat((latent_s, a), dim=1)).squeeze(-1)
        return o2, r, o_recon

    def step(self, a):
        if len(a.size()) == 1:
            a = a.unsqueeze(0)
        self.latent_s.requires_grad = True
        if self.use_ode:

            def odefunc(t, x):
                grad = torch.autograd.grad(self.odefunc1(x).sum(), x, create_graph=False, retain_graph=False)[0]
                with torch.no_grad():
                    grad_q = -grad[:, :self.half_latent_dim].clone()
                    grad_p = grad[:, self.half_latent_dim:].clone()
                    grad[:, :self.half_latent_dim] = grad_p
                    grad[:, self.half_latent_dim:] = grad_q
                    grad[:, self.half_latent_dim:] += self.odefunc2(torch.cat((x, a), dim=1))
                return grad
            latent_s2 = odeint(odefunc,
                               self.latent_s,
                               self.integration_time.to(self.latent_s.device),
                               rtol=self.tol,
                               atol=self.tol)[1]
        else:
            self.odefunc.a = a
            latent_s2 = self.odefunc(0, self.latent_s)
        with torch.no_grad():
            self.o = self.ae.decode(latent_s2)
            r = self.rew_nn(torch.cat((self.latent_s, a), dim=1)).squeeze(-1)
        self.latent_s = latent_s2.detach()
        return self.o, r, np.array([0] * len(r))


def noda(env_fn, actor_critic=core.MLPActorCritic, ac_kwargs=None, model=NODA, model_kwargs=None,
         seed=0, steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99,
         polyak=0.995, lr=1e-3, alpha=0.2, batch_size=256, start_steps=10000,
         update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000,
         logger_kwargs=None, save_freq=1, device='cpu', noise=0.0):
    """
    Neural Ordinary Differential Equation Auto-encoder (NODA)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with an ``act``
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of
            observations as inputs, and ``q1`` and ``q2`` should accept a batch
            of observations and a batch of actions as inputs. When called,
            ``act``, ``q1``, and ``q2`` should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each
                                           | observation.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

            Calling ``pi`` should return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                           | actions in ``a``. Importantly: gradients
                                           | should be able to flow back into ``a``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
            you provided to SAC.

        model: the model of the environment.

        model_kwargs (dict): Any kwargs appropriate for the model.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long
            you wait between updates, the ratio of env steps to gradient steps
            is locked to 1.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

        device (string): the device for running the experiment.

        noise (float): the noise added to transitions.

    """
    os.environ['PYTHONHASHSEED'] = '0'
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    if ac_kwargs is None:
        ac_kwargs = dict()
    if model_kwargs is None:
        model_kwargs = dict()
    if logger_kwargs is None:
        logger_kwargs = dict()
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    env, test_env = env_fn().unwrapped, env_fn().unwrapped
    if hasattr(env, '_terminate_when_unhealthy'):
        env._terminate_when_unhealthy = False
    if hasattr(test_env, '_terminate_when_unhealthy'):
        env._terminate_when_unhealthy = False
    env.seed(seed)
    test_env.seed(seed)
    env.observation_space.seed(seed)
    test_env.observation_space.seed(seed)
    env.action_space.seed(seed)
    test_env.action_space.seed(seed)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Create actor-critic module and target networks
    ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs).to(device)
    ac_targ = deepcopy(ac).to(device)

    # create environment model
    model = model(env.observation_space,
                  env.action_space,
                  latent_dim=model_kwargs['latent_dim'],
                  hidden_dim_ode=model_kwargs['hidden_dim_ode'],
                  hidden_dim_ae=model_kwargs['hidden_dim_ae'],
                  use_ode = model_kwargs['use_ode']).to(device)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size, device=device)
    replay_buffer_model = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size, device=device)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.q1, ac.q2, model])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d, \t model: %d\n' % var_counts)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        q1 = ac.q1(o, a)
        q2 = ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = ac.pi(o2)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup) ** 2).mean()
        loss_q2 = ((q2 - backup) ** 2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.detach().cpu().numpy(),
                      Q2Vals=q2.detach().cpu().numpy())

        return loss_q, q_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (alpha * logp_pi - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=logp_pi.detach().cpu().numpy())

        return loss_pi, pi_info

    # Set up function for computing model loss
    def compute_loss_model(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        o2_pred, r_pred, o_recon = model(o, a)
        loss_o_pred = ((o2_pred - o2) ** 2).mean()
        loss_r_pred = ((r_pred - r) ** 2).mean()
        loss_o_recon = ((o_recon - o) ** 2).mean()

        loss_model = 0.5 * (loss_o_pred + loss_o_recon) + 0.5 * loss_r_pred
        model_info = dict(# LossMOPred=0.5 * loss_o_pred.item(),
                          # LossMORecon=0.5 * loss_o_recon.item(),
                          # LossMRPred=0.5 * loss_r_pred.item()
                          )

        return loss_model, model_info

    # Set up optimizers for policy, q-function and model
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q_optimizer = Adam(q_params, lr=lr)
    model_optimizer = Adam(model.parameters(), lr=lr)

    # Set up model saving
    logger.setup_pytorch_saver([ac, model])

    def update_model(data):
        # Update model
        model_optimizer.zero_grad()
        loss_model, model_info = compute_loss_model(data)
        loss_model.backward()
        model_optimizer.step()

        # Record things
        logger.store(LossM=loss_model.item(), **model_info)

    def update(data, real_data=True):
        # Firstly Run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        q_optimizer.step()

        # Record things
        if real_data:
            logger.store(LossQ=loss_q.item(), **q_info)
        else:
            logger.store(LossMQ=loss_q.item(), **q_info)

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Record things
        if real_data:
            logger.store(LossPi=loss_pi.item(), **pi_info)
        else:
            logger.store(LossMPi=loss_pi.item(), **pi_info)

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_modified_action(obs, act):
        obs = obs.detach()
        act = act.detach()
        act.requires_grad = True
        act_optimizer = Adam([act], lr=model_kwargs['explore_lr'])
        for i in range(model_kwargs['update_action_turns']):
            loss = -torch.min(ac.q1(obs, act), ac.q2(obs, act)).mean()
            loss = 1.0 / (loss - 1e-6)
            act_optimizer.zero_grad()
            loss.backward()
            act_optimizer.step()
        act.requires_grad = False
        return act

    def interact_with_model(batch):
        model_batch = batch
        o_model = model_batch['obs'].clone()
        with torch.no_grad():
            model.set_o(o_model)
        for j in range(model_kwargs['model_step']):
            if j == 0:
                a_model = get_modified_action(o_model, model_batch['act'].clone())
            else:
                a_model = get_modified_action(o_model,
                                              torch.as_tensor(get_action(o_model), dtype=torch.float32).to(device))
            o2_model, r_model, d_model = model.step(a_model)
            for k in range(model_batch_size):
                replay_buffer_model.store(o_model[k].cpu().numpy(), a_model[k].cpu().numpy(),
                                          r_model[k].cpu().numpy(), o2_model[k].cpu().numpy(), d_model[k])
            o_model = o2_model

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32).to(device), deterministic)

    def test_agent():
        for j in range(num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    bar = tqdm(list(range(steps_per_epoch)), desc=f'Epoch 1')
    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):
        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(a)
        o2 += noise * np.random.randn(*o2.shape)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d


        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, ep_ret, ep_len = env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch)
                if t <= start_steps:
                    update_model(data=batch)
                elif t % model_kwargs['update_model_interval'] == 0:
                    update_model(data=batch)
            if replay_buffer_model.size > update_every * batch_size:
                for i in range(max(update_every - (t + 1) // steps_per_epoch, 0)):
                    batch = replay_buffer_model.sample_batch(batch_size)
                    update(data=batch, real_data=False)

        if t > start_steps and t % 64 == 0:
            model_batch_size = int(64 * batch_size)
            if replay_buffer.ptr >= model_batch_size:
                batch = replay_buffer.get_batch(np.arange(replay_buffer.ptr - model_batch_size, replay_buffer.ptr))
            elif replay_buffer.size == replay_buffer.max_size:
                batch = replay_buffer.get_batch(
                    np.concatenate((np.arange(0, replay_buffer.ptr),
                                    np.arange(replay_buffer.max_size - model_batch_size + replay_buffer.ptr,
                                              replay_buffer.max_size))))
            else:
                batch = replay_buffer.get_batch(np.arange(0, replay_buffer.ptr))
            model_batch_size = len(batch['rew'])
            interact_with_model(batch)


        current_stats = logger.get_current_stats(mean_scope=update_every)
        if 'LogPi' in current_stats.keys():
            current_stats.pop('LogPi')
        if 'Q1Vals'  in current_stats.keys():
            current_stats.pop('Q1Vals')
        if 'Q2Vals' in current_stats.keys():
            current_stats.pop('Q2Vals')
        bar.set_postfix(**current_stats)
        bar.update()

        # End of epoch handling
        if (t + 1) % steps_per_epoch == 0:
            bar.close()
            epoch = (t + 1) // steps_per_epoch

            # Save model
            # if (epoch % save_freq == 0) or (epoch == epochs):
            #     logger.save_state({'env': env}, None)

            # Test the performance of the deterministic version of the agent.
            test_agent()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            try:
                logger.log_tabular('LossM', average_only=True)
                logger.log_tabular('LossMOPred', average_only=True)
                logger.log_tabular('LossMORecon', average_only=True)
                logger.log_tabular('LossMRPred', average_only=True)
            except:
                pass
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()

            if epoch < epochs:
                bar = tqdm(list(range(steps_per_epoch)), desc=f'Epoch {epoch + 1}')
    return logger
