#cluster

from collections import OrderedDict
import numpy as np

import torch
import torch.optim as optim
from torch import nn as nn

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.core.rl_algorithm import MetaRLAlgorithm
from rlkit.samplers.in_place import InPlacePathSampler

import copy
import random

n_cluster_list=[2,4,8]#[2,4,8,16,32]#n_cluster=30

class PEARLSoftActorCritic(MetaRLAlgorithm):
    def __init__(
            self,
            env,
            train_tasks,
            eval_tasks,
            latent_dim,
            nets,

            policy_lr=1e-3,
            qf_lr=1e-3,
            vf_lr=1e-3,
            context_lr=1e-3,
            kl_lambda=1.,
            policy_mean_reg_weight=1e-3,
            policy_std_reg_weight=1e-3,
            policy_pre_activation_weight=0.,
            optimizer_class=optim.Adam,
            recurrent=False,
            use_information_bottleneck=True,
            use_next_obs_in_context=False,
            sparse_rewards=False,

            soft_target_tau=1e-2,
            plotter=None,
            render_eval_paths=False,
            dyna_mode=1,
            **kwargs
    ):
        super().__init__(
            env=env,
            agent=nets[0],
            train_tasks=train_tasks,
            eval_tasks=eval_tasks,
            **kwargs
        )

        self.soft_target_tau = soft_target_tau
        self.policy_mean_reg_weight = policy_mean_reg_weight
        self.policy_std_reg_weight = policy_std_reg_weight
        self.policy_pre_activation_weight = policy_pre_activation_weight
        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.recurrent = recurrent
        self.latent_dim = latent_dim
        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()
        self.vib_criterion = nn.MSELoss()
        self.l2_reg_criterion = nn.MSELoss()
        self.kl_lambda = kl_lambda
        
        self.dynamics_criterion             = nn.MSELoss()
        self.dyna_mode                      = dyna_mode

        self.use_information_bottleneck = use_information_bottleneck
        self.sparse_rewards = sparse_rewards
        self.use_next_obs_in_context = use_next_obs_in_context

        self.qf1, self.qf2, self.vf,self.dynamics_list = nets[1:]
        self.target_vf = self.vf.copy()

        self.policy_optimizer = optimizer_class(
            self.agent.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        self.vf_optimizer = optimizer_class(
            self.vf.parameters(),
            lr=vf_lr,
        )
        self.context_optimizer = optimizer_class(
            self.agent.context_encoder.parameters(),
            lr=context_lr,
        )
        dyna_params_list=[]
        for dynamics in self.dynamics_list[0]:
            dyna_params=[]
            for dyna in dynamics:
                dyna_params.append({'params':dyna.parameters()})
            dyna_params_list.append(dyna_params)
        for dynamics in self.dynamics_list[1]:
            dyna_params=[]
            for dyna in dynamics:
                dyna_params.append({'params':dyna.parameters()})
            dyna_params_list.append(dyna_params)
        self.dynamics_optimizer_list             = [optimizer_class(dyna_params,  lr=qf_lr) for dyna_params in dyna_params_list]
        
        self.qf1_bf, self.qf2_bf, self.vf_bf = copy.deepcopy(self.qf1).cuda(), copy.deepcopy(self.qf2).cuda(), copy.deepcopy(self.vf).cuda()
        self.optimizer_class=optimizer_class
        self.policy_lr=policy_lr
        self.qf_lr=qf_lr
        self.vf_lr=vf_lr
        self.reset_expl1(set_sampler=False)
        self.reset_expl2(set_sampler=False)
        self.reset_expl1(set_sampler=True)
        self.reset_expl2(set_sampler=True)
        
        self.torch_eye_list=[torch.eye(n_cluster).cuda() for n_cluster in n_cluster_list]
        self.analyze=[[] for _ in range(100)]
        self.used_list=[[] for _ in range(len(n_cluster_list))]#self.used=[]
        self.z_loss_weight=10

    def reset_expl1(self,set_sampler=True):
        self.expl_agent1=copy.deepcopy(self.agent_bf).cuda()

        if set_sampler:
            self.sampler_expl = InPlacePathSampler(
                env=self.env,
                policy=self.expl_agent2,
                max_path_length=self.max_path_length,
            )
        
        self.qf1_expl1, self.qf2_expl1, self.vf_expl1 = copy.deepcopy(self.qf1_bf).cuda(), copy.deepcopy(self.qf2_bf).cuda(), copy.deepcopy(self.vf_bf).cuda()
        self.target_vf_expl1 = self.vf_expl1.copy().cuda()

        self.expl_policy_optimizer1 = self.optimizer_class(
            self.expl_agent1.policy.parameters(),
            lr=self.policy_lr,
        )
        self.expl_qf1_optimizer1 = self.optimizer_class(
            self.qf1_expl1.parameters(),
            lr=self.qf_lr,
        )
        self.expl_qf2_optimizer1 = self.optimizer_class(
            self.qf2_expl1.parameters(),
            lr=self.qf_lr,
        )
        self.expl_vf_optimizer1 = self.optimizer_class(
            self.vf_expl1.parameters(),
            lr=self.vf_lr,
        )
        
    def reset_expl2(self,set_sampler=False):
        self.expl_agent2=copy.deepcopy(self.agent_bf).cuda()

        if set_sampler:
            self.sampler_expl = InPlacePathSampler(
                env=self.env,
                policy=self.expl_agent1,
                max_path_length=self.max_path_length,
            )
        
        self.qf1_expl2, self.qf2_expl2, self.vf_expl2 = copy.deepcopy(self.qf1_bf).cuda(), copy.deepcopy(self.qf2_bf).cuda(), copy.deepcopy(self.vf_bf).cuda()
        self.target_vf_expl2 = self.vf_expl2.copy().cuda()

        self.expl_policy_optimizer2 = self.optimizer_class(
            self.expl_agent2.policy.parameters(),
            lr=self.policy_lr,
        )
        self.expl_qf1_optimizer2 = self.optimizer_class(
            self.qf1_expl2.parameters(),
            lr=self.qf_lr,
        )
        self.expl_qf2_optimizer2 = self.optimizer_class(
            self.qf2_expl2.parameters(),
            lr=self.qf_lr,
        )
        self.expl_vf_optimizer2 = self.optimizer_class(
            self.vf_expl2.parameters(),
            lr=self.vf_lr,
        )

    ###### Torch stuff #####
    @property
    def networks(self):
        nets=self.agent.networks + [self.agent] + [self.qf1, self.qf2, self.vf, self.target_vf]
        for dynamics in self.dynamics_list[0]:
            nets+=dynamics
        for dynamics in self.dynamics_list[1]:
            nets+=dynamics
        return nets
        # return self.agent.networks + [self.agent] + [self.qf1, self.qf2, self.vf, self.target_vf]

    def training_mode(self, mode):
        for net in self.networks:
            net.train(mode)

    def to(self, device=None):
        if device == None:
            device = ptu.device
        for net in self.networks:
            net.to(device)

    ##### Data handling #####
    def unpack_batch(self, batch, sparse_reward=False):
        ''' unpack a batch and return individual elements '''
        o = batch['observations'][None, ...]
        a = batch['actions'][None, ...]
        if sparse_reward:
            r = batch['sparse_rewards'][None, ...]
        else:
            r = batch['rewards'][None, ...]
        no = batch['next_observations'][None, ...]
        t = batch['terminals'][None, ...]
        return [o, a, r, no, t]

    def sample_sac(self, indices):
        ''' sample batch of training data from a list of tasks for training the actor-critic '''
        # this batch consists of transitions sampled randomly from replay buffer
        # rewards are always dense
        batches = [ptu.np_to_pytorch_batch(self.replay_buffer.random_batch(idx, batch_size=self.batch_size)) for idx in indices]
        unpacked = [self.unpack_batch(batch) for batch in batches]
        # group like elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    def sample_sac_enc(self, indices):
        ''' sample batch of training data from a list of tasks for training the actor-critic '''
        # this batch consists of transitions sampled randomly from replay buffer
        # rewards are always dense
        batches = [ptu.np_to_pytorch_batch(self.enc_replay_buffer.random_batch(idx, batch_size=self.batch_size)) for idx in indices]
        unpacked = [self.unpack_batch(batch) for batch in batches]
        # group like elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    def sample_context(self, indices):
        ''' sample batch of context from a list of tasks from the replay buffer '''
        # make method work given a single task index
        if not hasattr(indices, '__iter__'):
            indices = [indices]
        batches = [ptu.np_to_pytorch_batch(self.enc_replay_buffer.random_batch(idx, batch_size=self.embedding_batch_size, sequence=self.recurrent)) for idx in indices]
        context = [self.unpack_batch(batch, sparse_reward=self.sparse_rewards) for batch in batches]
        # group like elements together
        context = [[x[i] for x in context] for i in range(len(context[0]))]
        context = [torch.cat(x, dim=0) for x in context]
        # full context consists of [obs, act, rewards, next_obs, terms]
        # if dynamics don't change across tasks, don't include next_obs
        # don't include terminals in context
        if self.use_next_obs_in_context:
            context = torch.cat(context[:-1], dim=2)
        else:
            context = torch.cat(context[:-2], dim=2)
        return context

    ##### Training #####
    def _do_training(self, indices):
        mb_size = self.embedding_mini_batch_size
        num_updates = self.embedding_batch_size // mb_size

        # sample context batch
        context_batch = self.sample_context(indices)

        # zero out context and hidden encoder state
        self.agent.clear_z(num_tasks=len(indices))

        # do this in a loop so we can truncate backprop in the recurrent encoder
        for i in range(num_updates):
            context = context_batch[:, i * mb_size: i * mb_size + mb_size, :]
            self._take_step(indices, context)

            # stop backprop
            self.agent.detach_z()
            
    ##### Training #####
    def _do_training_expl(self, indices):
        mb_size = self.embedding_mini_batch_size
        num_updates = self.embedding_batch_size // mb_size

        # sample context batch
        context_batch = self.sample_context(indices)

        # zero out context and hidden encoder state
        self.expl_agent1.clear_z(num_tasks=len(indices))
        self.expl_agent2.clear_z(num_tasks=len(indices))

        # do this in a loop so we can truncate backprop in the recurrent encoder
        for i in range(num_updates):
            context = context_batch[:, i * mb_size: i * mb_size + mb_size, :]
            self._take_step_expl(indices, context)

            # stop backprop
            self.expl_agent1.detach_z()
            self.expl_agent2.detach_z()

    def _min_q(self, obs, actions, task_z):
        q1 = self.qf1(obs, actions, task_z.detach())
        q2 = self.qf2(obs, actions, task_z.detach())
        min_q = torch.min(q1, q2)
        return min_q

    def _min_q_expl1(self, obs, actions, task_z):
        q1 = self.qf1_expl1(obs, actions, task_z.detach())
        q2 = self.qf2_expl1(obs, actions, task_z.detach())
        min_q = torch.min(q1, q2)
        return min_q

    def _min_q_expl2(self, obs, actions, task_z):
        q1 = self.qf1_expl2(obs, actions, task_z.detach())
        q2 = self.qf2_expl2(obs, actions, task_z.detach())
        min_q = torch.min(q1, q2)
        return min_q

    def _update_target_network(self):
        ptu.soft_update_from_to(self.vf, self.target_vf, self.soft_target_tau)
        
    def _update_target_network_expl1(self):
        ptu.soft_update_from_to(self.vf_expl1, self.target_vf_expl1, self.soft_target_tau)
        
    def _update_target_network_expl2(self):
        ptu.soft_update_from_to(self.vf_expl2, self.target_vf_expl2, self.soft_target_tau)

    def z_loss(self, indices, task_z,b, epsilon=1e-3, threshold=0.999):
        pos_z_loss = 0.
        neg_z_loss = 0.
        pos_cnt = 0
        neg_cnt = 0
        for i in range(len(indices)):
            idx_i = i * b # index in task * batch dim
            for j in range(i+1, len(indices)):
                idx_j = j * b # index in task * batch dim
                if indices[i] == indices[j]:
                    pos_z_loss += torch.sqrt(torch.mean((task_z[idx_i] - task_z[idx_j]) ** 2) + epsilon)
                    pos_cnt += 1
                else:
                    neg_z_loss += 1/(torch.mean((task_z[idx_i] - task_z[idx_j]) ** 2) + epsilon * 100)
                    neg_cnt += 1
        return pos_z_loss/(pos_cnt + epsilon) +  neg_z_loss/(neg_cnt + epsilon) 

    def _take_step(self, indices, context):
        now_n_cluster=np.random.randint(len(n_cluster_list))
        
        obs_dim = int(np.prod(self.env.observation_space.shape))
        action_dim = int(np.prod(self.env.action_space.shape))

        num_tasks = len(indices)

        # data is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self.sample_sac(indices)

        # run inference in networks
        policy_outputs, task_z =self.agent(obs, context)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]

        # flattens out the task dimension
        t, b, _ = obs.size()
        
        obs_e, actions_e, rewards_e, next_obs_e, terms_e = self.sample_sac_enc(indices)
        obs_flat = obs_e.view(t * b, -1)
        actions_flat = actions_e.view(t * b, -1)
        next_obs_flat = next_obs_e.view(t * b, -1)
        rewards_flat = rewards_e.view(t * b, -1)
        
        obs_actions_flat=torch.cat([obs_flat,actions_flat],-1)
        dyna_flat=rewards_flat if self.dyna_mode==1 else torch.cat([rewards_flat,next_obs_flat],-1)
        
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        
        # cluster match
        with torch.no_grad():
            closest_list,cluster_prob_list=[],[]
            for dynamics1,dynamics2,torch_eye,n_cluster in zip(self.dynamics_list[0],self.dynamics_list[1],self.torch_eye_list,n_cluster_list):
                pred=torch.stack([dyna(obs_flat) for dyna in dynamics1],dim=0)#(n_cl,t*b,a)
                # pred=torch.stack([dyna(t,b,obs_flat) for dyna in dynamics1],dim=0)#(n_cl,t*b,a)
                pred=pred.view(n_cluster,t,b,-1)#(n_cl,t,b,a)
                dist=((pred-actions_flat.view(t,b,-1).unsqueeze(0))**2).sum(3).sum(2)#(n_cl,t)
                dist=(1+torch.rand(dist.shape).cuda()*0.01)*dist
                # dist1=dist/dist.mean(0,keepdim=True)
                dist1=(dist-dist.mean(0,keepdim=True))/dist.std(0,keepdim=True)
                
                pred=torch.stack([dyna(obs_actions_flat) for dyna in dynamics2],dim=0)#(n_cl,t*b,1)
                # pred=torch.stack([dyna(t,b,obs_actions_flat) for dyna in dynamics2],dim=0)#(n_cl,t*b,1)
                pred=pred.view(n_cluster,t,b,-1)#(n_cl,t,b,1)
                dist=((pred-dyna_flat.view(t,b,-1).unsqueeze(0))**2).sum(3).sum(2)#(n_cl,t)
                dist=(1+torch.rand(dist.shape).cuda()*0.01)*dist
                # dist2=dist/dist.mean(0,keepdim=True)
                dist2=(dist-dist.mean(0,keepdim=True))/dist.std(0,keepdim=True)
                
                # closest=(dist1+torch.max(dist1,dist2)).argmin(0)#(t)
                closest=(dist2).argmin(0)#(t)
                # closest=(dist1+dist2).argmin(0)#(t)
                cluster_prob=torch_eye[closest].permute(1,0)#(n_cl,t)
                
                closest_list.append(closest)
                cluster_prob_list.append(cluster_prob)
        
        if True:#epoch<=50:#
            for dynamics1,dynamics2,closest,cluster_prob,used,n_cluster,dynamics_optimizer in zip(self.dynamics_list[0],self.dynamics_list[1],closest_list,cluster_prob_list,self.used_list,n_cluster_list,self.dynamics_optimizer_list):
                # analyze
                # for i in range(t):
                #     self.analyze[indices_rubbish[i]]=self.analyze[indices_rubbish[i]][-100:]+[int(closest[i])]
                # if random.random()<0.01:
                #     json.dump(self.analyze,open('analyze_results/analyze_v3b_c30_pointgen.json','w'))
                    
                # cluster reuse
                used+=[int(closest[i]) for i in range(len(closest))]
                used=used[-(n_cluster*20):]
                if len(used)==n_cluster*20:
                    active,inactive=[],[]
                    for i in range(n_cluster):
                        if i in used:
                            active.append(i)
                        else:
                            inactive.append(i)
                    for cluster in inactive:
                        dynamics1[cluster].load_state_dict(dynamics1[active[random.randint(0,len(active)-1)]].state_dict())
                        dynamics2[cluster].load_state_dict(dynamics2[active[random.randint(0,len(active)-1)]].state_dict())
                
                # dynamics
                dynamics_optimizer.zero_grad()
                pred=torch.stack([dyna(obs_flat) for dyna in dynamics1],dim=0)#(n_cl,t*b,a)
                # pred=torch.stack([dyna(t,b,obs_flat) for dyna in dynamics1],dim=0)#(n_cl,t*b,a)
                pred=pred.view(n_cluster,t,b,-1)#(n_cl,t,b,a)
                pred=(pred*cluster_prob.unsqueeze(2).unsqueeze(3)).sum(0)#(t,b,a)
                dyna_loss1=self.dynamics_criterion(pred.view(t*b,-1),actions_flat)
                
                pred=torch.stack([dyna(obs_actions_flat) for dyna in dynamics2],dim=0)#(n_cl,t*b,1)
                # pred=torch.stack([dyna(t,b,obs_actions_flat) for dyna in dynamics2],dim=0)#(n_cl,t*b,1)
                pred=pred.view(n_cluster,t,b,-1)#(n_cl,t,b,1)
                pred=(pred*cluster_prob.unsqueeze(2).unsqueeze(3)).sum(0)#(t,b,1)
                dyna_loss2=self.dynamics_criterion(pred.view(t*b,-1),dyna_flat)
                
                dyna_loss=dyna_loss1+dyna_loss2
                dyna_loss.backward()
                dynamics_optimizer.step()
        
        indices=closest_list[now_n_cluster].cpu()

        # Q and V networks
        # encoder will only get gradients from Q nets
        q1_pred = self.qf1(obs, actions, task_z)
        q2_pred = self.qf2(obs, actions, task_z)
        v_pred = self.vf(obs, task_z.detach())
        # get targets for use in V and Q updates
        with torch.no_grad():
            target_v_values = self.target_vf(next_obs, task_z)

        # # KL constraint on z if probabilistic
        # self.context_optimizer.zero_grad()
        # if self.use_information_bottleneck:
        #     kl_div = self.agent.compute_kl_div()
        #     kl_loss = self.kl_lambda * kl_div
        #     kl_loss.backward(retain_graph=True)
            
        # KL constraint on z if probabilistic
        # self.context_optimizer.zero_grad()
        z_loss = self.z_loss_weight * self.z_loss(indices=indices, task_z=task_z, b=b)
        # z_loss.backward(retain_graph=True)
        # self.loss["z_loss"] = z_loss.item()
        # self.context_optimizer.step()

        # qf and encoder update (note encoder does not get grads from policy or vf)
        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()
        self.context_optimizer.zero_grad()
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)
        q_target = rewards_flat + (1. - terms_flat) * self.discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target) ** 2) + torch.mean((q2_pred - q_target) ** 2)
        qf_loss+=z_loss
        qf_loss.backward()
        self.qf1_optimizer.step()
        self.qf2_optimizer.step()
        self.context_optimizer.step()

        # compute min Q on the new actions
        min_q_new_actions = self._min_q(obs, new_actions, task_z)

        # vf update
        v_target = min_q_new_actions - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()
        self._update_target_network()

        # policy update
        # n.b. policy update includes dQ/da
        log_policy_target = min_q_new_actions

        policy_loss = (
                log_pi - log_policy_target
        ).mean()

        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean()
        )
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # save some statistics for eval
        if self.eval_statistics is None:
            # eval should set this to None.
            # this way, these statistics are only computed for one batch.
            self.eval_statistics = OrderedDict()
            if self.use_information_bottleneck:
                z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0])))
                z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0]))
                self.eval_statistics['Z mean train'] = z_mean
                self.eval_statistics['Z variance train'] = z_sig
                # self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div)
                # self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss)

            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'V Predictions',
                ptu.get_numpy(v_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))

    def _take_step_expl(self, indices, context):
        num_tasks = len(indices)

        # data is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self.sample_sac(indices)

        # run inference in networks
        policy_outputs, task_z =self.expl_agent1.forward_no_context(obs)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)

        # Q and V networks
        # encoder will only get gradients from Q nets
        q1_pred = self.qf1_expl1(obs, actions, task_z)
        q2_pred = self.qf2_expl1(obs, actions, task_z)
        v_pred = self.vf_expl1(obs, task_z.detach())
        # get targets for use in V and Q updates
        with torch.no_grad():
            target_v_values = self.target_vf_expl1(next_obs, task_z)

        # qf and encoder update (note encoder does not get grads from policy or vf)
        self.expl_qf1_optimizer1.zero_grad()
        self.expl_qf2_optimizer1.zero_grad()
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)
        q_target = rewards_flat + (1. - terms_flat) * self.discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target) ** 2) + torch.mean((q2_pred - q_target) ** 2)
        qf_loss.backward()
        self.expl_qf1_optimizer1.step()
        self.expl_qf2_optimizer1.step()

        # compute min Q on the new actions
        min_q_new_actions = self._min_q_expl1(obs, new_actions, task_z)

        # vf update
        v_target = min_q_new_actions - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.expl_vf_optimizer1.zero_grad()
        vf_loss.backward()
        self.expl_vf_optimizer1.step()
        self._update_target_network_expl1()

        # policy update
        # n.b. policy update includes dQ/da
        log_policy_target = min_q_new_actions

        policy_loss = (
                log_pi - log_policy_target
        ).mean()

        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean()
        )
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self.expl_policy_optimizer1.zero_grad()
        policy_loss.backward()
        self.expl_policy_optimizer1.step()
        
        #again
        num_tasks = len(indices)

        # data is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self.sample_sac(indices)

        # run inference in networks
        policy_outputs, task_z =self.expl_agent2.forward_no_context(obs)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)

        # Q and V networks
        # encoder will only get gradients from Q nets
        q1_pred = self.qf1_expl2(obs, actions, task_z)
        q2_pred = self.qf2_expl2(obs, actions, task_z)
        v_pred = self.vf_expl2(obs, task_z.detach())
        # get targets for use in V and Q updates
        with torch.no_grad():
            target_v_values = self.target_vf_expl2(next_obs, task_z)

        # qf and encoder update (note encoder does not get grads from policy or vf)
        self.expl_qf1_optimizer2.zero_grad()
        self.expl_qf2_optimizer2.zero_grad()
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)
        q_target = rewards_flat + (1. - terms_flat) * self.discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target) ** 2) + torch.mean((q2_pred - q_target) ** 2)
        qf_loss.backward()
        self.expl_qf1_optimizer2.step()
        self.expl_qf2_optimizer2.step()

        # compute min Q on the new actions
        min_q_new_actions = self._min_q_expl2(obs, new_actions, task_z)

        # vf update
        v_target = min_q_new_actions - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.expl_vf_optimizer2.zero_grad()
        vf_loss.backward()
        self.expl_vf_optimizer2.step()
        self._update_target_network_expl2()

        # policy update
        # n.b. policy update includes dQ/da
        log_policy_target = min_q_new_actions

        policy_loss = (
                log_pi - log_policy_target
        ).mean()

        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean()
        )
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self.expl_policy_optimizer2.zero_grad()
        policy_loss.backward()
        self.expl_policy_optimizer2.step()

    def get_epoch_snapshot(self, epoch):
        # NOTE: overriding parent method which also optionally saves the env
        snapshot = OrderedDict(
            qf1=self.qf1.state_dict(),
            qf2=self.qf2.state_dict(),
            policy=self.agent.policy.state_dict(),
            vf=self.vf.state_dict(),
            target_vf=self.target_vf.state_dict(),
            context_encoder=self.agent.context_encoder.state_dict(),
        )
        return snapshot
