# Must be run with OMP_NUM_THREADS=1

import logging
import os
import pprint
import threading
import time
import timeit
import traceback
from collections import Counter, defaultdict

import coolname
import hydra
import numpy as np
import pandas as pd
import plotly.express as px
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from tokenizers import BertWordPieceTokenizer
from torch import multiprocessing as mp
from torch import nn
from torch.nn import functional as F

from . import buffers as B
from . import envs, losses, models, optimizers, utils
from .torchbeast.core import prof

from einops import rearrange


generator_batcher = utils.Batcher()
generator_count = 0
message_counts = None
optimizer_steps = 0
discriminator_optimizer_steps = 0
generator_optimizer_steps = 0
goal_counts = None
bert_tokenizer = None


LOCKS = defaultdict(threading.Lock)


def check_goal_completion(env_output, initial_env_state, action, goal, raw_goal):
    old_frame = torch.flatten(initial_env_state["frame"], 2, 3)
    new_frame = torch.flatten(env_output["frame"], 2, 3)
    reached_condition = _check_goal_completion_babyai(
        old_frame,
        new_frame,
        env_output,
        action,
        goal,
    )
    return reached_condition


def _check_goal_completion_babyai(old_frame, new_frame, env_output, action, goal):
    return env_output["subgoal_done"][goal.squeeze()]


@torch.no_grad()
def act(
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    generator_model,
    buffers: B.Buffers,
    initial_agent_state_buffers,
    all_proposed_goals,
    achieved_proposed_goals,
    achieved_proposed_goals_steps,
    generator_current_target,
):
    """Defines and generates IMPALA actors in multiples threads."""
    num_frames = 0

    try:
        logging.info(f"Actor {actor_index} started.")
        timings = prof.Timings()
        gym_env = envs.create_env(FLAGS)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)

        env = envs.babyai.Observation_WrapperSetup(gym_env)

        env_output = env.reset()
        initial_env_state = env.get_initial_env_state(env_output)
        initial_env_output = env.get_initial_env_output(env_output)

        agent_state = model.initial_state(batch_size=1)
        generator_output = generator_model(env_output)
        goal = generator_output["goal"]
        raw_goal = generator_output["raw_goal"]

        intrinsic_done = False
        all_proposed_goals[goal.item()] += 1

        agent_output, _ = model(env_output, agent_state, raw_goal)

        reached_condition = check_goal_completion(
            env_output, initial_env_state, agent_output["action"], goal, raw_goal
        )
        intrinsic_done = reached_condition

        while True:
            timings.reset()
            timings.time("get_target")
            index = free_queue.get()
            timings.time("get_queue")
            if index is None:
                print(f"Got None index in worker process {actor_index}, exiting")
                break

            initial_env_state_with_head = utils.map_dict(
                lambda k: f"initial_{k}", initial_env_state, map_keys=True
            )
            buffers.update(
                index,
                0,
                **env_output,
                **agent_output,
                **generator_output,
                **initial_env_state_with_head,
                intrinsic_done=intrinsic_done,
                reached=reached_condition,
            )
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor
            timings.time("write")

            for t in range(FLAGS.unroll_length):
                timings.reset()

                agent_output, agent_state = model(env_output, agent_state, raw_goal)

                timings.time("model")

                env_output = env.step(agent_output["action"])
                num_frames += 1

                timings.time("step")

                reached_condition = check_goal_completion(
                    env_output,
                    initial_env_state,
                    agent_output["action"],
                    goal,
                    raw_goal,
                )

                env_done = env_output["done"][0] == 1

                intrinsic_done = reached_condition or env_done

                initial_env_state_with_head = utils.map_dict(
                    lambda k: f"initial_{k}", initial_env_state, map_keys=True
                )
                buffers.update(
                    index,
                    t + 1,
                    **env_output,
                    **agent_output,
                    **generator_output,
                    **initial_env_state_with_head,
                    intrinsic_done=intrinsic_done,
                    reached=reached_condition,
                )

                if intrinsic_done:
                    if reached_condition:
                        achieved_proposed_goals[goal.item()] += 1
                        achieved_proposed_goals_steps[goal.item()] += int(env.extrinsic_episode_step)
                        env.intrinsic_episode_step = 0
                    if env_done:
                        env_output = env.reset()
                        initial_env_state = env.get_initial_env_state(env_output)
                        initial_env_output = env.get_initial_env_output(env_output)
                    
                    with torch.no_grad():
                        generator_output = generator_model(initial_env_output) 
                            
                    goal = generator_output["goal"]
                    raw_goal = generator_output["raw_goal"]

                    all_proposed_goals[goal.item()] += 1

                timings.time("write")
            full_queue.put(index)

    except KeyboardInterrupt:
        print(f"Caught KeyboardInterrupt in worker process {actor_index}")
        pass
    except Exception as e:
        logging.info(f"Exception in worker process {actor_index}")
        logging.info(traceback.format_exc())
        raise e


def get_batch(
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers: B.Buffers,
    initial_agent_state_buffers,
    timings,
):
    with LOCKS["get_batch"]:
        timings.time("lock")
        indices = [full_queue.get() for _ in range(FLAGS.batch_size)]
        timings.time("dequeue")
    batch = buffers.get_batch(indices, device=FLAGS.device)
    initial_agent_state = (
        torch.cat(ts, dim=1)
        for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
    )
    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    initial_agent_state = tuple(
        t.to(device=FLAGS.device, non_blocking=True) for t in initial_agent_state
    )
    timings.time("device")

    return batch, initial_agent_state


def learn_generator_policy(
    generator_model,
    generator_optimizer,
    generator_scheduler,
    generator_batch,
    generator_current_target,
    max_steps,
    stats,
):
    global generator_count

    local_generator_current_target = generator_current_target.value

    generator_outputs = generator_model(generator_batch)
    generator_bootstrap_value = generator_outputs["generator_baseline"][-1]

    def compute_generator_reward(intrinsic_episode_step, reached, targ):
        aux = FLAGS.generator_reward_negative * torch.ones(
            intrinsic_episode_step.shape
        ).to(device=FLAGS.device)
        difficult_enough = (intrinsic_episode_step >= targ).float()
        aux += difficult_enough * reached
        aux += ((1 - difficult_enough) * reached) * FLAGS.easy_goal_reward
        return aux

    generator_rewards = compute_generator_reward(
        generator_batch["intrinsic_episode_step"],
        generator_batch["reached"],
        targ=local_generator_current_target,
    )
    
    generator_rewards -= (generator_batch["goal"] == 0).float()

    if torch.mean(generator_rewards).item() >= FLAGS.generator_threshold:
        generator_count += 1
    else:
        generator_count = 0    
    if (
        generator_count >= FLAGS.generator_counts
        and local_generator_current_target <= max_steps * FLAGS.generator_maximum_rate
    ):
        local_generator_current_target += 1
        generator_current_target.value = local_generator_current_target
        generator_count = 0

    if FLAGS.reward_clipping == "abs_one":
        generator_clipped_rewards = torch.clamp(generator_rewards, -2, 2)
    else:
        generator_clipped_rewards = generator_rewards

    if not FLAGS.no_extrinsic_rewards:
        generator_clipped_rewards = (
            1.0 * (generator_batch["reward"] > 0).float()
            + generator_clipped_rewards * (generator_batch["reward"] <= 0).float()
        )
        if FLAGS.combine_rewards:
            reached_both_goals = (
                (generator_batch["reached"] > 0) & (generator_batch["reward"] > 0)
            ).float()
            generator_clipped_rewards += reached_both_goals

    generator_discounts = torch.zeros_like(generator_batch["intrinsic_episode_step"])

    gg_loss, generator_baseline_loss = losses.compute_actor_losses(
        behavior_policy_logits=generator_batch["generator_logits"],
        target_policy_logits=generator_outputs["generator_logits"],
        actions=generator_batch["goal"],
        discounts=generator_discounts,
        rewards=generator_clipped_rewards,
        values=generator_outputs["generator_baseline"],
        bootstrap_value=generator_bootstrap_value,
        baseline_cost=FLAGS.baseline_cost,
    )
    generator_entropy_loss = FLAGS.generator_entropy_cost * losses.compute_entropy_loss(
        generator_outputs["generator_logits"]
    )

    generator_total_loss = gg_loss + generator_entropy_loss + generator_baseline_loss

    intrinsic_rewards_gen = generator_batch["reached"] * (
        1 - 0.9 * (generator_batch["intrinsic_episode_step"].float() / max_steps)
    )

    stats["reached_goal"] = generator_batch["reached"].float().mean().item()
    stats["gen_rewards"] = torch.mean(generator_clipped_rewards).item()
    stats["gg_loss"] = gg_loss.item()
    stats["generator_baseline_loss"] = generator_baseline_loss.item()
    stats["generator_entropy_loss"] = generator_entropy_loss.item()
    stats["generator_intrinsic_rewards"] = intrinsic_rewards_gen.mean().item()
    stats["mean_intrinsic_episode_steps"] = torch.mean(
        generator_batch["intrinsic_episode_step"].float()
    ).item()
    stats["ex_reward"] = torch.mean(generator_batch["reward"]).item()
    stats["generator_current_target"] = local_generator_current_target

    generator_optimizer.zero_grad()
    generator_total_loss.backward()

    nn.utils.clip_grad_norm_(generator_model.parameters(), 40.0)
    generator_optimizer.step()
    generator_scheduler.step()
    global generator_optimizer_steps
    generator_optimizer_steps += 1
    stats["generator_optimizer_steps"] = generator_optimizer_steps
    stats["generator_lr"] = generator_optimizer.param_groups[0]["lr"]



@utils.require_lock(LOCKS, "learn")
def learn(
    actor_model,
    model,
    actor_generator_model,
    generator_model,
    batch,
    initial_agent_state,
    optimizer,
    discriminator_optimizer,
    generator_optimizer,
    scheduler,
    discriminator_scheduler,
    generator_scheduler,
    generator_current_target,
    max_steps=100.0,
):
    stats = {}

    subgoal_done = batch["subgoal_done"][1:].to(device=FLAGS.device)
    subgoal_achievable = batch["subgoal_achievable"][1:].to(device=FLAGS.device)
    goal = batch["goal"][1:].to(device=FLAGS.device)
    reached = batch["reached"][1:].to(device=FLAGS.device)
    intrinsic_done = batch["intrinsic_done"][1:].to(device=FLAGS.device)

    goal_was_achievable = subgoal_achievable.gather(-1, goal.unsqueeze(-1)).squeeze(-1)
    goal_was_achievable = goal_was_achievable.float().mean().item()

    if FLAGS.generator:
        intrinsic_rewards = FLAGS.intrinsic_reward_coef * reached.float()
        intrinsic_rewards = intrinsic_rewards * (
            intrinsic_rewards
            - 0.9 * (batch["intrinsic_episode_step"][1:].float() / max_steps)
        )
    else:
        intrinsic_rewards = torch.zeros_like(reached, dtype=torch.float32)

    if FLAGS.naive_message_reward > 0:
        if FLAGS.is_babyai:
            encountered_message = batch["subgoal_done"].any(-1).float()
        else:
            encountered_message = (~((batch["message"][..., 0] == 101) & (batch["message"][..., 1] == 102))).float()

        if FLAGS.naive_message_reward_format == "learn":
            learn_reward = (FLAGS.discounting * encountered_message[1:]) - encountered_message[:-1]
            intrinsic_rewards += learn_reward * FLAGS.naive_message_reward
        else:
            intrinsic_rewards += encountered_message[1:] * FLAGS.naive_message_reward

    learner_outputs, _ = model(
        batch, initial_agent_state, batch["raw_goal"].squeeze(-1)
    )
    batch = utils.map_dict(lambda t: t[1:], batch)
    learner_outputs = utils.map_dict(lambda t: t[:-1], learner_outputs)
    rewards = batch["reward"]
    
    if FLAGS.mutual_information:
        mutual_information_reward = get_mutual_information(model, batch, stats)
        intrinsic_rewards += FLAGS.mutual_information_rate * mutual_information_reward

    if not FLAGS.int.twoheaded:
        total_rewards = rewards + intrinsic_rewards
    else:
        total_rewards = rewards

    if FLAGS.reward_clipping == "abs_one":
        clipped_rewards = torch.clamp(total_rewards, -1, 1)
    elif FLAGS.reward_clipping == "none":
        clipped_rewards = total_rewards

    discounts = (
        ~batch["done"]
    ).float() * FLAGS.discounting
    clipped_rewards += 1.0 * (rewards > 0.0).float()

    total_loss = 0
    
    if not FLAGS.no_extrinsic_rewards:
        pg_loss, baseline_loss = losses.compute_actor_losses(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch["action"],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=learner_outputs["baseline"][-1],
            baseline_cost=FLAGS.baseline_cost,
        )
        entropy_loss = FLAGS.entropy_cost * losses.compute_entropy_loss(
            learner_outputs["policy_logits"]
        )

        total_loss += pg_loss + baseline_loss + entropy_loss

    # ==== INTRINSIC LOSS ====
    if FLAGS.int.twoheaded or FLAGS.no_extrinsic_rewards:
        int_pg_loss, int_baseline_loss = losses.compute_actor_losses(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch["action"],
            discounts=discounts,
            rewards=intrinsic_rewards,
            values=learner_outputs["int_baseline"],
            bootstrap_value=learner_outputs["int_baseline"][-1],
            baseline_cost=FLAGS.int.baseline_cost,
        )
        stats.update(
            {
                "int_pg_loss": int_pg_loss.item(),
                "int_baseline_loss": int_baseline_loss.item(),
            }
        )
        total_loss += int_pg_loss + int_baseline_loss

    optimizer.zero_grad()
    total_loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 40.0)
    optimizer.step()
    scheduler.step()
    actor_model.load_state_dict(model.state_dict())
    global optimizer_steps

    # ==== LOG STATS ====
    optimizer_steps += 1
    episode_returns = batch["episode_return"][batch["done"]]
    episode_reward = batch["reward"][batch["done"]]
    goal_reached_rate = reached[intrinsic_done].float().mean().nan_to_num(0).item()

    stats.update(
        {
            "mean_episode_return": episode_returns.mean().nan_to_num(0).item(),
            "mean_episode_final_reward": episode_reward.mean().nan_to_num(0).item(),
            "intrinsic_rewards": intrinsic_rewards.mean().item(),
            "total_loss": total_loss.item(),
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
            "mean_subgoal_done": subgoal_done.float().mean().item(),
            "goal_achievable": goal_was_achievable,
            "goal_reached_rate": goal_reached_rate,
            "optimizer_steps": optimizer_steps,
            "lr": optimizer.param_groups[0]["lr"],
        }
    )
    
    # ==== OPTIMIZE DISCRIMINATOR ====
    if FLAGS.mutual_information:
        learn_discriminator(
            model,
            discriminator_optimizer,
            discriminator_scheduler,
            batch,
            stats,
        )

    # ==== OPTIMIZE GENERATOR (PLANNER) ====
    if FLAGS.generator:
        learn_generator(
            generator_model,
            actor_generator_model,
            generator_optimizer,
            generator_scheduler,
            batch,
            generator_current_target,
            max_steps,
            stats,
        )

    return stats


def get_mutual_information(model, batch, stats):
    """
    Calculate mutual information for the actor
    """
    with torch.no_grad():
        T, B = batch["goal"].shape
        discriminator_outputs = model.forward_discriminator(batch)
        discriminator_logits = discriminator_outputs["logits"]
        discriminator_targets = batch["goal"]
        
        discriminator_targets = rearrange(discriminator_targets, "T B -> (T B) 1")
        discriminator_dist = F.log_softmax(discriminator_logits, dim=-1)
        item1 = discriminator_dist.gather(1, discriminator_targets)
        
        generator_logits = rearrange(batch["generator_logits"], "T B N -> (T B) N")
        generator_dist = F.log_softmax(generator_logits, dim=-1)
        item2 = generator_dist.gather(1, discriminator_targets)
        
        mutual_information = item1 - item2
        
        mutual_information = mutual_information.view(T, B)
        
        stats["mutual_information"] = mutual_information.float().mean().item()
        
        return mutual_information
    
    
def learn_discriminator(
    model, discriminator_optimizer, discriminator_scheduler, batch, stats
):
    # Train the discriminator model
    discriminator_outputs = model.forward_discriminator(batch)
    discriminator_logits = discriminator_outputs["logits"]
    discriminator_targets = batch["goal"]
    discriminator_targets = rearrange(discriminator_targets, "T B -> (T B)")
    cross_entropy_loss = nn.CrossEntropyLoss()
    discriminator_loss = cross_entropy_loss(discriminator_logits, discriminator_targets)
    
    discriminator_preds = discriminator_outputs["preds"]
    discriminator_acc = (discriminator_preds == discriminator_targets).float().mean()
    stats["discriminator_acc"] = discriminator_acc.item()
    stats["discriminator_loss"] = discriminator_loss.item()
    
    discriminator_optimizer.zero_grad()
    discriminator_loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 40.0)
    discriminator_optimizer.step()
    discriminator_scheduler.step()
    global discriminator_optimizer_steps
    discriminator_optimizer_steps += 1
    stats["discriminator_optimizer_steps"] = discriminator_optimizer_steps
    stats["discriminator_lr"] = discriminator_optimizer.param_groups[0]["lr"]
        

def learn_generator(
    generator_model,
    actor_generator_model,
    generator_optimizer,
    generator_scheduler,
    batch,
    generator_current_target,
    max_steps,
    stats,
):
    global generator_batcher

    initial_keys = [k for k in batch.keys() if k.startswith("initial")]

    generator_model.update_goals(batch)
    stats["generator_goals_seen"] = generator_model.goals_mask.sum().item()

    generator_batch_keys = [
        "goal",
        "subgoal_achievable",
        "subgoal_done",
        "intrinsic_episode_step",
        "generator_logits",
        "reached",
        "reward",
        "carried_obj",
        "carried_col",
        *initial_keys,
    ]

    new_items = {k: batch[k] for k in generator_batch_keys}
    new_items = utils.map_dict(
        lambda x: x.split("initial_")[-1], new_items, map_keys=True
    )
    generator_batcher.append(
        new_items, mask=batch["intrinsic_done"], device=FLAGS.device
    )

    if generator_batcher.ready(FLAGS.generator_batch_size):
        generator_batch = generator_batcher.get_batch(FLAGS.generator_batch_size)
        generator_batch = utils.map_dict(lambda x: x.unsqueeze(0), generator_batch)
        learn_generator_policy(
            generator_model,
            generator_optimizer,
            generator_scheduler,
            generator_batch,
            generator_current_target,
            max_steps,
            stats,
        )

    actor_generator_model.load_state_dict(generator_model.state_dict())

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    return table, total_params

def train():
    """Full training loop."""
    checkpointpath = os.getcwd() + "/checkpoint"
    if not os.path.exists(checkpointpath):
        os.makedirs(checkpointpath)

    if FLAGS.num_buffers is None:
        FLAGS.num_buffers = max(4 * FLAGS.num_actors, 2 * FLAGS.batch_size)
    if FLAGS.num_actors >= FLAGS.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")

    T = FLAGS.unroll_length
    B = FLAGS.batch_size

    logging.info(f"Using device {FLAGS.device}")

    env = envs.create_env(FLAGS)

    (
        model,
        generator_model,
        learner_model,
        learner_generator_model,
        buffers,
    ) = models.create_models_and_buffers(env, FLAGS)

    model.share_memory()
    generator_model.share_memory()
    
    logging.info(f"{FLAGS.group}")
    logging.info(f"lr:{FLAGS.lr},MI:{FLAGS.mutual_information_rate}")
    tableS, paraS = count_parameters(model)
    logging.info(f"\n{tableS}")
    logging.info(f"The parameters of Actor {paraS}")
    logging.info("")
    
    tabelT, paraT = count_parameters(generator_model)
    logging.info(f"\n{tabelT}")
    logging.info(f"The parameters of Planner {paraT}")
    logging.info("")

    (
        optimizer,
        discriminator_optimizer,
        generator_optimizer,
        scheduler,
        discriminator_scheduler,
        generator_scheduler,
    ) = optimizers.create_optimizers(
        learner_model, learner_generator_model, FLAGS.total_frames, FLAGS
    )

    manager = mp.Manager()
    proposed_goals = {
        "all": manager.dict(),
        "achieved": manager.dict(),
        "achieved_steps": manager.dict(),
    }
    generator_current_target = manager.Value("i", int(FLAGS.generator_start_target))

    num_possible_goals = generator_model.logits_size
    for i in range(num_possible_goals):
        proposed_goals["all"][i] = 0
        proposed_goals["achieved"][i] = 0
        proposed_goals["achieved_steps"][i] = 0

    initial_agent_state_buffers = []
    for _ in range(FLAGS.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(FLAGS.num_actors):
        if FLAGS.debug:
            procfn = threading.Thread
        else:
            procfn = ctx.Process
        actor = procfn(
            target=act,
            args=(
                i,
                free_queue,
                full_queue,
                model,
                generator_model,
                buffers,
                initial_agent_state_buffers,
                proposed_goals["all"],
                proposed_goals["achieved"],
                proposed_goals["achieved_steps"],
                generator_current_target,
            ),
        )
        actor.start()
        actor_processes.append(actor)

    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
        "gen_rewards",
        "gg_loss",
        "generator_entropy_loss",
        "generator_baseline_loss",
        "mean_intrinsic_rewards",
        "mean_intrinsic_episode_steps",
        "ex_reward",
        "generator_current_target",
    ]
    logging.info("# Step\t{}".format("\t".join(stat_keys)))

    frames = 0
    stats = {}

    def batch_and_learn(i, generator_current_target):
        """Thread target for the learning process."""
        nonlocal frames, stats
        timings = prof.Timings()
        index = 0
        while frames < FLAGS.total_frames:
            timings.reset()
            batch, agent_state = get_batch(
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(
                model,
                learner_model,
                generator_model,
                learner_generator_model,
                batch,
                agent_state,
                optimizer,
                discriminator_optimizer,
                generator_optimizer,
                scheduler,
                discriminator_scheduler,
                generator_scheduler,
                generator_current_target,
                max_steps=env.max_steps if FLAGS.is_babyai else env._max_episode_steps,
            )

            timings.time("learn")
            with LOCKS["batch_and_learn"]:
                to_log = dict(frames=frames)
                to_log.update(stats)
                to_log = utils.filter_dict(
                    lambda k: k not in {"grounder_dist", "goal_dist"},
                    to_log,
                    filter_keys=True,
                )
                frames += T * B
            index += 1
            if FLAGS.verbose and index % 15 == 0:
                logging.info(f"Batch and learn {i}: {timings.summary()}")
                timings.reset()

    for m in range(FLAGS.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(FLAGS.num_threads):
        thread = threading.Thread(
            target=batch_and_learn,
            name="batch-and-learn-%d" % i,
            args=(i, generator_current_target),
        )
        thread.start()
        threads.append(thread)

    def checkpoint():
        if FLAGS.disable_checkpoint:
            return
        logging.info(f"Saving checkpoint to {checkpointpath}")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "generator_model_state_dict": generator_model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "discriminator_optimizer_state_dict": discriminator_optimizer.state_dict(),
                "generator_optimizer_state_dict": generator_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "generator_scheduler_state_dict": generator_scheduler.state_dict(),
                "FLAGS": vars(FLAGS),
            },
            f"{checkpointpath}/model.pth",
        )

    timer = timeit.default_timer
    template_plot_data = []
    template_plot_data_norm = []
    template_plot_data_steps = []
    all_templates_plot_norm = None
    achieved_templates_plot_norm = None
    achieved_templates_plot_step = None
    noveld_plot = None
    logged_templates_plot = True
    logged_templates_step_plot = True
    logged_noveld_plot = True
    percent_goals_achieved = None

    try:
        last_checkpoint_time = timer()
        last_template_time = timer()
        last_plot_update_time = timer()

        while frames < FLAGS.total_frames:
            frame_interval = frames
            time_interval = timer()
            if FLAGS.debug:
                sleep_time = 5 * 60
            else:
                sleep_time = 5
            time.sleep(sleep_time)
            this_frames = frames
            if timer() - last_checkpoint_time > 10 * 60:
                checkpoint()
                last_checkpoint_time = timer()

            if timer() - last_template_time > 1 * 60:
                proposed_goals_templates = {"all": Counter(), "achieved": Counter(), "achieved_steps": Counter(),}
                for i, template in enumerate(
                    learner_generator_model.lang_templates
                ):
                    proposed_goals_templates["all"][template] += proposed_goals["all"][i]
                    proposed_goals_templates["achieved"][template] += proposed_goals["achieved"][i]
                    proposed_goals_templates["achieved_steps"][template] += proposed_goals["achieved_steps"][i]
                    
                    proposed_goals["all"][i] = 0
                    proposed_goals["achieved"][i] = 0
                    proposed_goals["achieved_steps"][i] = 0

                for template in proposed_goals_templates["all"]:
                    n_all = proposed_goals_templates["all"][template]
                    n_achieved = proposed_goals_templates["achieved"][template]
                    n_achieved_steps = proposed_goals_templates["achieved_steps"][template]
                    if n_all == 0:
                        continue
                    template_plot_data.append(
                        {
                            "num_frames": this_frames,
                            "n_all": n_all,
                            "n_achieved": n_achieved,
                            "template": template,
                        }
                    )
                    if n_achieved != 0:
                        n_mean_achieved_steps = n_achieved_steps / n_achieved
                        template_plot_data_steps.append(
                            {
                                "num_frames": this_frames,
                                "n_mean_achieved_steps": n_mean_achieved_steps,
                                "template": template,
                            }
                        )

                percent_goals_achieved = sum(
                    proposed_goals_templates["achieved"].values()
                ) / (sum(proposed_goals_templates["all"].values()) + 1e-10)

                all_sum = sum(proposed_goals_templates["all"].values())
                achieved_sum = sum(proposed_goals_templates["achieved"].values())
                proposed_goals_templates["all"] = utils.map_dict(
                    lambda v: v / (all_sum + 1e-10), proposed_goals_templates["all"]
                )
                proposed_goals_templates["achieved"] = utils.map_dict(
                    lambda v: v / (achieved_sum + 1e-10),
                    proposed_goals_templates["achieved"],
                )

                for template in proposed_goals_templates["all"]:
                    n_all = proposed_goals_templates["all"][template]
                    n_achieved = proposed_goals_templates["achieved"][template]
                    if n_all == 0:
                        continue
                    template_plot_data_norm.append(
                        {
                            "num_frames": this_frames,
                            "n_all": n_all,
                            "n_achieved": n_achieved,
                            "template": template,
                        }
                    )

                last_template_time = timer()

            if timer() - last_plot_update_time > 5 * 60:
                if FLAGS.wandb:
                    templates_df_norm = pd.DataFrame(template_plot_data_norm)
                    if len(templates_df_norm.columns) != 0:
                        all_templates_plot_norm = px.line(
                            templates_df_norm,
                            x="num_frames",
                            y="n_all",
                            color="template",
                        )
                        achieved_templates_plot_norm = px.line(
                            templates_df_norm,
                            x="num_frames",
                            y="n_achieved",
                            color="template",
                        )
                        logged_templates_plot = False
                        
                    templates_df_steps = pd.DataFrame(template_plot_data_steps)
                    if len(templates_df_steps.columns) != 0:
                        achieved_templates_plot_step = px.line(
                            templates_df_steps,
                            x="num_frames",
                            y="n_mean_achieved_steps",
                            color="template",
                        )
                        logged_templates_step_plot = False

                last_plot_update_time = timer()

            fps = (frames - frame_interval) / (timer() - time_interval)
            if stats.get("episode_returns", None):
                mean_return = (
                    "Return per episode: %.1f. " % stats["mean_episode_return"]
                )
            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            stats_not_none = utils.filter_dict(
                lambda v: v is not None and not isinstance(v, tuple),
                stats,
            )
            stats_not_none = utils.filter_dict(
                lambda k: k not in {"novelty_records"},
                stats_not_none,
                filter_keys=True,
            )
            if FLAGS.wandb:
                metrics_to_log = {
                    "num_frames": frames,
                    "fps": fps,
                    **stats_not_none,
                    "percent_goals_achieved": percent_goals_achieved,
                }
                if not logged_templates_plot:
                    metrics_to_log.update(
                        {
                            "all_templates_norm": all_templates_plot_norm,
                            "achieved_templates_norm": achieved_templates_plot_norm,
                        }
                    )
                if not logged_templates_step_plot:
                    metrics_to_log.update(
                        {
                            "achieved_templates_steps": achieved_templates_plot_step,
                        }
                    )
                if not logged_noveld_plot:
                    metrics_to_log.update(
                        {
                            "noveld_plot": noveld_plot,
                        }
                    )
                wandb.log(metrics_to_log, step=frames)
                logged_templates_plot = True
                logged_templates_step_plot = True
                logged_noveld_plot = True

            logging.info(
                f"After {frames} frames: loss {total_loss:f} @ {fps:.1f} fps. {mean_return}Stats: {pprint.pformat(stats_not_none)}"
            )
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    except Exception:
        import traceback

        logging.info("Got exception in main process, exiting")
        logging.info(traceback.format_exc())
    else:
        for thread in threads:
            thread.join()
        logging.info(f"Learning finished after {frames} frames.")
    finally:
        for _ in range(FLAGS.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    return frames, model.state_dict()


def is_babyai(env_name):
    return env_name.startswith("BabyAI")


OmegaConf.register_new_resolver("is_babyai", is_babyai)
OmegaConf.register_new_resolver("uid", lambda: coolname.generate_slug(3))


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
    print("Working directory : {}".format(os.getcwd()))
    global FLAGS
    FLAGS = cfg
    
    if FLAGS.wandb:
        wandb.init(
            mode="offline",
            project=str(FLAGS.project),
            group=str(FLAGS.group),
            name=f"lr:{FLAGS.lr},MI:{FLAGS.mutual_information_rate}",
            config=vars(FLAGS),
        )

    train()


if __name__ == "__main__":
    main()