import numpy as np
import torch

from HyAR_RL.utils import ReplayBuffer as embedbuffer
from embedding import ActionRepresentation_vae
from mpc_model.hyar_model import vae_train

from utils import device
import utils as u
from models.world_model import world_model
from copy import deepcopy
from torch.autograd import Variable
import os
import math
import copy

import wandb


class Trainer:
    def __init__(self, args):
        args.device = device
        self.algo = args.algo

        u.set_seed(args.seed)
        self.env, self.args = u.make_env(args)

        if args.save_points:
            self.save_points()

        self.model = world_model(args).to(device)
        self.model_target = deepcopy(self.model).to(device)
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.args.td_lr)
        self.pi_optim = torch.optim.Adam(self.model._agent.parameters(), lr=self.args.td_lr)

        self.buffer = u.ReplayBuffer(args)

        if args.embed:
            self.embed_buffer = embedbuffer(args.state_dim, 
                                            discrete_action_dim=1,
                                            parameter_action_dim=args.z_dim,
                                            all_parameter_action_dim=args.all_z_dim,
                                            discrete_emb_dim=args.discrete_emb_dim,
                                            parameter_emb_dim=args.parameter_emb_dim,
                                            max_size=int(2e6)
                                            )

        dir = f"result/TDMPC/{args.env.split('-')[0]}"
        data = args.save_dir
        redir = os.path.join(dir, data)
        if not os.path.exists(redir):
            os.makedirs(redir)
        self.redir = redir
        self.Test_Reward_100 = []
        self.Test_epioside_step_100 = []

    def save_points(self):
        run = wandb.init(
            project="pamdp-mpc",
            config=self.args,
            dir="../scratch/wandb"
        )

    def upload_log(self, mylog):
        if self.args.save_points:
            wandb.log(mylog)

    def save_local(self):
        title3 = "Test_Reward_100_td3_platform_"
        title4 = "Test_epioside_step_100_td3_platform_"
        np.savetxt(os.path.join(self.redir, title3 + "{}".format(str(self.args.seed) + ".csv")), self.Test_Reward_100, delimiter=',')
        np.savetxt(os.path.join(self.redir, title4 + "{}".format(str(self.args.seed) + ".csv")), self.Test_epioside_step_100,
                delimiter=',')
        
        model_file = "world_model_"
        torch.save(self.model.state_dict(), os.path.join(self.redir, model_file + "{}".format(str(self.args.seed) + ".pth")))

    def act(self, action, timestep, pre_state=None):
        # print(action)
        ret = self.env.step(action)

        if self.args.env == "simple_catch-v0":
            next_state, reward, terminal_n, _ = ret

            if pre_state[-2]<= 2/3 and action[0][-1] and np.sum(np.square(next_state[0][2:4]))>0.04:
                valid_time = pre_state[-2] + 1/6
            else:
                valid_time = pre_state[-2]

            next_state = next_state[0].tolist() + [valid_time, pre_state[-1]+1/12]
            # next_state = next_state[0]

            reward = reward[0]

            terminal = all(terminal_n)
            if reward > 4 or reward == 0 or timestep >= self.args.episode_length:
                terminal = True

        elif self.args.env == "simple_move_4_direction_v1-v0":
            next_state, reward, done_n, _ = ret
            next_state = next_state[0].tolist()
            terminal = all(done_n)
            reward = reward[0]
            if reward > 4 or timestep >= self.args.episode_length:
                terminal = True

        else:
            (next_state, steps), reward, terminal, _ = ret

        next_state = np.array(next_state, dtype=np.float32, copy=False)
        return next_state, reward, terminal
    
    def reset(self):
        if self.args.env == "simple_catch-v0":
            state = self.env.reset()
            valid_time, timestep = -1., -1.
            state = state[0].tolist() + [valid_time, timestep]  # agent_vol, direction2target

        elif self.args.env == "simple_move_4_direction_v1-v0":
            state = self.env.reset()
            state = state[0].tolist()

        else:
            state, _ = self.env.reset()
        return np.array(state, dtype=np.float32, copy=False)

    def evaluate(self):
        returns = []
        epioside_steps = []
        vis = self.args.visualise

        # self.model.load_state_dict(torch.load("result/TDMPC/simple_catch/tdmpc_model_catch_h20_hard/world_model_0.pth"))
        # self.model.eval()

        for epi in range(self.args.eval_eposides):
            # print(f"eval: {epi}")
            state = self.reset()
            t = 0
            
            with torch.no_grad():
                act, act_param = self.plan(state, eval_mode=True, t0=True, step=0, local_step=t)
                # print(act, act_param)
                action = self.pad_action(act, act_param)
                # print(action)

            if vis:
                self.env.render()

            terminal = False
            
            total_reward = 0.
            
            while not terminal:
                t += 1
                print(f"eval: {epi} || step: {t}")

                state, reward, terminal = self.act(action, t, pre_state=state)

                # print(state, action, reward, terminal)
                # print('true: ', state, reward, terminal, '\n')
                # print(act, state[-2])
                # exit()

                with torch.no_grad():
                    act, act_param = self.plan(state, eval_mode=True, t0=False, step=0, local_step=t)
                    action = self.pad_action(act, act_param)

                # if t > self.args.episode_length:
                #     break

                if vis:
                    self.env.render()

                total_reward += reward
            epioside_steps.append(t)
            returns.append(total_reward)
        print("---------------------------------------")
        print(
            f"Evaluation over {self.args.eval_eposides} episodes_rewards: {np.array(returns).mean():.3f} epioside_steps: {np.array(epioside_steps).mean():.3f}")
        print("---------------------------------------")
        Test_Reward = np.array(returns).mean()
        Test_epioside_step = np.array(epioside_steps).mean()

        self.Test_Reward_100.append(Test_Reward)
        self.Test_epioside_step_100.append(Test_epioside_step)

        self.upload_log({"Test_Reward": Test_Reward, "Test_epioside_step": Test_epioside_step})

    def rand_action(self):
        k = torch.randint(low=0, high=self.args.k_dim, size=[1])
        z = torch.rand([self.args.par_size[k]]) * self.args.scale + self.args.offsets
        
        return k.item(), z
    
    def dealRaw(self, k, z):
        size = torch.from_numpy(self.args.par_size).to(device)[k.argmax(-1)].unsqueeze(-1).repeat(1, self.args.z_dim)
        mask = torch.arange(self.args.z_dim).to(device).repeat(len(size), 1)
        mask = torch.where(mask<size, 1., 0.)
        z = z * mask
        return torch.cat([k, z], dim=-1)
    
    @torch.no_grad()
    def sample_from_N(self, mean, std):
        if self.args.embed:
            raise "not implemented embed yet"
        else:
            kmean = mean['k']
            zmean, zstd = mean['z'], std

            k_int = torch.multinomial(kmean, self.args.mpc_popsize, replacement=True)
            k_onehot = torch.nn.functional.one_hot(k_int, num_classes=self.args.k_dim).to(device)
            
            z_all = torch.clamp(zmean.unsqueeze(1) + zstd.unsqueeze(1) * \
                    torch.randn(self.args.mpc_horizon, self.args.mpc_popsize, self.args.all_z_dim, device=zstd.device), self.args.lb, self.args.ub)
            
            offsets = torch.tensor(self.args.offset).to(device)[k_int.flatten()].unsqueeze(-1).repeat(1, self.args.z_dim) + torch.arange(self.args.z_dim, device=device)

            z_one = torch.zeros([self.args.mpc_horizon*self.args.mpc_popsize, self.args.all_z_dim+self.args.z_dim], device=device)
            z_one[:, :self.args.all_z_dim] = z_all.reshape([-1, self.args.all_z_dim])
            
            zs = torch.gather(z_one, 1, offsets)
            
            size = torch.from_numpy(self.args.par_size).to(device)[k_int.flatten()].unsqueeze(-1).repeat(1, self.args.z_dim)
            mask = torch.arange(self.args.z_dim).to(device).repeat(len(size), 1)
            mask = torch.where(mask<size, 1., 0.)
            zs = zs * mask
            
            zs = zs.reshape([self.args.mpc_horizon, self.args.mpc_popsize, self.args.z_dim])
            return torch.cat([k_onehot, zs], dim=-1)
    
    @torch.no_grad()
    def estimate_value(self, s, actions, horizon, local_step):
        """Estimate value of a trajectory starting at latent state z and executing given actions."""
        G, discount = 0, 1
        num_traj = s.shape[0]
        c = torch.ones([num_traj, 1], device=device)

        # pre_s = s.detach().clone()

        for t in range(horizon):
            # s, reward, ci = self.model.next(s, actions[t], reparameterize=False, return_log_prob=False, deterministic=True)
            s, reward, ci = self.model.next(s, actions[t], reparameterize=True, return_log_prob=False, deterministic=False)

            ci = ci.argmax(-1).unsqueeze(-1)
            # episilon = 1e-4

            # if self.args.env in [ "simple_catch-v0"]:
            #     if local_step + t >= self.args.episode_length:
            #         break
            #     # print(s.shape)

            #     dist2 = torch.sum(torch.square(s[:, 2:4]))
            #     hard_r = torch.where(abs(s[:, 4]-1)<episilon, 0., - dist2)
            #     k = actions[t][:, :2].argmax(-1)
            #     condition = torch.logical_and((s[:, 4])-1>episilon, k)
            #     condition = torch.logical_and(condition, dist2<=0.2)
            #     hard_r = torch.where(condition, dist2-20, hard_r)
            #     hard_r = hard_r.unsqueeze(-1)
                
            #     reward = hard_r
                
            #     hard_c = (1 - torch.logical_or((reward > 4), (reward == 0)).long())
            #     ci = hard_c

            G += discount * reward * c
            discount *= self.args.mpc_gamma
            c *= ci

            # pre_s = s.detach().clone()

        # k, z = self.model.pi(s, self.args.min_std, reparameterize=False, return_log_prob=False, deterministic=True)

        # pi_a = self.dealRaw(k, z)
        # G += discount * torch.min(*self.model.Q(s, pi_a)) * c

        return G
    
    @torch.no_grad()
    def plan(self, state, eval_mode=False, step=None, t0=True, local_step=None):
        if step < self.args.seed_steps and not eval_mode:
            return self.rand_action()
        
        if eval_mode:
            reparameterize = False
            return_log_prob = False
            deterministic = True
        else:
            reparameterize = False
            return_log_prob = True
            deterministic = False

            self.model.timestep += 1
        
        # Sample policy trajectories
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        # horizon = int(min(self.args.horizon, h.linear_schedule(self.args.horizon_schedule, step)))

        if self.args.use_policy and (not self.args.use_model):
            k, z = self.model.pi(state, self.args.min_std, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)
            k, z = k.flatten(), z.flatten()
            k = k.argmax().item()
            z = z[:self.args.par_size[k]]
            return k, z
        
        horizon = self.args.mpc_horizon
        num_pi_trajs = int(self.args.mixture_coef * self.args.mpc_popsize) if self.args.use_policy else 0

        if num_pi_trajs > 0:
            pi_actions = torch.zeros(horizon, num_pi_trajs, self.args.action_dim, device=device)
            s = state.repeat(num_pi_trajs, 1)
            for t in range(horizon):
                k, z = self.model.pi(s, self.args.min_std, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)

                pi_actions[t] = self.dealRaw(k, z)

                s, _, _ = self.model.next(s, pi_actions[t], reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)

        # Initialize state and parameters
        s = state.repeat(self.args.mpc_popsize+num_pi_trajs, 1)

        if self.args.embed:
            return self.embed_cem(state, t0, eval_mode)

        kmean = torch.ones(horizon, self.args.k_dim, device=device)
        kmean /= kmean.sum(-1).unsqueeze(-1)
        
        zmean = torch.zeros(horizon, self.args.all_z_dim, device=device)
        std = 2*torch.ones(horizon, self.args.all_z_dim, device=device)
        mean = {'k': kmean, 'z': zmean}
        if not t0 and hasattr(self, '_prev_mean'):
            mean['k'][:-1] = self._prev_mean['k'][1:]
            mean['z'][:-1] = self._prev_mean['z'][1:]

        # Iterate CEM
        for i in range(self.args.cem_iter):
            actions = self.sample_from_N(mean, std)
            if num_pi_trajs > 0:
                actions = torch.cat([actions, pi_actions], dim=1)
                
            # Compute elite actions
            value = self.estimate_value(s, actions, horizon, local_step).nan_to_num_(0)
            elite_idxs = torch.topk(value.squeeze(1), self.args.mpc_num_elites, dim=0).indices
            elite_value = value[elite_idxs]  # [num_elite, 1]
            elite_actions = actions[:, elite_idxs]  # [horizon, num_elite, a_dim]

            max_value = elite_value.max(0)[0]

            # Update k parameters
            # k_score is k weights, softmax(elite_value-max)
            k_score = torch.exp(self.args.mpc_temperature*(elite_value - max_value))
            k_score /= k_score.sum(0)  # [num_elite, 1]
            kelites = elite_actions[:, :, :self.args.k_dim]
            _kmean = torch.sum(k_score.unsqueeze(0) * kelites, dim=1) / (k_score.sum(0) + 1e-9)

            # Update z parameters
            zelites = elite_actions[:, :, self.args.k_dim:]
            k_all = kelites.argmax(-1).unsqueeze(-1)  # [horizon, num_elite, 1]
            z_score = elite_value.unsqueeze(0).repeat([horizon, 1, 1])  # [horizon, num_elite, 1]
            _zmean, _std = torch.zeros_like(mean['z']), torch.zeros_like(std)

            for ki in range(self.args.k_dim):
                selected_ind = (k_all == ki)  # selected discrete type, [horizon, num_elite, 1]
                zis = zelites[:, :, :self.args.par_size[ki]]
                # zi: [horizon, num_elite, z_dim], = zi if selected else 0
                zi = torch.where(selected_ind, zis, torch.zeros_like(zis).to(device))

                # weight: [horizon, num_elite, z_dim], = softmax(selected(z))
                weight = torch.where(selected_ind, z_score, torch.tensor([float("-Inf")]).to(device))
                weight = torch.exp(self.args.mpc_temperature*(weight - max_value))
                weight_sum = weight.squeeze(-1).sum(1).reshape([-1, 1, 1]).repeat(1, self.args.mpc_num_elites, 1)
                weight /= (weight_sum + 1e-9)
                
                _zimean = torch.sum(weight * zi, dim=1) / (weight.sum(1) + 1e-9)
                _zistd = torch.sqrt(torch.sum(weight * (zi - _zimean.unsqueeze(1)) ** 2, dim=1) / (weight.sum(1) + 1e-9))

                ind_start = self.args.offset[ki]
                ind_end = ind_start + self.args.par_size[ki]

                if_non_select = selected_ind.squeeze(-1).sum(1).unsqueeze(-1)
                _zimean = torch.where(if_non_select==0, mean['z'][:, ind_start:ind_end], _zimean)
                _zistd = torch.where(if_non_select==0, std[:, ind_start:ind_end], _zistd)

                _zmean[:, ind_start:ind_end] = _zimean
                _std[:, ind_start:ind_end] = _zistd

            mean['k'] = self.args.mpc_alpha * mean['k'] + (1 - self.args.mpc_alpha) * _kmean
            mean['z'] = self.args.mpc_alpha * mean['z'] + (1 - self.args.mpc_alpha) * _zmean
            std = self.args.mpc_alpha * std + (1 - self.args.mpc_alpha) * _std

        # Outputs
        score = k_score.squeeze(1).cpu().numpy()
        actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
        self._prev_mean = mean
        mean, std = actions[0], _std[0]

        k = mean[:self.args.k_dim].argmax()
        z = mean[self.args.k_dim:self.args.k_dim+self.args.par_size[k]]
        
        if not eval_mode:
            ind_start = self.args.offset[k]
            ind_end = ind_start + self.args.par_size[k]
            z += std[ind_start:ind_end] * torch.randn(self.args.par_size[k], device=device)

        k = k.item()
        if self.args.env in [ "simple_catch-v0"] and k==1:
            z = torch.zeros(1)

        # # print(state.shape, actions[0].shape)
        # pred_s, reward, ci = self.model.next(state, actions[0].unsqueeze(0), reparameterize=False, return_log_prob=False, deterministic=True)

        # dist2 = torch.sum(torch.square(pred_s[:, 2:4]))
        # hard_r = torch.where(pred_s[:, 4]>11, 0., - dist2)
        # condition = torch.logical_and(pred_s[:, 4]<=11, torch.tensor(k))
        # condition = torch.logical_and(condition, dist2<=0.2)
        # hard_r = torch.where(condition, dist2-20, hard_r)
        # hard_r = hard_r.unsqueeze(-1)
        # hard_c = (1 - torch.logical_or((hard_r > 4), hard_r < 1e-3).long())

        # print("pred: ", pred_s.cpu().numpy()[0], hard_r.cpu().numpy(), hard_c.cpu().numpy())
            
        return k, z

    def pad_action(self, act, act_param):
        act_param = act_param.cpu().numpy()

        if self.args.env == "simple_catch-v0":
            if act == 0:
                action = np.hstack(([1], act_param * math.pi, [1], [0]))
            else:  # catch
                action = np.hstack(([1], [0], [0], [1]))
            return [action]
        
        elif self.args.env == "hard_goal-v0":
            return self.pad_hardgoal(act, act_param)
        
        elif self.args.env == "simple_move_4_direction_v1-v0":
            action = np.hstack(([8], [act], [self.args.action_n_dim])).tolist()

            act_params = [0] * (2 ** self.args.action_n_dim)
            act_params[act] = act_param[0]
            action.append(act_params)
            # action = np.array(action)
            # print(action, act, act_param)
            return [action]
        
        else:
            params = [np.zeros((num,), dtype=np.float32) for num in self.args.par_size]
            params[act][:] = act_param
            return (act, params)
    
    @torch.no_grad()
    def _td_target(self, next_obs, reward, c):
        """Compute the TD-target from a reward and the observation at the following time step."""
        k, z = self.model.pi(next_obs, self.args.min_std, reparameterize=False, return_log_prob=False, deterministic=True)

        action = self.dealRaw(k, z)
        td_target = reward + self.args.mpc_gamma * \
            torch.min(*self.model_target.Q(next_obs, action)) * c.unsqueeze(-1)
        
        # print(reward.shape, torch.min(*self.model_target.Q(next_obs, action)).shape, c.shape)
        # print(next_obs[0])
        # exit()
        return td_target
    
    def update_pi(self, zs):
        """Update policy using a sequence of latent states."""
        self.pi_optim.zero_grad(set_to_none=True)
        self.model.track_q_grad(False)
        # Loss is a weighted sum of Q-values
        pi_loss = 0
        for t, s in enumerate(zs):
            k, z = self.model.pi(s, self.args.min_std, reparameterize=True, return_log_prob=True, deterministic=False)

            a = self.dealRaw(k, z)
            Q = torch.min(*self.model.Q(s, a))
            pi_loss += -Q.mean() * (self.args.rho ** t)

        pi_loss.backward(retain_graph=True)

        torch.nn.utils.clip_grad_norm_(self.model._agent.parameters(), self.args.grad_clip_norm, error_if_nonfinite=False)
        self.pi_optim.step()
        self.model.track_q_grad(True)

        return pi_loss.item()
    
    def train_sperate(self, step):
        """Main update function. Corresponds to one iteration of the TOLD model learning."""
        # print("Training!")
        obs, next_obses, ks, zs, reward, idxs, weights, continuous, masks = self.buffer.sample()
        self.optim.zero_grad(set_to_none=True)
        # self.std = h.linear_schedule(self.args.std_schedule, step)
        action = torch.cat([ks, zs], dim=-1)
        self.model.train()

        # Representation
        pobs = obs.detach().clone()
        pk = ks[0].detach().clone()
        pz = zs[0].detach().clone()
        pr = reward[0].detach().clone()
        pc = continuous[0].unsqueeze(-1).detach().clone()
        pnobs = next_obses[0].detach().clone()

        consistency_loss, reward_loss, continue_loss = 0, 0, 0

        for t in range(self.args.mpc_horizon):
            mask = masks[t].unsqueeze(1)

            if not mask.any():
                break
            # print(mask)
            # exit()
            
            rho =  1. if self.args.train_single else (self.args.rho ** t)

            with torch.no_grad():
                next_obs = next_obses[t]

            # Predictions
            # if self.args.use_policy and (t > 0):
            inds = masks[t].bool()
            pobs = torch.cat([pobs, next_obses[t-1][inds].detach().clone()])
            pk = torch.cat([pk, ks[t][inds].detach().clone()])
            pz = torch.cat([pz, zs[t][inds].detach().clone()])
            pr = torch.cat([pr, reward[t][inds].detach().clone()])
            pc = torch.cat([pc, continuous[t].unsqueeze(-1)[inds].detach().clone()])
            pnobs = torch.cat([pnobs, next_obses[t][inds].detach().clone()])

            if self.args.use_model and (not self.args.train_single):

                obs, reward_pred, c_pred = self.model.next(obs, action[t], reparameterize=True, return_log_prob=True, deterministic=False)
                
                consistency_loss += rho * torch.mean(u.mse(obs, next_obs), dim=1, keepdim=True) * mask
                reward_loss += rho * u.mse(reward_pred, reward[t]) * mask
                # print(c_pred.shape, continuous[t].shape)
                continue_loss += rho * u.ce(c_pred, continuous[t]) * mask

        if self.args.use_model and self.args.train_single:
            obs, reward_pred, c_pred = self.model.next(obs, action[t], reparameterize=True, return_log_prob=True, deterministic=False)

            consistency_loss = torch.mean(u.mse(obs, next_obs), dim=1, keepdim=True)
            reward_loss = u.mse(reward_pred, reward[t])
            continue_loss = u.ce(c_pred, continuous[t])
                    
        
        if self.args.use_model:
            total_loss = self.args.consistency_coef * consistency_loss.clamp(max=1e4) + \
                            self.args.reward_coef * reward_loss.clamp(max=1e4) + \
                            self.args.contin_coef * continue_loss.clamp(max=1e4)
            
            weighted_loss = (total_loss.squeeze(1) * weights).mean()
            weighted_loss.register_hook(lambda grad: grad * (1/self.args.mpc_horizon))
            weighted_loss.backward()
            # grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip_norm, error_if_nonfinite=False)
            self.optim.step()
        else:
            consistency_loss = torch.zeros([self.args.batch_size, 1]).to(device)
            reward_loss = torch.zeros([self.args.batch_size, 1]).to(device)
            continue_loss = torch.zeros([self.args.batch_size, 1]).to(device)

        if self.args.use_policy:
            qf_loss, vf_loss, policy_loss = self.model.train_policy(pobs, pk, pz, pr, pnobs, pc)
        else:
            qf_loss, vf_loss, policy_loss = 0, 0, 0

        self.model.eval()
        
        return {
                'consistency_loss': float(consistency_loss.mean().item()),
                'reward_loss': float(reward_loss.mean().item()),
                'continuous_loss': float(continue_loss.mean().item()),
                'Q_loss': qf_loss,
                'value_loss': vf_loss,
                'pi_loss': policy_loss,
                # 'total_loss': float(total_loss.mean().item()),
                # 'weighted_loss': float(weighted_loss.mean().item()),
                # 'grad_norm': float(grad_norm)
                }


    def train(self, step):
        """Main update function. Corresponds to one iteration of the TOLD model learning."""
        # print("Training!")
        obs, next_obses, ks, zs, reward, idxs, weights, continuous, masks = self.buffer.sample()
        self.optim.zero_grad(set_to_none=True)
        # self.std = h.linear_schedule(self.args.std_schedule, step)
        action = torch.cat([ks, zs], dim=-1)
        self.model.train()

        # Representation
        ss = [obs.detach()]

        consistency_loss, reward_loss, continue_loss, value_loss, priority_loss = 0, 0, 0, 0, 0
        # print(obs.shape, next_obses.shape, ks.shape, zs.shape, reward.shape, idxs.shape, weights.shape, continuous.shape)

        for t in range(self.args.mpc_horizon):
            mask = masks[t].unsqueeze(1)
            # print(mask)
            # exit()
            
            rho =  1. if self.args.train_single else (self.args.rho ** t)

            with torch.no_grad():
                next_obs = next_obses[t]

            # Predictions
            if self.args.use_policy:
                # Q1, Q2 = self.model.Q(obs, action[t])
                Q1, Q2 = self.model.Q(obs.detach(), action[t])
                with torch.no_grad():
                    td_target = self._td_target(next_obs, self.args.reward_scale * reward[t], continuous[t])

                value_loss += rho * (u.mse(Q1, td_target) + u.mse(Q2, td_target)) * mask
                priority_loss += rho * (u.l1(Q1, td_target) + u.l1(Q2, td_target)) * mask

            if self.args.use_model:

                obs, reward_pred, c_pred = self.model.next(obs, action[t], reparameterize=True, return_log_prob=True, deterministic=False)
                ss.append(obs.detach())
                
                consistency_loss += rho * torch.mean(u.mse(obs, next_obs), dim=1, keepdim=True) * mask
                reward_loss += rho * u.mse(reward_pred, reward[t]) * mask
                # print(c_pred.shape, continuous[t].shape)
                continue_loss += rho * u.ce(c_pred, continuous[t]) * mask

                if self.args.train_single:
                    obs = next_obs.detach().clone()
            
            # print(consistency_loss.shape, reward_loss.shape, value_loss.shape, priority_loss.shape, mask.shape)
        # exit()

        # Optimize model
        if self.args.use_policy and (not self.args.use_model):
            consistency_loss = torch.zeros([self.args.batch_size, 1]).to(device)
            reward_loss = torch.zeros([self.args.batch_size, 1]).to(device)
            continue_loss = torch.zeros([self.args.batch_size, 1]).to(device)
        elif (not self.args.use_policy) and self.args.use_model:
            value_loss = torch.zeros([self.args.batch_size, 1]).to(device)

        total_loss = self.args.consistency_coef * consistency_loss.clamp(max=1e4) + \
                        self.args.reward_coef * reward_loss.clamp(max=1e4) + \
                        self.args.contin_coef * continue_loss.clamp(max=1e4) + \
                        self.args.value_coef * value_loss.clamp(max=1e4)
        
        weighted_loss = (total_loss.squeeze(1) * weights).mean()
        weighted_loss.register_hook(lambda grad: grad * (1/self.args.mpc_horizon))
        weighted_loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip_norm, error_if_nonfinite=False)
        self.optim.step()

        if self.args.use_policy:
            self.buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach())

            # Update policy + target network
            pi_loss = self.update_pi(ss)
            if step % self.args.pi_update_freq == 0:
                u.ema(self.model, self.model_target, self.args.pi_tau)
        else:
            pi_loss = 0

        self.model.eval()
        
        return {'consistency_loss': float(consistency_loss.mean().item()),
                'reward_loss': float(reward_loss.mean().item()),
                'continuous_loss': float(continue_loss.mean().item()),
                'value_loss': float(value_loss.mean().item()),
                'pi_loss': pi_loss,
                'total_loss': float(total_loss.mean().item()),
                'weighted_loss': float(weighted_loss.mean().item()),
                'grad_norm': float(grad_norm)}
    
    def embed_cem(self, state, t0, eval_mode):
        horizon = self.args.mpc_horizon
        s = state.repeat(self.args.mpc_popsize, 1)

        ek_mean = torch.zeros(horizon, self.args.discrete_emb_dim, device=device)
        ek_std = 2*torch.ones(horizon, self.args.discrete_emb_dim, device=device)

        ez_mean = torch.zeros(horizon, self.args.parameter_emb_dim, device=device)
        ez_std = 2*torch.ones(horizon, self.args.parameter_emb_dim, device=device)

        if not t0 and hasattr(self, '_prev_mean'):
            ek_mean[:-1] = self._prev_ek_mean[1:]
            ez_mean[:-1] = self._prev_ez_mean[1:]

        # Iterate CEM
        for i in range(self.args.iterations):
            kactions = torch.clamp(ek_mean.unsqueeze(1) + ek_std.unsqueeze(1) * \
                torch.randn(horizon, self.args.mpc_popsize, self.args.discrete_emb_dim, device=device), -1, 1)
            
            zactions = torch.clamp(ez_mean.unsqueeze(1) + ez_std.unsqueeze(1) * \
                torch.randn(horizon, self.args.mpc_popsize, self.args.parameter_emb_dim, device=device), -1, 1)
            
            actions = torch.cat([kactions, zactions], dim=-1)

            # if num_pi_trajs > 0:
            #     actions = torch.cat([actions, pi_actions], dim=1)

            # Compute elite actions
            value = self.estimate_value(s, actions, horizon).nan_to_num_(0)
            elite_idxs = torch.topk(value.squeeze(1), self.args.num_elites, dim=0).indices

            elite_value = value[elite_idxs]
            elite_ek_actions = kactions[:, elite_idxs]
            elite_ez_actions = zactions[:, elite_idxs]

            # Update parameters
            max_value = elite_value.max(0)[0]
            score = torch.exp(self.args.temperature*(elite_value - max_value))
            score /= score.sum(0)

            ek_mean, ek_std = self.update_new_dis_par(score, elite_ek_actions, ek_mean, ek_std)
            ez_mean, ez_std = self.update_new_dis_par(score, elite_ez_actions, ez_mean, ez_std)

        # Outputs
        score = score.squeeze(1).cpu().numpy()

        ind = np.random.choice(np.arange(score.shape[0]), p=score)
        ek = elite_ek_actions[:, ind]
        ez = elite_ez_actions[:, ind]

        self._prev_ek_mean = ek_mean
        self._prev_ez_mean = ez_mean

        ek, ekstd = ek[0], ek_std[0]
        ez, ezstd = ez[0], ez_std[0]
        
        if not eval_mode:
            ek += ekstd * torch.randn(self.args.discrete_emb_dim, device=device)
            ez += ezstd * torch.randn(self.args.parameter_emb_dim, device=device)

        k, allz = self.model.decode_embed(state, ek, ez)

        ind_start = self.args.offset[k]
        ind_end = ind_start + self.args.par_size[k]
        z = allz[ind_start:ind_end]

        return k, z
    
    def update_new_dis_par(self, score, elite_actions, mean, std):
        _mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
        _std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9))
        _std = _std.clamp_(self.std, 2)
        mean, std = self.args.momentum * mean + (1 - self.args.momentum) * _mean, _std
        return mean, std
    
    def random_action(self):
        k = np.random.randint(0, self.args.k_dim)
        psize = self.args.par_size[k]
        z = np.random.random(psize) * 2 - 1
        return k, z
    
    def pretrain(self, steps):
        max_steps = 150
        total_reward = 0.
        returns = []
        step = 0
        while step < steps:

            state = self.reset()

            k, z = self.random_action()
            action = self.pad_action(k, z)

            # episode_reward = 0.
            for j in range(max_steps):
                step += 1

                next_state, reward, terminal = self.act(action, j, pre_state=state)
                
                nk, nz = self.random_action()
                next_action = self.pad_action(k, z)
                
                self.embed_buffer.add(state, k, z, None, 
                                    discrete_emb=None,
                                    parameter_emb=None,
                                    next_state=next_state,
                                    state_next_state=None,
                                    reward=reward, 
                                    done=terminal)
                k, z = nk, nz
                action = next_action
                state = next_state
                
                if terminal:
                    break
        
        self.model.vae_train(train_step=5000, buffer=self.embed_buffer)

    def vae_train(self):
        self.model.vae_train(train_step=1, buffer=self.embed_buffer)

    def pad_hardgoal(self, act, act_param):

        c_rate = [[-1.0, -0.6], [-0.6, -0.2], [0.2, 0.2], [0.2, 0.6], [0.6, 1.0]]

        params = [np.zeros((2,)), np.zeros((1,)), np.zeros((1,))]
        if act == 0:
            params[0][0] = act_param[0]
            params[0][1] = act_param[1]
        elif act == 1:
            act_param=self.true_action(0, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[1] = act_param
            act = 1
        elif act == 2:
            act_param=self.true_action(1, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[1] = act_param
            act = 1
        elif act == 3:
            act_param=self.true_action(2, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[1] = act_param
            act = 1
        elif act == 4:
            act_param=self.true_action(3, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[1] = act_param
            act = 1
        elif act == 5:
            act_param=self.true_action(4, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[1] = act_param
            act = 1

        elif act == 6:
            act_param=self.true_action(0, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[2] = act_param
            act = 2
        elif act == 7:
            act_param=self.true_action(1, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[2] = act_param
            act = 2
        elif act == 8:
            act_param=self.true_action(2, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[2] = act_param
            act = 2
        elif act == 9:
            act_param=self.true_action(3, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[2] = act_param
            act = 2
        elif act == 10:
            act_param=self.true_action(4, act_param[0], c_rate)
            act_param=np.array([act_param])
            params[2] = act_param
            act = 2
        return (act, params)
    
    def count_boundary(self, c_rate):
        median = (c_rate[0] - c_rate[1]) / 2
        offset = c_rate[0] - 1 * median
        return median, offset

    def true_action(self, act, act_param, c_rate):
        parameter_action_ = copy.deepcopy(act_param)
        median, offset = self.count_boundary(c_rate[act])
        parameter_action_ = parameter_action_ * median + offset
        return parameter_action_
    

