import torch
import copy 

import numpy as np
import torch.nn.functional as F
import torch.nn as nn

from tqdm import tqdm

from sgcrl.models.quantizer import special_tokens
from sgcrl.utils.imports import instantiate_class, get_arguments, get_class
from sgcrl.data.torch_datasets.relabel import image_data_augmentation
from sgcrl.evaluation.d4rl_evaluation import serial_evaluation_loop
from sgcrl.models.dual_policy import DualPolicy

CURRENT = 0
FUTURE = 1
RANDOM = 2
GOAL_TYPES = [CURRENT, FUTURE, RANDOM]

#############
# quantizer #
#############
def train_quantizer(quantizer, optimizer_quantizer, epochs, dataloader, device, model_db, logger, offset, norm, contrastive_coef, commit_coef, 
                    reconstruction_coef, save_every=100, use_visual_data_augmentation=False, p_aug=0.5, square_rotation=False, vertical_flip=False,
                    horizontal_flip=False, padded_random_crop=False, padding_size=4, reshaped_random_crop=False, crop_size=(16,16), grad_max_norm=None, noise=0):
    quantizer.train()

    if len(norm) == 0:
        norm = torch.tensor([1.0], device=device)
    else:
        norm = torch.tensor(norm, device=device)

    for epoch in range(1, epochs + 1):
        losses = []
        commit_losses = []
        reconstruction_losses = []
        contrastive_losses = []
        with tqdm(dataloader, desc=f"Epoch {epoch}:") as pbar:
            for sample, positive, negative in pbar:

                # data augmentation
                sample, positive, negative = sample + noise * torch.rand_like(sample), positive + noise * torch.rand_like(positive), negative + noise * torch.rand_like(negative)
                sample, positive, negative = sample.to(device), positive.to(device), negative.to(device)
                sample, positive, negative = (sample + offset)/norm, (positive + offset)/norm, (negative + offset)/norm

                if use_visual_data_augmentation:
                    sample, sample_mask = image_data_augmentation(p_aug, sample, square_rotation, vertical_flip, horizontal_flip, padded_random_crop, padding_size, reshaped_random_crop, crop_size)
                    positive, positive_mask = image_data_augmentation(p_aug, positive, square_rotation, vertical_flip, horizontal_flip, padded_random_crop, padding_size, reshaped_random_crop, crop_size)
                    negative, negative_mask = image_data_augmentation(p_aug, negative, square_rotation, vertical_flip, horizontal_flip, padded_random_crop, padding_size, reshaped_random_crop, crop_size)

                else:
                    sample_mask = torch.ones_like(sample)
                    positive_mask = torch.ones_like(positive)
                    negative_mask = torch.ones_like(negative)

                # update
                quantized_sample,   commit_loss_sample,   latent_sample   = quantizer(sample,   return_latent=True)
                quantized_positive, commit_loss_positive, latent_positive = quantizer(positive, return_latent=True)
                quantized_negative, commit_loss_negative, latent_negative = quantizer(negative, return_latent=True)

                commit_loss = commit_loss_sample + commit_loss_positive + commit_loss_negative
                reconstruction_loss = F.mse_loss(sample, quantized_sample*sample_mask) + F.mse_loss(positive, quantized_positive*positive_mask) + F.mse_loss(negative, quantized_negative*negative_mask)
                contrastive_loss = F.triplet_margin_loss(latent_sample, positive=latent_positive, negative=latent_negative)
                
                loss = commit_coef * commit_loss + contrastive_coef * contrastive_loss + reconstruction_coef * reconstruction_loss

                optimizer_quantizer.zero_grad(set_to_none=True)
                loss.backward()
                if not (grad_max_norm is None):
                    torch.nn.utils.clip_grad_norm_(quantizer.parameters(), grad_max_norm)
                optimizer_quantizer.step()

                losses.append(loss.item())
                commit_losses.append(commit_loss.item())
                reconstruction_losses.append(reconstruction_loss.item())
                contrastive_losses.append(contrastive_loss.item())
                pbar.set_postfix(loss=np.mean(losses), commit=np.mean(commit_losses), recon=np.mean(reconstruction_losses), contrastive=np.mean(contrastive_losses))
                
            if epoch % save_every == 0:
                m = copy.deepcopy(quantizer).to("cpu")
                model_db.push("quantizer", m)
                
            logger.add_scalar("train/loss", np.mean(losses), epoch)
            logger.add_scalar("train/commit_loss", np.mean(commit_losses), epoch)
            logger.add_scalar("train/reconstruction_loss", np.mean(reconstruction_losses), epoch)
            logger.add_scalar("train/contrastive_loss", np.mean(contrastive_losses), epoch)
    
    return quantizer

###############
# transformer #
###############
def train_transformer(transformer, optimizer, epochs, train_dataloader, val_dataloader, special_tokens, model_db, logger, save_every, device):
    loss_fn = nn.CrossEntropyLoss(ignore_index=special_tokens['PADDING_VALUE'])

    for epoch in range(1, epochs + 1):   
        transformer.train()
        losses = []
        accuracies = []
        with tqdm(train_dataloader, desc=f'Train {epoch}:') as pbar:
            for batch in pbar:
                input_seq = torch.cat((batch['goal/token'][:, 0], batch['token/compressed'][:, :2]), dim=1).to(device).long() # goal token + sos token + initial token
                target_seq = batch['token/compressed'].to(device).long()

                logits = transformer(input_seq, target_seq[:, :-1])
                loss = loss_fn(logits, target_seq[:, 1:])

                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()

                preds = logits.argmax(dim=1)
                masked_preds = (target_seq[:, 1:] != special_tokens['PADDING_VALUE'])

                accuracy = (preds[masked_preds] == target_seq[:, 1:][masked_preds]).float().mean().detach().cpu()

                losses.append(loss.item())
                accuracies.append(accuracy)
                
                pbar.set_postfix(loss=np.mean(losses), accuracy=np.mean(accuracies))

        logger.add_scalar("train/loss", np.mean(losses), epoch)
        logger.add_scalar("train/acc", np.mean(accuracies), epoch)

        transformer.eval()
        losses = []
        accuracies = []
        with torch.no_grad():
            with tqdm(val_dataloader, desc=f'Val   {epoch}:') as pbar:
                for batch in pbar:
                    input_seq = torch.cat((batch['goal/token'][:, 0], batch['token/compressed'][:, :2]), dim=1).to(device).long() # goal token + sos token + initial token
                    target_seq = batch['token/compressed'].to(device).long()

                    logits = transformer(input_seq, target_seq[:, :-1])
                    loss = loss_fn(logits, target_seq[:, 1:])

                    preds = logits.argmax(dim=1)
                    masked_preds = (target_seq[:, 1:] != special_tokens['PADDING_VALUE'])
                    accuracy = (preds[masked_preds] == target_seq[:, 1:][masked_preds]).float().mean().detach().cpu()

                    losses.append(loss.item())
                    accuracies.append(accuracy)
                
                    pbar.set_postfix(loss=np.mean(losses), accuracy=np.mean(accuracies))


        logger.add_scalar("val/loss", np.mean(losses), epoch)
        logger.add_scalar("val/acc", np.mean(accuracies), epoch)
    
        if epoch % save_every == 0:
            m = copy.deepcopy(transformer).to("cpu")
            model_db.push("transformer", m)
    
    return transformer

def expectile_loss(adv, diff, expectile):
    weight = torch.where(adv >= 0, expectile, (1 - expectile))
    return weight * (diff**2)

#########
# qphil #
#########
def train_iql_goal(
    policy_cfg,
    vf_cfg,
    qf_cfg,
    optimizer_cfg,
    device,
    dataloader,
    max_gradient_step,
    reward_scale,
    discount,
    expectile,
    beta,
    clip_score,
    q_update_period,
    policy_update_period,
    target_update_period,
    polyak_coef,
    model_db,
    save_every,
    logger,
    serial_evaluation,
    cfg_evaluation,
    evaluation_logger

):
    """
    Trains pi(a|s,g) with iql.
    """

    # Instantiate models
    policy = instantiate_class(policy_cfg).to(device)
    vf = instantiate_class(vf_cfg).to(device)
    qf1 = instantiate_class(qf_cfg).to(device)
    qf2 = instantiate_class(qf_cfg).to(device)
    target_qf1 = copy.deepcopy(qf1)
    target_qf2 = copy.deepcopy(qf2)

    # Instantiate optimizers
    optimizer_args = get_arguments(optimizer_cfg)
    policy_optimizer = get_class(optimizer_cfg)(policy.parameters(), **optimizer_args)
    vf_optimizer = get_class(optimizer_cfg)(vf.parameters(), **optimizer_args)
    qf1_optimizer = get_class(optimizer_cfg)(qf1.parameters(), **optimizer_args)
    qf2_optimizer = get_class(optimizer_cfg)(qf2.parameters(), **optimizer_args)

    # Training
    epoch, gradient_step, done = 0, 0, False
    while not done:

        # Set epoch metrics
        epoch += 1

        total_vf_loss = 0
        total_qf1_loss = 0
        total_qf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        weights = None

        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:
                
                # Get data
                observations = batch['observations']
                goals = batch['goals']
                actions = batch['actions']
                next_observations = batch['next_observations']
                rewards = batch['rewards'] - 1.0 
                dones = batch['dones']
                masks = 1.0 - rewards # masks are 0 if terminal, 1 otherwise

                # qf_loss
                target_vf_pred = vf(next_observations,goals).detach()
                q_target = reward_scale * rewards + masks * discount * target_vf_pred
                q_target.detach()
                qf1_loss = F.mse_loss(qf1(next_observations,goals,actions), q_target)
                qf2_loss = F.mse_loss(qf2(next_observations,goals,actions), q_target)

                # vf_loss
                q_pred = torch.min(target_qf1(observations,goals,actions),target_qf2(observations,goals,actions)).detach()
                vf_pred = vf(observations, goals)
                vf_err = vf_pred - q_pred
                vf_sign = (vf_err > 0).float()
                vf_weight = (1 - vf_sign) * expectile + vf_sign * (1 - expectile)
                vf_loss = (vf_weight * (vf_err ** 2)).mean()

                # policy_loss
                vf_pred = vf(batch)
                adv = q_pred - vf_pred
                exp_adv = torch.exp(adv / beta)
                exp_adv = torch.clamp(exp_adv, max=clip_score)
                weights = exp_adv[:, 0].detach()
                policy_loss = policy.compute_loss(batch, weights)

                # Update networks
                if gradient_step % q_update_period == 0:
                    qf1_optimizer.zero_grad()
                    qf1_loss.backward()
                    qf1_optimizer.step()

                    qf2_optimizer.zero_grad()
                    qf2_loss.backward()
                    qf2_optimizer.step()

                    vf_optimizer.zero_grad()
                    vf_loss.backward()
                    vf_optimizer.step()
                
                if gradient_step % policy_update_period == 0:
                    policy_optimizer.zero_grad()
                    policy_loss.backward()
                    policy_optimizer.step()
            
                # Soft updates
                if gradient_step % target_update_period == 0:
                    for target_param, param in zip(target_qf1.parameters(), qf1.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)
                    
                    for target_param, param in zip(target_qf2.parameters(), qf2.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)
        
                gradient_step += 1

                # Update metrics for logging 
                abs_adv_mean += torch.abs(adv).cpu().detach().mean()
                adv_mean += adv.cpu().detach().mean()
                adv_max += adv.cpu().detach().max()
                adv_min += adv.cpu().detach().min()
                accept_prob += (adv >= 0).float().cpu().detach().mean()

                gradient_step += 1

                total_vf_loss += vf_loss.item()
                total_qf1_loss += qf1_loss.item()
                total_qf2_loss += qf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(
                    vf_loss=total_vf_loss/(it+1),
                    qf1_loss=total_qf1_loss/(it+1), 
                    qf2_loss=total_qf2_loss/(it+1),
                    policy_loss=np.mean(policy_losses))

                # Saving
                if gradient_step % save_every == 0:
                    m = copy.deepcopy(policy).to('cpu')
                    model_db.push('low_goal_policy', m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                # Done
                if gradient_step >= max_gradient_step:
                    done = True
                    break

    return policy

def train_low_hiql_goal(
    model,
    optimizer,
    dataloader,
    max_gradient_step,
    reward_scale,
    discount,
    expectile,
    beta,
    clip_score,
    v_update_period,
    policy_update_period,
    target_update_period,
    polyak_coef,
    save_every,
    model_db,
    logger,
    serial_evaluation,
    cfg_evaluation,
    evaluation_logger,
    device
):  
    """
    Trains pi(a|s,g) with low policy loss of hiql.
    """
    # Training
    target_vf1 = model.target_vf1
    target_vf2 = model.target_vf2
    epoch, gradient_steps, done = 0, 1, False
    while not done:
        epoch += 1

        model.train()
        total_vf1_loss = 0
        total_vf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        v1_min = 0
        v1_max = 0
        v1_mean = 0
        v2_min = 0
        v2_max = 0
        v2_mean = 0
        log_prob_loss = 0
        weights = None
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()} 
                batch_next = {k: v[:, 1] for k, v in batch.items()}
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs/complete']
                goals = batch['goal']
                actions = batch['action']
                next_observations = batch_next['obs/complete']
                next_goals = batch_next['goal']
                rewards = batch['reward'].unsqueeze(1) - 1.0 
                masks = 1 - rewards # masks are 0 if terminal, 1 otherwise

                # vf_loss
                next_v1_t, next_v2_t = target_vf1(next_observations,next_goals), target_vf2(next_observations,next_goals)
                next_v_t = torch.min(next_v1_t, next_v2_t).detach()

                q_t = (reward_scale * rewards + masks * discount * next_v_t)
                v_t = (target_vf1(observations,goals) + target_vf2(observations,goals)) / 2
                adv = (q_t - v_t).detach()

                q1_t = reward_scale * rewards + masks * discount * next_v1_t
                q2_t = reward_scale * rewards + masks * discount * next_v2_t

                v1, v2 = model.vf1(observations,goals), model.vf2(observations,goals)
                vf1_loss = expectile_loss(adv, q1_t.detach() - v1, expectile).mean()
                vf2_loss = expectile_loss(adv, q2_t.detach() - v2, expectile).mean()
                
                # policy_loss
                vf_pred = (model.vf1(observations,goals) + model.vf2(observations,goals)) / 2
                next_vf1 = model.vf1(next_observations,next_goals)
                next_vf2 = model.vf2(next_observations,next_goals)
                next_vf_pred = (next_vf1 + next_vf2) / 2
                
                adv = (next_vf_pred - vf_pred)
                exp_adv = torch.exp(adv * beta)
                exp_adv = torch.clamp(exp_adv, max=clip_score)
                weights = exp_adv[:, 0].detach()

                _, dist = model.policy_high(observations,goals)
                log_prob = dist.log_prob(actions)

                policy_loss = (-log_prob * weights).sum()

                # Update networks
                if gradient_steps % v_update_period == 0:
                        optimizer.vf1.zero_grad(set_to_none=True)
                        vf1_loss.backward()
                        optimizer.vf1.step()

                        optimizer.vf2.zero_grad(set_to_none=True)
                        vf2_loss.backward()
                        optimizer.vf2.step()

                if gradient_steps % policy_update_period == 0:
                    optimizer.policy_low.zero_grad(set_to_none=True)
                    optimizer.policy_high.zero_grad(set_to_none=True)
                    policy_loss.backward()
                    optimizer.policy_low.step()
                    optimizer.policy_high.step()

                # Soft updates
                if gradient_steps % target_update_period == 0:
                    for target_param, param in zip(target_vf1.parameters(), model.vf1.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                    for target_param, param in zip(target_vf2.parameters(), model.vf2.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                # Update metrics for logging 
                abs_adv_mean += torch.abs(adv).cpu().detach().mean()
                adv_mean += adv.cpu().detach().mean()
                adv_max += adv.cpu().detach().max()
                adv_min += adv.cpu().detach().min()
                accept_prob += (adv >= 0).float().cpu().detach().mean()
                v1_min += v1.cpu().detach().min()
                v1_max += v1.cpu().detach().max()
                v1_mean += v1.cpu().detach().mean()
                v2_min += v2.cpu().detach().min()
                v2_max += v2.cpu().detach().max()
                v2_mean += v2.cpu().detach().mean()
                log_prob_loss += log_prob.cpu().detach().mean()

                gradient_steps += 1

                total_vf1_loss += vf1_loss.item()
                total_vf2_loss += vf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(vf1_loss=total_vf1_loss/(it+1), 
                                 vf2_loss=total_vf2_loss/(it+1),
                                 policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('low_goal_policy', m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

        logger.add_scalar('vf1_loss/train', vf1_loss.item(), gradient_steps)
        logger.add_scalar('vf2_loss/train', vf2_loss.item(), gradient_steps)
        logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_steps)
        logger.add_scalar('abs_adv_mean', abs_adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_mean', adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_max', adv_max/len(dataloader), gradient_steps)
        logger.add_scalar('adv_min', adv_min/len(dataloader), gradient_steps)
        logger.add_scalar('accept_prob', accept_prob/len(dataloader), gradient_steps)
        logger.add_scalar('v1_min', v1_min/len(dataloader), gradient_steps)
        logger.add_scalar('v1_max', v1_max/len(dataloader), gradient_steps)
        logger.add_scalar('v1_mean', v1_mean/len(dataloader), gradient_steps)
        logger.add_scalar('v2_min', v2_min/len(dataloader), gradient_steps)
        logger.add_scalar('v2_max', v2_max/len(dataloader), gradient_steps)
        logger.add_scalar('v2_mean', v2_mean/len(dataloader), gradient_steps)
        logger.add_scalar('log_prob_loss', log_prob_loss/len(dataloader), gradient_steps)
        logger.add_histogram(name='Histogram_awr', values=weights.cpu().detach(), epoch=gradient_steps)

    return model

def train_gcbc_goal(
    model,
    optimizer,
    dataloader,
    max_gradient_step,
    save_every,
    save_as_policy,
    model_db,
    logger,
    serial_evaluation,
    cfg_evaluation,
    evaluation_logger,
    device,
    is_godot
): 
    """
    Trains pi(a|s,g) with gcbc.
    """
    # Training
    epoch, gradient_steps, done = 0, 1, False
    while not done:
        epoch += 1

        model.train()
        policy_losses = []

        log_prob_loss = 0
        with tqdm(total=int(max_gradient_step), desc=f"Gradient step:") as pbar:
            for it, batch in enumerate(dataloader):
        
                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()} 
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs/complete']
                goals = batch['goal']
                actions = batch['action']
                
                # policy_loss
                if is_godot:
                    log_prob = model.policy_high.log_prob(batch)
                else:
                    _, dist = model.policy_high(observations,goals)
                    log_prob = dist.log_prob(actions)
                policy_loss = -log_prob.mean()

                # Update
                optimizer.policy_low.zero_grad(set_to_none=True)
                optimizer.policy_high.zero_grad(set_to_none=True)
                policy_loss.backward()
                optimizer.policy_low.step()
                optimizer.policy_high.step()

                # Update metrics for logging 
                log_prob_loss += log_prob.cpu().detach().mean()

                gradient_steps += 1
                pbar.update(1)

                policy_losses.append(policy_loss.item())
                pbar.set_postfix(policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('low_goal_policy', m)
                    if save_as_policy:
                        model_db.push("model", m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

            logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_steps)

    return model

def train_iql_subgoal():
    """
    Trains pi(a|s,i) with iql.
    """
    raise NotImplementedError(0)

def train_gcbc_subgoal(
    model, 
    transformer, 
    tokenizer,
    keys_to_tokenize,
    low_level_policy_goals,
    save_dual_policy,
    dual_policy_relabellers,
    optimizer, 
    model_db, 
    dataloader, 
    device, 
    save_every,
    max_gradient_step,
    logger, 
    serial_evaluation=False, 
    cfg_evaluation=None, 
    evaluation_logger=None,
    is_godot=False
):
    """
    Trains pi(a|s,i) with gcbc.
    """

    global GOAL_TYPES
    transformer.eval()
    model.train()

    gradient_steps = 1
    epoch = 0
    done = False
    while not done:
        epoch += 1
        model.train()
        policy_losses = []
        log_prob_loss = 0
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()} 
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs/complete']
                goals = batch['value/goal']
                actions = batch['action']

                # policy_loss
                if is_godot:
                    log_prob = model.policy_high.log_prob(batch)
                else:
                    _, dist = model.policy_high(observations,goals)
                    log_prob = dist.log_prob(actions)
                policy_loss = -log_prob.mean()

                optimizer.policy_low.zero_grad(set_to_none=True)
                optimizer.policy_high.zero_grad(set_to_none=True)
                policy_loss.backward()
                optimizer.policy_low.step()
                optimizer.policy_high.step()

                # Update metrics for logging 
                log_prob_loss += log_prob.cpu().detach().mean()

                gradient_steps += 1

                policy_losses.append(policy_loss.item())
                pbar.set_postfix(policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('low_subgoal_policy', m)
                    if save_dual_policy:
                        dual_policy = DualPolicy(
                            low_level_policy_subgoal=m, 
                            low_level_policy_goal=low_level_policy_goals, 
                            high_level_policy=transformer, 
                            tokenizer=tokenizer, 
                            sos_token=special_tokens['SOS_TOKEN'], 
                            eos_token=special_tokens['EOS_TOKEN'], 
                            keys_to_tokenize=keys_to_tokenize, 
                            frame_relabellers=[instantiate_class(r) for r in (dual_policy_relabellers or [])]
                        )
                        m = copy.deepcopy(dual_policy).cpu()
                        model_db.push("model", m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

        logger.add_scalar('log_prob_loss', log_prob_loss/len(dataloader), gradient_steps)

    return model

def train_low_hiql_goal(
    model, 
    transformer,
    tokenizer,
    keys_to_tokenize,
    low_level_policy_goals,
    save_dual_policy,
    dual_policy_relabellers,
    optimizer, 
    model_db, 
    dataloader, 
    device, 
    save_every,
    reward_scale,
    discount,
    expectile,
    beta,
    clip_score,
    v_update_period,
    policy_update_period,
    target_update_period,
    max_gradient_step,
    polyak_coef,
    logger, 
    serial_evaluation=False, 
    cfg_evaluation=None, 
    evaluation_logger=None
):
    """
    Trains pi(a|s,i) with iql.
    """
    global GOAL_TYPES
    transformer.eval()
    model.train()

    target_vf1 = model.target_vf1
    target_vf2 = model.target_vf2

    gradient_steps = 1
    epoch = 0
    done = False
    while not done:
        epoch += 1

        model.train()
        total_vf1_loss = 0
        total_vf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        v1_min = 0
        v1_max = 0
        v1_mean = 0
        v2_min = 0
        v2_max = 0
        v2_mean = 0
        log_prob_loss = 0
        weights = None
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()} 
                batch_next = {k: v[:, 1] for k, v in batch.items()}
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs/complete']
                goals = batch['value/goal']
                actions = batch['action']
                next_observations = batch_next['obs/complete']
                next_goals = batch_next['value/goal']
                rewards = batch['reward'].unsqueeze(1) - 1.0
                masks = 1 - rewards # masks are 0 if terminal, 1 otherwise

                # vf_loss
                next_v1_t, next_v2_t = target_vf1(next_observations,next_goals), target_vf2(next_observations,next_goals)
                next_v_t = torch.min(next_v1_t, next_v2_t).detach()

                q_t = (reward_scale * rewards + masks * discount * next_v_t)
                v_t = (target_vf1(observations,goals) + target_vf2(observations,goals)) / 2
                adv = (q_t - v_t).detach()

                q1_t = reward_scale * rewards + masks * discount * next_v1_t
                q2_t = reward_scale * rewards + masks * discount * next_v2_t

                v1, v2 = model.vf1(observations,goals), model.vf2(observations,goals)
                vf1_loss = expectile_loss(adv, q1_t.detach() - v1, expectile).mean()
                vf2_loss = expectile_loss(adv, q2_t.detach() - v2, expectile).mean()
                
                # policy loss
                vf_pred = (model.vf1(observations,goals) + model.vf2(observations,goals)) / 2
                next_vf1 = model.vf1(next_observations,next_goals)
                next_vf2 = model.vf2(next_observations,next_goals)
                next_vf_pred = (next_vf1 + next_vf2) / 2
                
                adv = (next_vf_pred - vf_pred)
                exp_adv = torch.exp(adv * beta)
                exp_adv = torch.clamp(exp_adv, max=clip_score)
                weights = exp_adv[:, 0].detach()

                _, dist = model.policy_high(observations,goals)
                log_prob = dist.log_prob(actions)

                policy_loss = (-log_prob * weights).sum()

                # Update networks
                if gradient_steps % v_update_period == 0:
                        optimizer.vf1.zero_grad(set_to_none=True)
                        vf1_loss.backward()
                        optimizer.vf1.step()

                        optimizer.vf2.zero_grad(set_to_none=True)
                        vf2_loss.backward()
                        optimizer.vf2.step()

                if gradient_steps % policy_update_period == 0:
                    optimizer.policy_low.zero_grad(set_to_none=True)
                    optimizer.policy_high.zero_grad(set_to_none=True)
                    policy_loss.backward()
                    optimizer.policy_low.step()
                    optimizer.policy_high.step()

                # Soft updates
                if gradient_steps % target_update_period == 0:
                    for target_param, param in zip(target_vf1.parameters(), model.vf1.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                    for target_param, param in zip(target_vf2.parameters(), model.vf2.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                # Update metrics for logging 
                abs_adv_mean += torch.abs(adv).cpu().detach().mean()
                adv_mean += adv.cpu().detach().mean()
                adv_max += adv.cpu().detach().max()
                adv_min += adv.cpu().detach().min()
                accept_prob += (adv >= 0).float().cpu().detach().mean()
                v1_min += v1.cpu().detach().min()
                v1_max += v1.cpu().detach().max()
                v1_mean += v1.cpu().detach().mean()
                v2_min += v2.cpu().detach().min()
                v2_max += v2.cpu().detach().max()
                v2_mean += v2.cpu().detach().mean()
                log_prob_loss += log_prob.cpu().detach().mean()

                gradient_steps += 1

                total_vf1_loss += vf1_loss.item()
                total_vf2_loss += vf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(vf1_loss=total_vf1_loss/(it+1), 
                                 vf2_loss=total_vf2_loss/(it+1),
                                 policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('low_subgoal_policy', m)

                    if save_dual_policy:
                        dual_policy = DualPolicy(
                            low_level_policy_subgoal=m, 
                            low_level_policy_goal=low_level_policy_goals, 
                            high_level_policy=transformer, 
                            tokenizer=tokenizer, 
                            sos_token=special_tokens['SOS_TOKEN'], 
                            eos_token=special_tokens['EOS_TOKEN'], 
                            keys_to_tokenize=keys_to_tokenize, 
                            frame_relabellers=[instantiate_class(r) for r in (dual_policy_relabellers or [])]
                        )
                        m = copy.deepcopy(dual_policy).cpu()
                        model_db.push("model", m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

        logger.add_scalar('vf1_loss/train', vf1_loss.item(), gradient_steps)
        logger.add_scalar('vf2_loss/train', vf2_loss.item(), gradient_steps)
        logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_steps)
        logger.add_scalar('abs_adv_mean', abs_adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_mean', adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_max', adv_max/len(dataloader), gradient_steps)
        logger.add_scalar('adv_min', adv_min/len(dataloader), gradient_steps)
        logger.add_scalar('accept_prob', accept_prob/len(dataloader), gradient_steps)
        logger.add_scalar('v1_min', v1_min/len(dataloader), gradient_steps)
        logger.add_scalar('v1_max', v1_max/len(dataloader), gradient_steps)
        logger.add_scalar('v1_mean', v1_mean/len(dataloader), gradient_steps)
        logger.add_scalar('v2_min', v2_min/len(dataloader), gradient_steps)
        logger.add_scalar('v2_max', v2_max/len(dataloader), gradient_steps)
        logger.add_scalar('v2_mean', v2_mean/len(dataloader), gradient_steps)
        logger.add_scalar('log_prob_loss', log_prob_loss/len(dataloader), gradient_steps)
        logger.add_histogram(name='Histogram_awr', values=weights.cpu().detach(), epoch=gradient_steps)

    return model

def get_empty_log_dicts():
    losses = {"v1": [], "v2": [], "policy": [], "high_policy": [], "low_policy": []}
    value_infos = {
        "accept_prob": [],
        "v1_min": [],
        "v1_max": [],
        "v1_mean": [],
        "abs adv mean": [],
        "adv mean": [],
        "adv max": [],
        "adv min": [],
    }
    policy_infos = {
        "bc_log_probs": [],
        "adv": [],
        "mse": [],
        "high_bc_log_probs": [],
        "high_adv": [],
        "high_scale": [],
        "high_mse": [],
    }
    return losses, value_infos, policy_infos

def train_hiql_goal(
    model,
    use_waypoints,
    goal_keys,
    optimizer,
    dataloader,
    max_gradient_step,
    reward_scale,
    discount,
    expectile,
    beta,
    high_beta,
    clip_score,
    v_update_period,
    policy_update_period,
    target_update_period,
    polyak_coef,
    save_every,
    model_db,
    logger,
    serial_evaluation,
    cfg_evaluation,
    evaluation_logger,
    device
):  
    """
    Trains pi(a|s,g) with low policy loss of hiql.
    """
    # Training
    target_vf1 = model.target_vf1
    target_vf2 = model.target_vf2
    epoch, gradient_steps, done = 0, 1, False
    while not done:
        epoch += 1

        model.train()
        total_vf1_loss = 0
        total_vf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        v1_min = 0
        v1_max = 0
        v1_mean = 0
        v2_min = 0
        v2_max = 0
        v2_mean = 0
        log_prob_loss = 0
        weights = None
        losses, value_infos, policy_infos = get_empty_log_dicts()
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()}
                batch_next = {k: v[:, 1] if v.size()[1] > 1 else v[:, 0] for k, v in batch.items()}
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs']
                goals = batch['goal']
                low_goals = batch['low_goal']
                high_goals = batch['high_goal']
                high_targets = batch['high_target']
                actions = batch['action']
                next_observations = batch_next['obs']
                next_goals = batch_next['goal']
                next_low_goals = batch_next['low_goal']
                next_high_goals = batch_next['high_goal']
                # vf_loss vanilla GC-IQL (vanilla == using action free vf as in HIQL, rather than qf & vf as in IQL paper)
                # makes sure reward is of shape (batch_size,1)
                if len(batch["reward"].shape) == 1:
                    batch["reward"] = batch["reward"].unsqueeze(1)
                masks = 1.0 - batch["reward"]  # masks are 0 if terminal, 1 otherwise
                rewards = batch["reward"] - 1.0

                # vf_loss
                next_v1_t, next_v2_t = target_vf1(next_observations,next_goals), target_vf2(next_observations,next_goals)
                next_v_t = torch.min(next_v1_t, next_v2_t).detach()

                q_t = (reward_scale * rewards + masks * discount * next_v_t)
                v_t = (target_vf1(observations,goals) + target_vf2(observations,goals)) / 2
                adv = (q_t - v_t).detach()

                q1_t = reward_scale * rewards + masks * discount * next_v1_t
                q2_t = reward_scale * rewards + masks * discount * next_v2_t

                v1, v2 = model.vf1(observations,goals), model.vf2(observations,goals)
                vf1_loss = expectile_loss(adv, q1_t.detach() - v1, expectile).mean()
                vf2_loss = expectile_loss(adv, q2_t.detach() - v2, expectile).mean()

                adv = adv.cpu().detach()
                v1 = v1.cpu().detach()
                value_infos["abs adv mean"].append(torch.abs(adv).mean())
                value_infos["adv mean"].append(adv.mean())
                value_infos["adv max"].append(adv.max())
                value_infos["adv min"].append(adv.min())
                value_infos["accept_prob"].append((adv >= 0).float().mean())
                value_infos["v1_min"].append(v1.min())
                value_infos["v1_max"].append(v1.max())
                value_infos["v1_mean"].append(v1.mean())

                # low-level / flat policy_loss
                vf_pred = (model.vf1(observations,low_goals) + model.vf2(observations,low_goals)) / 2
                next_vf1 = model.vf1(next_observations,next_low_goals)
                next_vf2 = model.vf2(next_observations,next_low_goals)
                next_vf_pred = (next_vf1 + next_vf2) / 2

                adv = (next_vf_pred - vf_pred)
                exp_adv = torch.exp(adv * beta)
                exp_adv = torch.clamp(exp_adv, max=clip_score)
                weights = exp_adv[:, 0].detach()
                if use_waypoints:
                    logpp, mean = model.policy_low.log_prob(observations,low_goals,actions)
                else:
                    logpp, mean = model.log_prob(observations,low_goals,actions)
                low_policy_loss = (-logpp * weights).mean()
                policy_loss = low_policy_loss
                policy_infos["bc_log_probs"].append(logpp.cpu().detach().mean())
                policy_infos["adv"].append(adv.cpu().detach().mean())

                # high-level policy_loss
                if use_waypoints:
                    vf_pred = (model.vf1(observations,high_goals) + model.vf2(observations,high_goals)) / 2
                    next_vf1 = model.vf1(next_observations,next_high_goals)
                    next_vf2 = model.vf2(next_observations,next_high_goals)
                    next_vf_pred = (next_vf1 + next_vf2) / 2

                    adv = (next_vf_pred - vf_pred)
                    exp_adv = torch.exp(adv * high_beta)
                    exp_adv = torch.clamp(exp_adv, max=clip_score)
                    weights = exp_adv[:, 0].detach()

                    # in HIQL w/o repr learning they use "high_targets - obs" as target
                    # maybe because relative positions are easier to learn than absolute ones (smaller values)
                    # to "fix" this at inference time they add obs back to the high_pi output before feeding low pi
                    target = high_targets - observations
                    logpp, mean = model.policy_high.log_prob(observations,high_goals,target)
                    high_policy_loss = (-logpp * weights).mean()
                    policy_loss = low_policy_loss + high_policy_loss

                    policy_infos["high_bc_log_probs"].append(logpp.cpu().detach().mean())
                    policy_infos["high_adv"].append(adv.cpu().detach().mean())
                    # policy_infos["high_scale"].append(
                    #     dist.base_dist.scale.diag().cpu().detach().mean()
                    # )
                    # policy_infos["high_mse"].append(
                    #     ((mean.detach() - target.detach()) ** 2).cpu().mean()
                    # )

                # Update networks
                if gradient_steps % v_update_period == 0:
                        optimizer.vf1.zero_grad(set_to_none=True)
                        vf1_loss.backward()
                        optimizer.vf1.step()

                        optimizer.vf2.zero_grad(set_to_none=True)
                        vf2_loss.backward()
                        optimizer.vf2.step()

                if gradient_steps % policy_update_period == 0:
                    optimizer.policy_low.zero_grad(set_to_none=True)
                    optimizer.policy_high.zero_grad(set_to_none=True)
                    policy_loss.backward()
                    optimizer.policy_low.step()
                    optimizer.policy_high.step()

                # Soft updates
                if gradient_steps % target_update_period == 0:
                    for target_param, param in zip(target_vf1.parameters(), model.vf1.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                    for target_param, param in zip(target_vf2.parameters(), model.vf2.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                # Update metrics for logging 
                losses["v1"].append(vf1_loss.item())
                losses["v2"].append(vf2_loss.item())
                losses["policy"].append(policy_loss.item())
                if use_waypoints:
                    losses["high_policy"].append(high_policy_loss.item())
                    losses["low_policy"].append(low_policy_loss.item())

                gradient_steps += 1

                total_vf1_loss += vf1_loss.item()
                total_vf2_loss += vf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(vf1_loss=total_vf1_loss/(it+1), 
                                 vf2_loss=total_vf2_loss/(it+1),
                                 policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('model', m)
                    if serial_evaluation:
                        print('begin serial')
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

        logger.add_scalar('vf1_loss/train', vf1_loss.item(), gradient_steps)
        logger.add_scalar('vf2_loss/train', vf2_loss.item(), gradient_steps)
        logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_steps)
        logger.add_scalar('abs_adv_mean', abs_adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_mean', adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_max', adv_max/len(dataloader), gradient_steps)
        logger.add_scalar('adv_min', adv_min/len(dataloader), gradient_steps)
        logger.add_scalar('accept_prob', accept_prob/len(dataloader), gradient_steps)
        logger.add_scalar('v1_min', v1_min/len(dataloader), gradient_steps)
        logger.add_scalar('v1_max', v1_max/len(dataloader), gradient_steps)
        logger.add_scalar('v1_mean', v1_mean/len(dataloader), gradient_steps)
        logger.add_scalar('v2_min', v2_min/len(dataloader), gradient_steps)
        logger.add_scalar('v2_max', v2_max/len(dataloader), gradient_steps)
        logger.add_scalar('v2_mean', v2_mean/len(dataloader), gradient_steps)
        logger.add_scalar('log_prob_loss', log_prob_loss/len(dataloader), gradient_steps)
        logger.add_histogram(name='Histogram_awr', values=weights.cpu().detach(), epoch=gradient_steps)

    return model

def train_hgcbc_goal(
    model,
    use_waypoints,
    optimizer,
    dataloader,
    max_gradient_step,
    policy_update_period,
    save_every,
    model_db,
    serial_evaluation,
    cfg_evaluation,
    evaluation_logger,
    device
):  
    """
    Trains pi(a|s,g) with low policy loss of hiql.
    """
    # Training
    epoch, gradient_steps, done = 0, 1, False
    while not done:
        epoch += 1

        model.train()
        policy_losses = []
        losses, value_infos, policy_infos = get_empty_log_dicts()
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()}
                batch_next = {k: v[:, 1] if v.size()[1] > 1 else v[:, 0] for k, v in batch.items()}
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs']
                goals = batch['goal']
                low_goals = batch['low_goal']
                high_goals = batch['high_goal']
                actions = batch['action']
                next_observations = batch_next['obs']
                next_goals = batch_next['goal']
                next_low_goals = batch_next['low_goal']
                next_high_goals = batch_next['high_goal']
                # vf_loss vanilla GC-IQL (vanilla == using action free vf as in HIQL, rather than qf & vf as in IQL paper)
                # makes sure reward is of shape (batch_size,1)
                if len(batch["reward"].shape) == 1:
                    batch["reward"] = batch["reward"].unsqueeze(1)
                masks = 1.0 - batch["reward"]  # masks are 0 if terminal, 1 otherwise
                rewards = batch["reward"] - 1.0

                # low-level policy_loss
                if use_waypoints:
                    pi = model.policy_low
                else:
                    pi = model
                logpp, mean = pi.log_prob(batch)
                low_policy_loss = -logpp.mean()
                policy_loss = low_policy_loss

                policy_infos["bc_log_probs"].append(logpp.cpu().detach().mean())

                # high-level policy_loss
                if use_waypoints:
                    logpp, mean = model.policy_high.log_prob(batch)
                    high_policy_loss = -logpp.mean()
                    policy_loss = low_policy_loss + high_policy_loss

                    policy_infos["high_bc_log_probs"].append(logpp.cpu().detach().mean())
                    # policy_infos["high_scale"].append(
                    #     dist.base_dist.scale.diag().cpu().detach().mean()
                    # )
                    # policy_infos["high_mse"].append(
                    #     ((mean.detach() - target.detach()) ** 2).cpu().mean()
                    # )

                # Update networks
                if gradient_steps % policy_update_period == 0:
                    optimizer.policy_low.zero_grad(set_to_none=True)
                    optimizer.policy_high.zero_grad(set_to_none=True)
                    policy_loss.backward()
                    optimizer.policy_low.step()
                    optimizer.policy_high.step()

                # Update metrics for logging 
                losses["policy"].append(policy_loss.item())
                if use_waypoints:
                    losses["high_policy"].append(high_policy_loss.item())
                    losses["low_policy"].append(low_policy_loss.item())

                gradient_steps += 1

                policy_losses.append(policy_loss.item())
                pbar.set_postfix(policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('model', m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break
    return model

def train_low_hiql_subgoal(
    model, 
    transformer, 
    optimizer, 
    model_db, 
    dataloader, 
    device, 
    save_every,
    reward_scale,
    discount,
    expectile,
    beta,
    clip_score,
    v_update_period,
    policy_update_period,
    target_update_period,
    max_gradient_step,
    polyak_coef,
    logger, 
    serial_evaluation=False, 
    cfg_evaluation=None, 
    evaluation_logger=None
):
    """
    Trains pi(a|s,i) with iql.
    """
    global GOAL_TYPES
    transformer.eval()
    model.train()

    target_vf1 = model.target_vf1
    target_vf2 = model.target_vf2

    gradient_steps = 1
    epoch = 0
    done = False
    while not done:
        epoch += 1

        model.train()
        total_vf1_loss = 0
        total_vf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        v1_min = 0
        v1_max = 0
        v1_mean = 0
        v2_min = 0
        v2_max = 0
        v2_mean = 0
        log_prob_loss = 0
        weights = None
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()} 
                batch_next = {k: v[:, 1] for k, v in batch.items()}
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs/complete']
                goals = batch['value/goal']
                actions = batch['action']
                next_observations = batch_next['obs/complete']
                next_goals = batch_next['value/goal']
                rewards = batch['rewards'].unsqueeze(1) - 1.0
                masks = 1 - rewards # masks are 0 if terminal, 1 otherwise

                # vf_loss
                next_v1_t, next_v2_t = target_vf1(next_observations,next_goals), target_vf2(next_observations,next_goals)
                next_v_t = torch.min(next_v1_t, next_v2_t).detach()

                q_t = (reward_scale * rewards + masks * discount * next_v_t)
                v_t = (target_vf1(observations,goals) + target_vf2(observations,goals)) / 2
                adv = (q_t - v_t).detach()

                q1_t = reward_scale * rewards + masks * discount * next_v1_t
                q2_t = reward_scale * rewards + masks * discount * next_v2_t

                v1, v2 = model.vf1(observations,goals), model.vf2(observations,goals)
                vf1_loss = expectile_loss(adv, q1_t.detach() - v1, expectile).mean()
                vf2_loss = expectile_loss(adv, q2_t.detach() - v2, expectile).mean()
                
                # policy loss
                vf_pred = (model.vf1(observations,goals) + model.vf2(observations,goals)) / 2
                next_vf1 = model.vf1(next_observations,next_goals)
                next_vf2 = model.vf2(next_observations,next_goals)
                next_vf_pred = (next_vf1 + next_vf2) / 2
                
                adv = (next_vf_pred - vf_pred)
                exp_adv = torch.exp(adv * beta)
                exp_adv = torch.clamp(exp_adv, max=clip_score)
                weights = exp_adv[:, 0].detach()

                _, dist = model.policy_high(observations,goals)
                log_prob = dist.log_prob(actions)

                policy_loss = (-log_prob * weights).sum()

                # Update networks
                if gradient_steps % v_update_period == 0:
                        optimizer.vf1.zero_grad(set_to_none=True)
                        vf1_loss.backward()
                        optimizer.vf1.step()

                        optimizer.vf2.zero_grad(set_to_none=True)
                        vf2_loss.backward()
                        optimizer.vf2.step()

                if gradient_steps % policy_update_period == 0:
                    optimizer.policy_low.zero_grad(set_to_none=True)
                    optimizer.policy_high.zero_grad(set_to_none=True)
                    policy_loss.backward()
                    optimizer.policy_low.step()
                    optimizer.policy_high.step()

                # Soft updates
                if gradient_steps % target_update_period == 0:
                    for target_param, param in zip(target_vf1.parameters(), model.vf1.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                    for target_param, param in zip(target_vf2.parameters(), model.vf2.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                # Update metrics for logging 
                abs_adv_mean += torch.abs(adv).cpu().detach().mean()
                adv_mean += adv.cpu().detach().mean()
                adv_max += adv.cpu().detach().max()
                adv_min += adv.cpu().detach().min()
                accept_prob += (adv >= 0).float().cpu().detach().mean()
                v1_min += v1.cpu().detach().min()
                v1_max += v1.cpu().detach().max()
                v1_mean += v1.cpu().detach().mean()
                v2_min += v2.cpu().detach().min()
                v2_max += v2.cpu().detach().max()
                v2_mean += v2.cpu().detach().mean()
                log_prob_loss += log_prob.cpu().detach().mean()

                gradient_steps += 1

                total_vf1_loss += vf1_loss.item()
                total_vf2_loss += vf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(vf1_loss=total_vf1_loss/(it+1), 
                                 vf2_loss=total_vf2_loss/(it+1),
                                 policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('low_subgoal_policy', m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

        logger.add_scalar('vf1_loss/train', vf1_loss.item(), gradient_steps)
        logger.add_scalar('vf2_loss/train', vf2_loss.item(), gradient_steps)
        logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_steps)
        logger.add_scalar('abs_adv_mean', abs_adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_mean', adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_max', adv_max/len(dataloader), gradient_steps)
        logger.add_scalar('adv_min', adv_min/len(dataloader), gradient_steps)
        logger.add_scalar('accept_prob', accept_prob/len(dataloader), gradient_steps)
        logger.add_scalar('v1_min', v1_min/len(dataloader), gradient_steps)
        logger.add_scalar('v1_max', v1_max/len(dataloader), gradient_steps)
        logger.add_scalar('v1_mean', v1_mean/len(dataloader), gradient_steps)
        logger.add_scalar('v2_min', v2_min/len(dataloader), gradient_steps)
        logger.add_scalar('v2_max', v2_max/len(dataloader), gradient_steps)
        logger.add_scalar('v2_mean', v2_mean/len(dataloader), gradient_steps)
        logger.add_scalar('log_prob_loss', log_prob_loss/len(dataloader), gradient_steps)
        logger.add_histogram(name='Histogram_awr', values=weights.cpu().detach(), epoch=gradient_steps)
