import torch
import copy 

import numpy as np
import torch.nn.functional as F
import torch.nn as nn

from tqdm import tqdm

from sgcrl.models.quantizer import special_tokens
from sgcrl.utils.imports import instantiate_class, get_arguments, get_class
from sgcrl.data.torch_datasets.relabel import image_data_augmentation
from sgcrl.evaluation.d4rl_evaluation import serial_evaluation_loop
from sgcrl.models.dual_policy import DualPolicy

CURRENT = 0
FUTURE = 1
RANDOM = 2
GOAL_TYPES = [CURRENT, FUTURE, RANDOM]

#############
# quantizer #
#############
def train_quantizer(quantizer, optimizer_quantizer, epochs, dataloader, device, model_db, logger, offset, norm, contrastive_coef, commit_coef, 
                    reconstruction_coef, save_every=100, use_visual_data_augmentation=False, p_aug=0.5, square_rotation=False, vertical_flip=False,
                    horizontal_flip=False, padded_random_crop=False, padding_size=4, reshaped_random_crop=False, crop_size=(16,16), grad_max_norm=None, noise=0):
    quantizer.train()

    if len(norm) == 0:
        norm = torch.tensor([1.0], device=device)
    else:
        norm = torch.tensor(norm, device=device)

    for epoch in range(1, epochs + 1):
        losses = []
        commit_losses = []
        reconstruction_losses = []
        contrastive_losses = []
        with tqdm(dataloader, desc=f"Epoch {epoch}:") as pbar:
            for sample, positive, negative in pbar:

                # data augmentation
                sample, positive, negative = sample + noise * torch.rand_like(sample), positive + noise * torch.rand_like(positive), negative + noise * torch.rand_like(negative)
                sample, positive, negative = sample.to(device), positive.to(device), negative.to(device)
                sample, positive, negative = (sample + offset)/norm, (positive + offset)/norm, (negative + offset)/norm

                if use_visual_data_augmentation:
                    sample, sample_mask = image_data_augmentation(p_aug, sample, square_rotation, vertical_flip, horizontal_flip, padded_random_crop, padding_size, reshaped_random_crop, crop_size)
                    positive, positive_mask = image_data_augmentation(p_aug, positive, square_rotation, vertical_flip, horizontal_flip, padded_random_crop, padding_size, reshaped_random_crop, crop_size)
                    negative, negative_mask = image_data_augmentation(p_aug, negative, square_rotation, vertical_flip, horizontal_flip, padded_random_crop, padding_size, reshaped_random_crop, crop_size)

                else:
                    sample_mask = torch.ones_like(sample)
                    positive_mask = torch.ones_like(positive)
                    negative_mask = torch.ones_like(negative)

                # update
                quantized_sample,   commit_loss_sample,   latent_sample   = quantizer(sample,   return_latent=True)
                quantized_positive, commit_loss_positive, latent_positive = quantizer(positive, return_latent=True)
                quantized_negative, commit_loss_negative, latent_negative = quantizer(negative, return_latent=True)

                commit_loss = commit_loss_sample + commit_loss_positive + commit_loss_negative
                reconstruction_loss = F.mse_loss(sample, quantized_sample*sample_mask) + F.mse_loss(positive, quantized_positive*positive_mask) + F.mse_loss(negative, quantized_negative*negative_mask)
                contrastive_loss = F.triplet_margin_loss(latent_sample, positive=latent_positive, negative=latent_negative)
                
                loss = commit_coef * commit_loss + contrastive_coef * contrastive_loss + reconstruction_coef * reconstruction_loss

                optimizer_quantizer.zero_grad(set_to_none=True)
                loss.backward()
                if not (grad_max_norm is None):
                    torch.nn.utils.clip_grad_norm_(quantizer.parameters(), grad_max_norm)
                optimizer_quantizer.step()

                losses.append(loss.item())
                commit_losses.append(commit_loss.item())
                reconstruction_losses.append(reconstruction_loss.item())
                contrastive_losses.append(contrastive_loss.item())
                pbar.set_postfix(loss=np.mean(losses), commit=np.mean(commit_losses), recon=np.mean(reconstruction_losses), contrastive=np.mean(contrastive_losses))
                
            if epoch % save_every == 0:
                m = copy.deepcopy(quantizer).to("cpu")
                model_db.push("quantizer", m)
                
            logger.add_scalar("train/loss", np.mean(losses), epoch)
            logger.add_scalar("train/commit_loss", np.mean(commit_losses), epoch)
            logger.add_scalar("train/reconstruction_loss", np.mean(reconstruction_losses), epoch)
            logger.add_scalar("train/contrastive_loss", np.mean(contrastive_losses), epoch)
    
    return quantizer

##############
# high_value #
##############
def train_high_value(
    cfg,
    tokenzier,
    high_reward_scale,
    high_discount,
    high_expectile,
    high_beta,
    high_clip_score,
    high_v_update_period,
    high_policy_update_period,
    high_target_update_period,
    polyak_coef,
    epochs,
    max_gradient_step,
    train_dataloader, 
    val_dataloader, 
    special_tokens, 
    model_db, 
    logger, 
    save_every, 
    device
):
    loss_fn = nn.CrossEntropyLoss(ignore_index=special_tokens['PADDING_VALUE'],reduce=False)

    # SETTING MODELS
    high_policy = instantiate_class(cfg.high_policy)
    hvf1 = instantiate_class(cfg.hvf)
    hvf2 = instantiate_class(cfg.hvf)

    high_policy.to(cfg.device)
    hvf1.to(cfg.device)
    hvf2.to(cfg.device)

    print(high_policy)
    print(hvf1)

    target_hvf1 = copy.deepcopy(hvf1)
    target_hvf2 = copy.deepcopy(hvf2)

    # SETTING OPTIMIZERS
    optimizer_args = get_arguments(cfg.high_policy_optimizer)
    policy_optimizer = get_class(cfg.high_policy_optimizer)(
        high_policy.parameters(), **optimizer_args
    )
    vf1_optimizer = get_class(cfg.hvf_optimizer)(
        hvf1.parameters(), **optimizer_args
    )
    vf2_optimizer = get_class(cfg.hvf_optimizer)(
        hvf2.parameters(), **optimizer_args
    )

    codebook = tokenzier._quantizer.quantize._codebook.embed.squeeze(0).to(cfg.device)

    g = torch.Generator()
    g.manual_seed(cfg.seed)
    gradient_step = 0
    gradient_steps = 1
    epoch = 0
    done = False
    while not done:
        high_policy.train()
        losses = []
        accuracies = []

        total_vf1_loss = 0
        total_vf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        v1_min = 0
        v1_max = 0
        v1_mean = 0
        v2_min = 0
        v2_max = 0
        v2_mean = 0
        log_prob_loss = 0
        weights = None
        it = 0
        with tqdm(train_dataloader, desc=f'Train {epoch}:') as pbar:
            for batch in pbar:
                
                # Build batch
                token_seq = batch['token/compressed'].to(device).long() # Get compressed token sequence
                sos_indices = (token_seq  == special_tokens['SOS_TOKEN']).nonzero(as_tuple=False) # Find the indices of the `sos_token` in the batch
                sos_positions = sos_indices[:, 1]  # Extract the row indices and positions of `sos_token`, only need column positions
                eos_indices = (token_seq  == special_tokens['EOS_TOKEN']).nonzero(as_tuple=False) # Find the indices of the `eos_token` in the batch
                eos_positions = eos_indices[:, 1] # Extract the row indices and positions of `eos_token`
                indices_1 = np.random.randint(low=sos_positions.cpu()+1,high=eos_positions.cpu())
                indices_2 = np.random.randint(low=sos_positions.cpu()+1,high=eos_positions.cpu())
                current_indices = np.minimum(indices_1,indices_2)
                goal_indices = np.maximum(indices_1,indices_2)
                next_indices = np.minimum(current_indices+1,goal_indices) # if current is goal
                index_current = torch.tensor(current_indices, dtype=torch.long).to(token_seq.device)
                index_next = torch.tensor(next_indices, dtype=torch.long).to(token_seq.device)
                index_goal = torch.tensor(goal_indices, dtype=torch.long).to(token_seq.device)

                current_token = token_seq.gather(1, index_current.unsqueeze(1).to(token_seq.device))
                next_token = token_seq.gather(1, index_next.unsqueeze(1).to(token_seq.device))
                goal_token = token_seq.gather(1, index_goal.unsqueeze(1).to(token_seq.device))
                
                current_representation = codebook[index_current].detach()
                goal_representation = codebook[index_next].detach()

                # high_p_random_goal (to test later)
                B = goal_token.size(0)
                high_p_random_goal = 0.3
                normal_idx = torch.arange(start=0,end=B)
                shuffled_idx = torch.randperm(B)
                mask_random_goal = torch.rand(B) > high_p_random_goal
                idx = torch.where(mask_random_goal, normal_idx, shuffled_idx)
                goal_token = goal_token[idx]
                goal_representation = goal_representation[idx]

                rewards = (goal_token == current_token).float() - 1 # r(s,g) = 0 if (s=g) and -1 otherwise
                masks = - rewards # masks are 0 if terminal, 1 otherwise

                # Compute value loss
                next_v1_t, next_v2_t = target_hvf1(next_token,goal_token), target_hvf2(next_token,goal_token)
                next_v_t = torch.min(next_v1_t, next_v2_t).detach()
                q_t = (high_reward_scale * rewards + masks * high_discount * next_v_t)
                v_t = (target_hvf1(current_token,goal_token) + target_hvf2(current_token,goal_token)) / 2
                adv = (q_t - v_t).detach()

                q1_t = high_reward_scale * rewards + masks * high_discount * next_v1_t
                q2_t = high_reward_scale * rewards + masks * high_discount * next_v2_t

                v1, v2 = hvf1(current_token,goal_token), hvf2(current_token,goal_token)
                vf1_loss = expectile_loss(adv, q1_t.detach() - v1, high_expectile).mean()
                vf2_loss = expectile_loss(adv, q2_t.detach() - v2, high_expectile).mean()

                # Compute policy loss
                vf_pred = (hvf1(current_token,goal_token) + hvf2(current_token,goal_token)) / 2
                next_vf1 = hvf1(next_token,goal_token)
                next_vf2 = hvf2(next_token,goal_token)
                next_vf_pred = (next_vf1 + next_vf2) / 2
                adv = (next_vf_pred - vf_pred)
                exp_adv = torch.exp(adv * high_beta)
                exp_adv = torch.clamp(exp_adv, max=high_clip_score)
                weights = exp_adv[:, 0].detach()

                logits = high_policy(current_representation,goal_representation)
                policy_loss = (weights*loss_fn(logits, next_token.squeeze(1))).mean()

                # Update networks
                if gradient_step % high_v_update_period == 0:
                    vf1_optimizer.zero_grad()
                    vf1_loss.backward()
                    vf1_optimizer.step()

                    vf2_optimizer.zero_grad()
                    vf2_loss.backward()
                    vf2_optimizer.step()

                if gradient_step % high_policy_update_period == 0:
                    policy_optimizer.zero_grad()
                    policy_loss.backward()
                    policy_optimizer.step()

                # Soft updates
                if gradient_step % high_target_update_period == 0:
                    for target_param, param in zip(
                        target_hvf1.parameters(), hvf1.parameters()
                    ):
                        target_param.data.copy_(
                            polyak_coef * param.data
                            + (1.0 - polyak_coef) * target_param.data
                        )

                    for target_param, param in zip(
                        target_hvf2.parameters(), hvf2.parameters()
                    ):
                        target_param.data.copy_(
                            polyak_coef * param.data
                            + (1.0 - polyak_coef) * target_param.data
                        )


                # Update metrics for logging 
                abs_adv_mean += torch.abs(adv).cpu().detach().mean()
                adv_mean += adv.cpu().detach().mean()
                adv_max += adv.cpu().detach().max()
                adv_min += adv.cpu().detach().min()
                accept_prob += (adv >= 0).float().cpu().detach().mean()
                v1_min += v1.cpu().detach().min()
                v1_max += v1.cpu().detach().max()
                v1_mean += v1.cpu().detach().mean()
                v2_min += v2.cpu().detach().min()
                v2_max += v2.cpu().detach().max()
                v2_mean += v2.cpu().detach().mean()
                log_prob_loss += policy_loss.cpu().detach().mean()

                gradient_step += 1
                it += 1

                total_vf1_loss += vf1_loss.item()
                total_vf2_loss += vf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(vf1_loss=total_vf1_loss/(it+1), 
                                 vf2_loss=total_vf2_loss/(it+1),
                                 policy_loss=np.mean(policy_losses))
                
                if gradient_step % save_every == 0:
                    m = copy.deepcopy(high_policy).to('cpu')
                    model_db.push('high_goal_policy', m)

                    m = copy.deepcopy(hvf1).to('cpu')
                    model_db.push('hvf1', m)
                
                if gradient_step >= max_gradient_step:
                    done = True
                    break
            
            epoch += 1
        
        logger.add_scalar('vf1_loss/train', vf1_loss.item(), gradient_step)
        logger.add_scalar('vf2_loss/train', vf2_loss.item(), gradient_step)
        logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_step)
        logger.add_scalar('abs_adv_mean', abs_adv_mean/len(train_dataloader), gradient_step)
        logger.add_scalar('adv_mean', adv_mean/len(train_dataloader), gradient_step)
        logger.add_scalar('adv_max', adv_max/len(train_dataloader), gradient_step)
        logger.add_scalar('adv_min', adv_min/len(train_dataloader), gradient_step)
        logger.add_scalar('accept_prob', accept_prob/len(train_dataloader), gradient_step)
        logger.add_scalar('v1_min', v1_min/len(train_dataloader), gradient_step)
        logger.add_scalar('v1_max', v1_max/len(train_dataloader), gradient_step)
        logger.add_scalar('v1_mean', v1_mean/len(train_dataloader), gradient_step)
        logger.add_scalar('v2_min', v2_min/len(train_dataloader), gradient_step)
        logger.add_scalar('v2_max', v2_max/len(train_dataloader), gradient_step)
        logger.add_scalar('v2_mean', v2_mean/len(train_dataloader), gradient_step)
        logger.add_scalar('log_prob_loss', log_prob_loss/len(train_dataloader), gradient_step)
        logger.add_histogram(name='Histogram_awr', values=weights.cpu().detach(), epoch=gradient_step)

    return high_policy, hvf1, hvf2

###############
# transformer #
###############
def train_transformer(transformer, optimizer, epochs, train_dataloader, val_dataloader, special_tokens, model_db, logger, save_every, device):
    loss_fn = nn.CrossEntropyLoss(ignore_index=special_tokens['PADDING_VALUE'])

    for epoch in range(1, epochs + 1):   
        transformer.train()
        losses = []
        accuracies = []
        with tqdm(train_dataloader, desc=f'Train {epoch}:') as pbar:
            for batch in pbar:
                input_seq = torch.cat((batch['goal/token'][:, 0], batch['token/compressed'][:, :2]), dim=1).to(device).long() # goal token + sos token + initial token
                target_seq = batch['token/compressed'].to(device).long()

                logits = transformer(input_seq, target_seq[:, :-1])
                loss = loss_fn(logits, target_seq[:, 1:])

                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()

                preds = logits.argmax(dim=1)
                masked_preds = (target_seq[:, 1:] != special_tokens['PADDING_VALUE'])

                accuracy = (preds[masked_preds] == target_seq[:, 1:][masked_preds]).float().mean().detach().cpu()

                losses.append(loss.item())
                accuracies.append(accuracy)
                
                pbar.set_postfix(loss=np.mean(losses), accuracy=np.mean(accuracies))

        logger.add_scalar("train/loss", np.mean(losses), epoch)
        logger.add_scalar("train/acc", np.mean(accuracies), epoch)

        transformer.eval()
        losses = []
        accuracies = []
        with torch.no_grad():
            with tqdm(val_dataloader, desc=f'Val   {epoch}:') as pbar:
                for batch in pbar:
                    input_seq = torch.cat((batch['goal/token'][:, 0], batch['token/compressed'][:, :2]), dim=1).to(device).long() # goal token + sos token + initial token
                    target_seq = batch['token/compressed'].to(device).long()

                    logits = transformer(input_seq, target_seq[:, :-1])
                    loss = loss_fn(logits, target_seq[:, 1:])

                    preds = logits.argmax(dim=1)
                    masked_preds = (target_seq[:, 1:] != special_tokens['PADDING_VALUE'])
                    accuracy = (preds[masked_preds] == target_seq[:, 1:][masked_preds]).float().mean().detach().cpu()

                    losses.append(loss.item())
                    accuracies.append(accuracy)
                
                    pbar.set_postfix(loss=np.mean(losses), accuracy=np.mean(accuracies))


        logger.add_scalar("val/loss", np.mean(losses), epoch)
        logger.add_scalar("val/acc", np.mean(accuracies), epoch)
    
        if epoch % save_every == 0:
            m = copy.deepcopy(transformer).to("cpu")
            model_db.push("transformer", m)
    
    return transformer

#################
# Goal policies #
#################
def expectile_loss(adv, diff, expectile):
    weight = torch.where(adv >= 0, expectile, (1 - expectile))
    return weight * (diff**2)

def train_low_hiql_goal(
    model, 
    transformer,
    tokenizer,
    keys_to_tokenize,
    low_level_policy_goals,
    save_dual_policy,
    dual_policy_relabellers,
    optimizer, 
    model_db, 
    dataloader, 
    device, 
    save_every,
    reward_scale,
    discount,
    expectile,
    beta,
    clip_score,
    v_update_period,
    policy_update_period,
    target_update_period,
    max_gradient_step,
    polyak_coef,
    logger, 
    serial_evaluation=False, 
    cfg_evaluation=None, 
    evaluation_logger=None,
    is_godot=False
):
    """
    Trains pi(a|s,i) with iql.
    """
    global GOAL_TYPES
    transformer.eval()
    model.train()

    target_vf1 = model.target_vf1
    target_vf2 = model.target_vf2

    gradient_steps = 1
    epoch = 0
    done = False
    while not done:
        epoch += 1

        model.train()
        total_vf1_loss = 0
        total_vf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        v1_min = 0
        v1_max = 0
        v1_mean = 0
        v2_min = 0
        v2_max = 0
        v2_mean = 0
        log_prob_loss = 0
        weights = None
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()} 
                batch_next = {k: v[:, 1] for k, v in batch.items()}
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs/complete']
                goals = batch['goal'] # hiql given subgoal ?
                actions = batch['action']
                next_observations = batch_next['obs/complete']
                next_goals = batch_next['goal']
                rewards = batch['reward'].unsqueeze(1) - 1.0
                masks = 1 - rewards # masks are 0 if terminal, 1 otherwise

                # vf_loss
                next_v1_t, next_v2_t = target_vf1(next_observations,next_goals), target_vf2(next_observations,next_goals)
                next_v_t = torch.min(next_v1_t, next_v2_t).detach()

                q_t = (reward_scale * rewards + masks * discount * next_v_t)
                v_t = (target_vf1(observations,goals) + target_vf2(observations,goals)) / 2
                adv = (q_t - v_t).detach()

                q1_t = reward_scale * rewards + masks * discount * next_v1_t
                q2_t = reward_scale * rewards + masks * discount * next_v2_t

                v1, v2 = model.vf1(observations,goals), model.vf2(observations,goals)
                vf1_loss = expectile_loss(adv, q1_t.detach() - v1, expectile).mean()
                vf2_loss = expectile_loss(adv, q2_t.detach() - v2, expectile).mean()
                
                # policy loss
                vf_pred = (model.vf1(observations,goals) + model.vf2(observations,goals)) / 2
                next_vf1 = model.vf1(next_observations,next_goals)
                next_vf2 = model.vf2(next_observations,next_goals)
                next_vf_pred = (next_vf1 + next_vf2) / 2
                
                adv = (next_vf_pred - vf_pred)
                exp_adv = torch.exp(adv * beta)
                exp_adv = torch.clamp(exp_adv, max=clip_score)
                weights = exp_adv[:, 0].detach()

                if is_godot:
                    log_prob = model.policy_high.log_prob(batch)
                else:
                    _, dist = model.policy_high(observations,goals)
                    log_prob = dist.log_prob(actions)
                policy_loss = (-log_prob * weights).sum()

                # Update networks
                if gradient_steps % v_update_period == 0:
                        optimizer.vf1.zero_grad(set_to_none=True)
                        vf1_loss.backward()
                        optimizer.vf1.step()

                        optimizer.vf2.zero_grad(set_to_none=True)
                        vf2_loss.backward()
                        optimizer.vf2.step()

                if gradient_steps % policy_update_period == 0:
                    optimizer.policy_low.zero_grad(set_to_none=True)
                    optimizer.policy_high.zero_grad(set_to_none=True)
                    policy_loss.backward()
                    optimizer.policy_low.step()
                    optimizer.policy_high.step()

                # Soft updates
                if gradient_steps % target_update_period == 0:
                    for target_param, param in zip(target_vf1.parameters(), model.vf1.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                    for target_param, param in zip(target_vf2.parameters(), model.vf2.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                # Update metrics for logging 
                abs_adv_mean += torch.abs(adv).cpu().detach().mean()
                adv_mean += adv.cpu().detach().mean()
                adv_max += adv.cpu().detach().max()
                adv_min += adv.cpu().detach().min()
                accept_prob += (adv >= 0).float().cpu().detach().mean()
                v1_min += v1.cpu().detach().min()
                v1_max += v1.cpu().detach().max()
                v1_mean += v1.cpu().detach().mean()
                v2_min += v2.cpu().detach().min()
                v2_max += v2.cpu().detach().max()
                v2_mean += v2.cpu().detach().mean()
                log_prob_loss += log_prob.cpu().detach().mean()

                gradient_steps += 1

                total_vf1_loss += vf1_loss.item()
                total_vf2_loss += vf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(vf1_loss=total_vf1_loss/(it+1), 
                                 vf2_loss=total_vf2_loss/(it+1),
                                 policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('low_goal_policy', m)

                    if save_dual_policy:
                        dual_policy = DualPolicy(
                            low_level_policy_subgoal=m,
                            low_level_policy_goal=low_level_policy_goals, 
                            high_level_policy=transformer, 
                            tokenizer=tokenizer, 
                            sos_token=special_tokens['SOS_TOKEN'], 
                            eos_token=special_tokens['EOS_TOKEN'], 
                            keys_to_tokenize=keys_to_tokenize, 
                            frame_relabellers=[instantiate_class(r) for r in (dual_policy_relabellers or [])]
                        )
                        m = copy.deepcopy(dual_policy).cpu()
                        model_db.push("model", m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

        logger.add_scalar('vf1_loss/train', vf1_loss.item(), gradient_steps)
        logger.add_scalar('vf2_loss/train', vf2_loss.item(), gradient_steps)
        logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_steps)
        logger.add_scalar('abs_adv_mean', abs_adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_mean', adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_max', adv_max/len(dataloader), gradient_steps)
        logger.add_scalar('adv_min', adv_min/len(dataloader), gradient_steps)
        logger.add_scalar('accept_prob', accept_prob/len(dataloader), gradient_steps)
        logger.add_scalar('v1_min', v1_min/len(dataloader), gradient_steps)
        logger.add_scalar('v1_max', v1_max/len(dataloader), gradient_steps)
        logger.add_scalar('v1_mean', v1_mean/len(dataloader), gradient_steps)
        logger.add_scalar('v2_min', v2_min/len(dataloader), gradient_steps)
        logger.add_scalar('v2_max', v2_max/len(dataloader), gradient_steps)
        logger.add_scalar('v2_mean', v2_mean/len(dataloader), gradient_steps)
        logger.add_scalar('log_prob_loss', log_prob_loss/len(dataloader), gradient_steps)
        logger.add_histogram(name='Histogram_awr', values=weights.cpu().detach(), epoch=gradient_steps)

    return model

####################
# Subgoal policies #
####################
def train_low_hiql_subgoal(
    model, 
    transformer,
    tokenizer,
    use_obs_representation,
    keys_to_tokenize,
    low_level_policy_goals,
    save_dual_policy,
    dual_policy_relabellers,
    optimizer, 
    model_db, 
    dataloader, 
    device, 
    save_every,
    reward_scale,
    discount,
    expectile,
    beta,
    clip_score,
    v_update_period,
    policy_update_period,
    target_update_period,
    max_gradient_step,
    polyak_coef,
    logger, 
    serial_evaluation=False, 
    cfg_evaluation=None, 
    evaluation_logger=None,
    is_godot=False
):
    """
    Trains pi(a|s,i) with iql.
    """
    global GOAL_TYPES
    transformer.eval()
    model.train()

    target_vf1 = model.target_vf1
    target_vf2 = model.target_vf2

    gradient_steps = 1
    epoch = 0
    done = False
    while not done:
        epoch += 1

        model.train()
        total_vf1_loss = 0
        total_vf2_loss = 0
        policy_losses = []

        abs_adv_mean = 0
        adv_mean = 0
        adv_max = 0
        adv_min = 0
        accept_prob = 0
        v1_min = 0
        v1_max = 0
        v1_mean = 0
        v2_min = 0
        v2_max = 0
        v2_mean = 0
        log_prob_loss = 0
        weights = None
        with tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}:") as pbar:
            for it, batch in pbar:

                # Form batch
                batch = {k: v.to(device) for k, v in batch.items()} 
                batch_next = {k: v[:, 1] for k, v in batch.items()}
                batch = {k: v[:, 0] for k, v in batch.items()}

                # Get data
                observations = batch['obs/partial']
                value_goals = batch['value/goal'] # separate because of relabelling
                goals = batch['token/next/obs_representation'] if use_obs_representation else batch['token/next']
                actions = batch['action']
                next_observations = batch_next['obs/partial']
                next_value_goals = batch_next['value/goal']
                next_goals = batch_next['token/next/obs_representation'] if use_obs_representation else batch_next['token/next']
                rewards = batch['reward'].unsqueeze(1) - 1.0
                masks = - rewards # masks are 0 if terminal, 1 otherwise

                # vf_loss
                next_v1_t, next_v2_t = target_vf1(next_observations,next_value_goals), target_vf2(next_observations,next_value_goals)
                next_v_t = torch.min(next_v1_t, next_v2_t).detach()

                q_t = (reward_scale * rewards + masks * discount * next_v_t)
                v_t = (target_vf1(observations,value_goals) + target_vf2(observations,value_goals)) / 2
                adv = (q_t - v_t).detach()

                q1_t = reward_scale * rewards + masks * discount * next_v1_t
                q2_t = reward_scale * rewards + masks * discount * next_v2_t

                v1, v2 = model.vf1(observations,value_goals), model.vf2(observations,value_goals)
                vf1_loss = expectile_loss(adv, q1_t.detach() - v1, expectile).mean()
                vf2_loss = expectile_loss(adv, q2_t.detach() - v2, expectile).mean()
                
                # policy loss
                vf_pred = (model.vf1(observations,goals) + model.vf2(observations,goals)) / 2
                next_vf1 = model.vf1(next_observations,next_goals)
                next_vf2 = model.vf2(next_observations,next_goals)
                next_vf_pred = (next_vf1 + next_vf2) / 2
                
                adv = (next_vf_pred - vf_pred)
                exp_adv = torch.exp(adv * beta)
                exp_adv = torch.clamp(exp_adv, max=clip_score)
                weights = exp_adv[:, 0].detach()

                if is_godot:
                    log_prob = model.policy_high.log_prob(batch)
                else:
                    _, dist = model.policy_high(observations,goals)
                    log_prob = dist.log_prob(actions)
                policy_loss = (-log_prob * weights).sum()

                # Update networks
                if gradient_steps % v_update_period == 0:
                        optimizer.vf1.zero_grad(set_to_none=True)
                        vf1_loss.backward()
                        optimizer.vf1.step()

                        optimizer.vf2.zero_grad(set_to_none=True)
                        vf2_loss.backward()
                        optimizer.vf2.step()

                if gradient_steps % policy_update_period == 0:
                    optimizer.policy_low.zero_grad(set_to_none=True)
                    optimizer.policy_high.zero_grad(set_to_none=True)
                    policy_loss.backward()
                    optimizer.policy_low.step()
                    optimizer.policy_high.step()

                # Soft updates
                if gradient_steps % target_update_period == 0:
                    for target_param, param in zip(target_vf1.parameters(), model.vf1.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                    for target_param, param in zip(target_vf2.parameters(), model.vf2.parameters()):
                        target_param.data.copy_(polyak_coef * param.data + (1.0 - polyak_coef) * target_param.data)

                # Update metrics for logging 
                abs_adv_mean += torch.abs(adv).cpu().detach().mean()
                adv_mean += adv.cpu().detach().mean()
                adv_max += adv.cpu().detach().max()
                adv_min += adv.cpu().detach().min()
                accept_prob += (adv >= 0).float().cpu().detach().mean()
                v1_min += v1.cpu().detach().min()
                v1_max += v1.cpu().detach().max()
                v1_mean += v1.cpu().detach().mean()
                v2_min += v2.cpu().detach().min()
                v2_max += v2.cpu().detach().max()
                v2_mean += v2.cpu().detach().mean()
                log_prob_loss += log_prob.cpu().detach().mean()

                gradient_steps += 1

                total_vf1_loss += vf1_loss.item()
                total_vf2_loss += vf2_loss.item()
                policy_losses.append(policy_loss.item())
                pbar.set_postfix(vf1_loss=total_vf1_loss/(it+1), 
                                 vf2_loss=total_vf2_loss/(it+1),
                                 policy_loss=np.mean(policy_losses))
                
                if gradient_steps % save_every == 0:
                    m = copy.deepcopy(model).to('cpu')
                    model_db.push('low_subgoal_policy', m)

                    if save_dual_policy:
                        dual_policy = DualPolicy(
                            low_level_policy_subgoal=m, 
                            low_level_policy_goal=low_level_policy_goals, 
                            high_level_policy=transformer, 
                            tokenizer=tokenizer, 
                            sos_token=special_tokens['SOS_TOKEN'], 
                            eos_token=special_tokens['EOS_TOKEN'], 
                            keys_to_tokenize=keys_to_tokenize, 
                            frame_relabellers=[instantiate_class(r) for r in (dual_policy_relabellers or [])]
                        )
                        m = copy.deepcopy(dual_policy).cpu()
                        model_db.push("model", m)
                    
                    if serial_evaluation:
                        serial_evaluation_loop(model_db, cfg_evaluation, evaluation_logger, only_last=True)
                
                if gradient_steps >= max_gradient_step:
                    done = True
                    break

        logger.add_scalar('vf1_loss/train', vf1_loss.item(), gradient_steps)
        logger.add_scalar('vf2_loss/train', vf2_loss.item(), gradient_steps)
        logger.add_scalar('policy_loss/train', np.mean(policy_losses), gradient_steps)
        logger.add_scalar('abs_adv_mean', abs_adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_mean', adv_mean/len(dataloader), gradient_steps)
        logger.add_scalar('adv_max', adv_max/len(dataloader), gradient_steps)
        logger.add_scalar('adv_min', adv_min/len(dataloader), gradient_steps)
        logger.add_scalar('accept_prob', accept_prob/len(dataloader), gradient_steps)
        logger.add_scalar('v1_min', v1_min/len(dataloader), gradient_steps)
        logger.add_scalar('v1_max', v1_max/len(dataloader), gradient_steps)
        logger.add_scalar('v1_mean', v1_mean/len(dataloader), gradient_steps)
        logger.add_scalar('v2_min', v2_min/len(dataloader), gradient_steps)
        logger.add_scalar('v2_max', v2_max/len(dataloader), gradient_steps)
        logger.add_scalar('v2_mean', v2_mean/len(dataloader), gradient_steps)
        logger.add_scalar('log_prob_loss', log_prob_loss/len(dataloader), gradient_steps)
        logger.add_histogram(name='Histogram_awr', values=weights.cpu().detach(), epoch=gradient_steps)