import argparse
import os
import pickle
import math
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import itertools
from functools import partial
import random
import collections
from torch.utils.data import DataLoader, ConcatDataset, Subset, Dataset, IterableDataset
import numpy as np
from positional_encodings.torch_encodings import get_emb, PositionalEncoding1D
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.loggers.logger import rank_zero_experiment
from tqdm.auto import tqdm, trange
from torch.utils.data import Sampler
import sys
import csv


def recursive_mod(sequence, depth, func):
    if depth == 0:
        return func(sequence)

    return [recursive_mod(subseq, depth - 1, func) for subseq in sequence]


def pad_subsequence_to(subsequence, length, pad):
    pad_width = tuple(
        [(0, length - min(len(subsequence), length))]
        + [(0, 0)] * (subsequence.ndim - 1)
    )
    padded = np.pad(
        subsequence[:length], pad_width, mode="constant", constant_values=pad
    )
    return padded


def pad_to(sequence, length, pad=-1):
    if length is None:
        return sequence

    length = (length,) if isinstance(length, int) else length
    # pad_width = [(0, l - sequence.shape[i]) for i, l in enumerate(length)]

    # First pad the sequences on the last dimension
    dim_idx = len(length)

    while dim_idx != 0:
        dim_idx -= 1

        sequence = recursive_mod(
            sequence, dim_idx, lambda x: pad_subsequence_to(x, length[dim_idx], pad)
        )

        if dim_idx != 0:
            sequence = recursive_mod(sequence, dim_idx - 1, lambda x: np.stack(x))

    return np.stack(sequence)

    trim_seq = sequence[tuple([slice(None, l, None) for l in length])]

    return np.pad(trim_seq, pad_width, mode="constant", constant_values=pad)


class SamplingDataset(IterableDataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __iter__(self):
        return self

    def __next__(self):
        dp = self.dataset[np.random.randint(len(self.dataset))]
        return dp


class MapDataset(Dataset):
    def __init__(self, dataset, func):
        super().__init__()
        self.dataset = dataset
        self.func = func

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.func(self.dataset[idx])


class IterMapDataset(IterableDataset):
    def __init__(self, iter_dataset, func):
        super().__init__()
        self.iter_dataset = iter_dataset
        self.func = func
        self.iter = None

    def __iter__(self):
        self.iter = iter(self.iter_dataset)
        return self

    def __next__(self):
        return self.func(next(self.iter))


class IterComposeDataset(IterableDataset):
    def __init__(self, iter_dataset, instance_generator_func):
        super().__init__()
        self.iter_dataset = iter_dataset
        self.instance_generator_func = instance_generator_func
        self.iter_dataset_iter = None
        self.instance_generator_iter = None

    def __iter__(self):
        self.iter_dataset_iter = iter(self.iter_dataset)
        return self

    def __next__(self):
        if self.instance_generator_iter is not None:
            try:
                next_item = next(self.instance_generator_iter)
            except StopIteration:
                self.instance_generator_iter = iter(
                    self.instance_generator_func(next(self.iter_dataset_iter))
                )
                next_item = next(self.instance_generator_iter)
        else:
            self.instance_generator_iter = iter(
                self.instance_generator_func(next(self.iter_dataset_iter))
            )
            next_item = next(self.instance_generator_iter)

        return next_item


class PaddingDataset(Dataset):
    def __init__(self, dataset, paddings, pad_values):
        super().__init__()
        self.dataset = dataset
        self.paddings = paddings
        self.pad_values = pad_values

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        item = self.dataset[i]

        # return recursive_pad_array(item, self.paddings, pad_value=self.pad_values)

        if isinstance(i, np.ndarray):
            items = self.dataset[i]
            return tuple(
                [
                    pad_to(a, p, v)
                    for a, p, v in zip(
                        items, [self.paddings] * items.shape[0], self.pad_values
                    )
                ]
            )
        else:
            item = self.dataset[i]
            return tuple(
                [
                    pad_to(a, p, v)
                    for a, p, v in zip(item, self.paddings, self.pad_values)
                ]
            )


class PaddingIterableDataset(IterableDataset):
    def __init__(self, dataset, paddings, pad_values):
        super().__init__()
        self.dataset = dataset
        self.paddings = paddings
        self.pad_values = pad_values
        self.iterable = None

    def __iter__(self):
        self.iterable = iter(self.dataset)
        return self

    def __next__(self):
        item = next(self.iterable)
        return recursive_pad_array(item, self.paddings, pad_value=self.pad_values)


class BatchShuffleIndexSampler(Sampler):
    def __init__(self, dataset, batch_size, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self):
        return len(self.dataset) // self.batch_size

    def __iter__(self):
        indices = np.arange(0, len(self.dataset))

        if self.shuffle:
            np.random.shuffle(indices)

        for i in range(0, len(indices) // self.batch_size):
            yield self.dataset[indices[i * self.batch_size : (i + 1) * self.batch_size]]


def extract_padding(x, pad_val=-1, pad_replace=0):
    mask = x == pad_val
    x = x.clone()
    x[mask] = pad_replace

    return x, mask


def recursive_pad_array(item, max_lengths, pad_value):
    if max_lengths == None:
        return item

    if isinstance(item, np.ndarray):
        assert isinstance(max_lengths, int)
        return pad_to(item, max_lengths, pad=pad_value)
    elif isinstance(item, collections.abc.Mapping):
        return type(item)(
            {
                k: recursive_pad_array(
                    item[k],
                    max_lengths,
                    pad_value[k]
                    if isinstance(pad_value, collections.abc.Mapping)
                    else pad_value,
                )
                for k in item
            }
        )
    elif isinstance(item, collections.abc.Sequence):
        assert isinstance(max_lengths, collections.abc.Sequence)
        return type(item)(
            [
                recursive_pad_array(
                    e,
                    l,
                    pad_value[i]
                    if isinstance(pad_value, collections.abc.Sequence)
                    else pad_value,
                )
                for i, (e, l) in enumerate(zip(item, max_lengths))
            ]
        )
    else:
        return item


def split_dataset(dataset, pct=0.01):
    indices = np.arange(len(dataset), dtype=int)
    np.random.shuffle(indices)
    train = indices[: -int(len(dataset) * pct)]
    test = indices[-int(len(dataset) * pct) :]
    return [dataset[x] for x in train], [dataset[x] for x in test]


def load_pickle_file(path):
    with open(path, "rb") as f:
        print(f"Loading {path}")
        return pickle.load(f)


def load_data(
    train_meta_trajectories_path, valid_trajectories_directory, dictionary_path
):
    meta_train_demonstrations = load_pickle_file(train_meta_trajectories_path)
    np.random.shuffle(meta_train_demonstrations)

    valid_trajectories_dict = {
        os.path.splitext(fname)[0]: load_pickle_file(
            os.path.join(valid_trajectories_directory, fname)
        )
        for fname in sorted(os.listdir(valid_trajectories_directory))
    }

    with open(dictionary_path, "rb") as f:
        WORD2IDX, ACTION2IDX, color_dictionary, noun_dictionary = pickle.load(f)

    return (
        (
            WORD2IDX,
            ACTION2IDX,
            color_dictionary,
            noun_dictionary,
        ),
        (meta_train_demonstrations, valid_trajectories_dict),
    )


def initialize_parameters(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        m.weight.data.normal_(0, 1)
        m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
        if m.bias is not None:
            m.bias.data.fill_(0)


class BOWEmbedding(nn.Module):
    def __init__(self, max_value, n_channels, embedding_dim):
        super().__init__()
        self.max_value = max_value
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(n_channels * max_value, embedding_dim)
        self.n_channels = n_channels
        self.apply(initialize_parameters)

    def forward(self, inputs):
        flat_inputs = inputs.flatten(0, -2)

        offsets = torch.Tensor([i * self.max_value for i in range(self.n_channels)]).to(
            inputs.device
        )
        offsetted = (flat_inputs + offsets[None, :]).long()
        each_embedding = self.embedding(offsetted)
        each_embedding_flat = each_embedding.flatten(-2, -1)

        return each_embedding_flat.unflatten(0, inputs.shape[:-1])


class StateEncoderDecoderTransformer(nn.Module):
    def __init__(
        self,
        n_state_components,
        input_size,
        embedding_dim,
        nlayers,
        dropout_p,
        pad_word_idx,
        bidirectional=False,
    ):
        super().__init__()
        self.n_state_components = n_state_components
        self.embedding_dim = embedding_dim
        self.state_embedding = BOWEmbedding(
            64, self.n_state_components, self.embedding_dim
        )
        self.state_projection = nn.Linear(
            self.n_state_components * self.embedding_dim, self.embedding_dim
        )
        self.embedding = nn.Embedding(input_size, embedding_dim)
        self.embedding_projection = nn.Linear(embedding_dim * 2, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.encoding = nn.Parameter(torch.randn(embedding_dim))
        self.pos_encoding = PositionalEncoding1D(embedding_dim)
        self.bi = False
        self.pad_word_idx = pad_word_idx
        self.transformer = nn.Transformer(
            d_model=embedding_dim,
            nhead=4,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout_p,
            num_encoder_layers=nlayers,
            num_decoder_layers=nlayers,
        )

    def forward(self, state_padded, z_padded):
        state_padding_bits = torch.zeros_like(state_padded[..., 0])
        z_padding_bits = z_padded == self.pad_word_idx

        state_embed_seq = self.state_projection(self.state_embedding(state_padded))

        z_embed_seq = self.embedding(z_padded)
        z_embed_seq = torch.cat([self.pos_encoding(z_embed_seq), z_embed_seq], dim=-1)
        z_embed_seq = self.embedding_projection(z_embed_seq)
        state_embed_seq = self.dropout(state_embed_seq)
        z_embed_seq = self.dropout(z_embed_seq)
        padding_bits = z_padding_bits

        encoded_seq = self.transformer(
            state_embed_seq.transpose(1, 0),
            z_embed_seq.transpose(1, 0),
            tgt_key_padding_mask=padding_bits,
        )

        # bs x emb_dim, z_seq_len x bs x emb_dim
        return encoded_seq


class StateEncoderTransformer(nn.Module):
    def __init__(
        self,
        n_state_components,
        input_size,
        embedding_dim,
        nlayers,
        dropout_p,
        pad_word_idx,
        bidirectional=False,
    ):
        super().__init__()
        self.n_state_components = n_state_components
        self.embedding_dim = embedding_dim
        self.state_embedding = BOWEmbedding(
            64, self.n_state_components, self.embedding_dim
        )
        self.state_projection = nn.Linear(
            self.n_state_components * self.embedding_dim, self.embedding_dim
        )
        self.embedding = nn.Embedding(input_size, embedding_dim)
        self.embedding_projection = nn.Linear(embedding_dim * 2, embedding_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.encoding = nn.Parameter(torch.randn(embedding_dim))
        self.pos_encoding = PositionalEncoding1D(embedding_dim)
        self.bi = False
        self.pad_word_idx = pad_word_idx
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embedding_dim,
                nhead=4,
                dim_feedforward=embedding_dim * 4,
                dropout=dropout_p,
            ),
            num_layers=nlayers,
        )

    def forward(self, state_padded, z_padded):
        state_padding_bits = torch.zeros_like(state_padded[..., 0]).bool()
        z_padding_bits = z_padded == self.pad_word_idx

        state_embed_seq = self.state_projection(self.state_embedding(state_padded))

        z_embed_seq = self.embedding(z_padded)
        z_embed_seq = torch.cat([self.pos_encoding(z_embed_seq), z_embed_seq], dim=-1)
        z_embed_seq = self.embedding_projection(z_embed_seq)
        state_embed_seq = self.dropout(state_embed_seq)
        z_embed_seq = self.dropout(z_embed_seq)

        z_embed_seq = torch.cat([state_embed_seq, z_embed_seq], dim=1)
        padding_bits = torch.cat([state_padding_bits, z_padding_bits], dim=-1)

        encoded_seq = self.transformer_encoder(
            z_embed_seq.transpose(1, 0),
            src_key_padding_mask=padding_bits,
        )

        # bs x emb_dim, z_seq_len x bs x emb_dim
        return encoded_seq, padding_bits


class EncoderDecoderTransformer(nn.Module):
    #
    # One-step batch LSTM decoder with Luong et al. attention
    #
    def __init__(
        self,
        n_state_components,
        hidden_size,
        output_size,
        nlayers,
        pad_action_idx,
        dropout_p=0.1,
    ):
        #
        # Input
        #  hidden_size : number of hidden units in RNN, and embedding size for output symbols
        #  output_size : number of output symbols
        #  nlayers : number of hidden layers
        #  dropout_p : dropout applied to symbol embeddings and RNNs
        #
        super().__init__()
        self.n_state_components = n_state_components
        self.nlayers = nlayers
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.tanh = nn.Tanh()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.embedding_projection = nn.Linear(hidden_size * 2, hidden_size)
        self.pos_encoding = PositionalEncoding1D(hidden_size)
        self.dropout = nn.Dropout(dropout_p)
        self.pad_action_idx = pad_action_idx
        self.transformer = nn.Transformer(
            d_model=hidden_size,
            dim_feedforward=hidden_size * 4,
            dropout=dropout_p,
            nhead=4,
            num_encoder_layers=nlayers,
            num_decoder_layers=nlayers,
        )
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, inputs, encoder_outputs, encoder_padding):
        # Run batch decoder forward for a single time step.
        #
        # Input
        #  input: LongTensor of length batch_size x seq_len (left-shifted targets)
        #  memory: encoder state
        #
        # Output
        #   output : unnormalized output probabilities, batch_size x output_size
        #
        # Embed each input symbol
        # state, state_padding_bits = extract_padding(state)
        input_padding_bits = inputs == self.pad_action_idx

        # state_embedding = self.state_embedding(state)
        # state_embedding = state_embedding + self.pos_encoding(state_embedding)
        # state_embedding = self.dropout(state_embedding)

        embedding = self.embedding(inputs)  # batch_size x hidden_size
        embedding = self.embedding_projection(
            torch.cat([embedding, self.pos_encoding(embedding)], dim=-1)
        )
        embedding = self.dropout(embedding)

        decoded = self.transformer(
            torch.cat(
                [
                    # state_embedding.transpose(0, 1),
                    encoder_outputs,
                ],
                dim=0,
            ),
            embedding.transpose(0, 1),
            src_key_padding_mask=torch.cat(
                [
                    # state_padding_bits,
                    encoder_padding,
                ],
                dim=-1,
            ),
            tgt_key_padding_mask=input_padding_bits,
            tgt_mask=torch.triu(
                torch.full((inputs.shape[-1], inputs.shape[-1]), float("-inf")),
                diagonal=1,
            ).to(inputs.device),
        ).transpose(0, 1)

        return self.out(decoded), encoder_padding


class DecoderTransformer(nn.Module):
    #
    # One-step batch LSTM decoder with Luong et al. attention
    #
    def __init__(
        self,
        n_state_components,
        hidden_size,
        output_size,
        nlayers,
        pad_action_idx,
        dropout_p=0.1,
    ):
        #
        # Input
        #  hidden_size : number of hidden units in RNN, and embedding size for output symbols
        #  output_size : number of output symbols
        #  nlayers : number of hidden layers
        #  dropout_p : dropout applied to symbol embeddings and RNNs
        #
        super().__init__()
        self.n_state_components = n_state_components
        self.nlayers = nlayers
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.tanh = nn.Tanh()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.embedding_projection = nn.Linear(hidden_size * 2, hidden_size)
        self.pos_encoding = PositionalEncoding1D(hidden_size)
        self.dropout = nn.Dropout(dropout_p)
        self.pad_action_idx = pad_action_idx
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=hidden_size,
                dim_feedforward=hidden_size * 4,
                dropout=dropout_p,
                nhead=4,
            ),
            num_layers=nlayers,
        )
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, inputs, encoder_outputs, encoder_padding):
        # Run batch decoder forward for a single time step.
        #
        # Input
        #  input: LongTensor of length batch_size x seq_len (left-shifted targets)
        #  memory: encoder state
        #
        # Output
        #   output : unnormalized output probabilities, batch_size x output_size
        #
        # Embed each input symbol
        # state, state_padding_bits = extract_padding(state)
        input_padding_bits = inputs == self.pad_action_idx

        # state_embedding = self.state_embedding(state)
        # state_embedding = state_embedding + self.pos_encoding(state_embedding)
        # state_embedding = self.dropout(state_embedding)

        embedding = self.embedding(inputs)  # batch_size x hidden_size
        embedding = self.embedding_projection(
            torch.cat([embedding, self.pos_encoding(embedding)], dim=-1)
        )
        embedding = self.dropout(embedding)

        decoded = self.decoder(
            tgt=embedding.transpose(0, 1),
            memory=torch.cat(
                [
                    # state_embedding.transpose(0, 1),
                    encoder_outputs,
                ],
                dim=0,
            ),
            memory_key_padding_mask=torch.cat(
                [
                    # state_padding_bits,
                    encoder_padding,
                ],
                dim=-1,
            ),
            tgt_key_padding_mask=input_padding_bits,
            tgt_mask=torch.triu(
                torch.full((inputs.shape[-1], inputs.shape[-1]), float("-inf")),
                diagonal=1,
            ).to(inputs.device),
        ).transpose(0, 1)

        return self.out(decoded)


def linear_with_warmup_schedule(
    optimizer, num_warmup_steps, num_training_steps, min_lr_scale, last_epoch=-1
):
    min_lr_logscale = min_lr_scale

    def lr_lambda(current_step):
        # Scale from 0 to 1
        if current_step <= num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))

        # Scale from 1 to min_lr_scale logarithmically
        #
        # So for example, if min_lr_logscale is -3, then
        # scale goes from 0 to -3 meaning that the lr multiplier
        # goes from 1, to 1e-1 at -1, to 1e-2 at -2 to 1e-3 at -3.
        scale = min(
            1,
            float(current_step - num_warmup_steps)
            / float(num_training_steps - num_warmup_steps),
        )
        logscale = scale * min_lr_logscale
        multiplier = 10**logscale

        return multiplier

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)


def transformer_optimizer_config(
    harness, lr, warmup_proportion=0.14, decay_power=-2, weight_decay=0
):
    optimizer = optim.AdamW(harness.parameters(), lr=lr, weight_decay=weight_decay)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": linear_with_warmup_schedule(
                optimizer,
                harness.trainer.max_steps * warmup_proportion,
                harness.trainer.max_steps,
                decay_power,
            ),
            "interval": "step",
            "frequency": 1,
        },
    }


class TransformerLearner(pl.LightningModule):
    def __init__(
        self,
        n_state_components,
        x_categories,
        y_categories,
        embed_dim,
        dropout_p,
        nlayers,
        pad_word_idx,
        pad_action_idx,
        sos_action_idx,
        eos_action_idx,
        lr=1e-4,
        wd=1e-2,
        warmup_proportion=0.001,
        decay_power=-1,
        predict_steps=64,
    ):
        super().__init__()
        self.encoder = StateEncoderTransformer(
            n_state_components,
            x_categories,
            embed_dim,
            nlayers,
            dropout_p,
            pad_word_idx,
        )
        self.decoder = DecoderTransformer(
            n_state_components,
            embed_dim,
            y_categories,
            nlayers,
            pad_action_idx,
            dropout_p,
        )
        self.y_categories = y_categories
        self.pad_word_idx = pad_word_idx
        self.pad_action_idx = pad_action_idx
        self.sos_action_idx = sos_action_idx
        self.eos_action_idx = eos_action_idx
        self.save_hyperparameters()

    def configure_optimizers(self):
        return transformer_optimizer_config(
            self,
            self.hparams.lr,
            warmup_proportion=self.hparams.warmup_proportion,
            weight_decay=self.hparams.wd,
            decay_power=self.hparams.decay_power,
        )

    def encode(self, states, queries):
        return self.encoder(states, queries)

    def decode_autoregressive(self, decoder_in, encoder_outputs, encoder_padding):
        return self.decoder(decoder_in, encoder_outputs, encoder_padding)

    def decode_recursive(self, query_state, encoder_outputs, encoder_padding):
        # Recursive decoding, start with a batch of SOS tokens
        decoder_in = torch.tensor(
            self.sos_action_idx, dtype=torch.long, device=self.device
        )[None].expand(query_state.shape[0], 1)

        logits = []

        with torch.no_grad():
            for i in range(self.hparams.predict_steps):
                logits.append(
                    self.decode_autoregressive(
                        query_state, decoder_in, encoder_outputs, encoder_padding
                    )[:, -1]
                )
                decoder_out = logits[-1].argmax(dim=-1)
                decoder_in = torch.cat([decoder_in, decoder_out[:, None]], dim=1)

            decoded_eq_mask = (
                (decoder_in == self.eos_action_idx).int().cumsum(dim=-1).bool()[:, :-1]
            )
            decoded = decoder_in[:, 1:]
            decoded[decoded_eq_mask] = -1
            logits = torch.stack(logits, dim=1)

        return logits

    def forward(self, states, queries, decoder_in):
        encoded, encoder_padding = self.encoder(states, queries)
        return self.decode_autoregressive(decoder_in, encoded, encoder_padding)

    def training_step(self, x, idx):
        query, targets, state = x
        actions_mask = targets == self.pad_action_idx

        decoder_in = torch.cat(
            [torch.ones_like(targets)[:, :1] * self.sos_action_idx, targets], dim=-1
        )

        # Now do the training
        preds = self.forward(state, query, decoder_in)[:, :-1]

        # Ultimately we care about the cross entropy loss
        loss = F.cross_entropy(
            preds.flatten(0, -2),
            targets.flatten().long(),
            ignore_index=self.pad_action_idx,
        )

        argmax_preds = preds.argmax(dim=-1)
        argmax_preds[actions_mask] = self.pad_action_idx
        exacts = (argmax_preds == targets).all(dim=-1).to(torch.float).mean()

        self.log("tloss", loss, prog_bar=True)
        self.log("texact", exacts, prog_bar=True)
        self.log(
            "tacc",
            (preds.argmax(dim=-1)[~actions_mask] == targets[~actions_mask])
            .float()
            .mean(),
            prog_bar=True,
        )

        return loss

    def validation_step(self, x, idx, dl_idx=0):
        query, targets, state = x
        actions_mask = targets == self.pad_action_idx

        decoder_in = torch.cat(
            [torch.ones_like(targets)[:, :1] * self.sos_action_idx, targets], dim=-1
        )

        # Now do the training
        preds = self.forward(state, query, decoder_in)[:, :-1]

        # Ultimately we care about the cross entropy loss
        loss = F.cross_entropy(
            preds.flatten(0, -2),
            targets.flatten().long(),
            ignore_index=self.pad_action_idx,
        )

        argmax_preds = preds.argmax(dim=-1)
        argmax_preds[actions_mask] = self.pad_action_idx
        exacts = (argmax_preds == targets).all(dim=-1).to(torch.float).mean()

        self.log("vloss", loss, prog_bar=True)
        self.log("vexact", exacts, prog_bar=True)
        self.log(
            "vacc",
            (preds.argmax(dim=-1)[~actions_mask] == targets[~actions_mask])
            .float()
            .mean(),
            prog_bar=True,
        )

    def predict_step(self, x, idx, dl_idx=0):
        instruction, target, state = x[:3]

        encodings, key_padding_mask = self.encode(state, instruction)

        # Recursive decoding, start with a batch of SOS tokens
        decoder_in = torch.tensor(
            self.sos_action_idx, dtype=torch.long, device=self.device
        )[None].expand(instruction.shape[0], 1)

        logits = []

        with torch.no_grad():
            for i in range(target.shape[1]):
                logits.append(
                    self.decode_autoregressive(decoder_in, encodings, key_padding_mask)[
                        :, -1
                    ]
                )
                decoder_out = logits[-1].argmax(dim=-1)
                decoder_in = torch.cat([decoder_in, decoder_out[:, None]], dim=1)

            decoded = decoder_in
            # these are shifted off by one
            decoded_eq_mask = (
                (decoded == self.eos_action_idx).int().cumsum(dim=-1).bool()[:, :-1]
            )
            decoded = decoded[:, 1:]
            decoded[decoded_eq_mask] = -1
            logits = torch.stack(logits, dim=1)

        exacts = (decoded == target).all(dim=-1).cpu().numpy()

        decoded = decoded.cpu().numpy()
        decoded_select_mask = decoded != -1
        decoded = [d[m] for d, m in zip(decoded, decoded_select_mask)]

        target = target.cpu().numpy()
        target = [d[d != -1] for d in target]

        instruction = instruction.cpu().numpy()
        instruction = [i[i != -1] for i in instruction]

        logits = logits.cpu().numpy()
        logits = [l[m] for l, m in zip(logits, decoded_select_mask)]

        return tuple([state, instruction, decoded, logits, exacts, target] + x[3:])


def get_most_recent_version(experiment_dir):
    versions = os.listdir(os.path.join(experiment_dir, "lightning_logs"))
    versions = [
        v
        for v in versions
        if os.path.isdir(
            os.path.join(experiment_dir, "lightning_logs", v, "checkpoints")
        )
    ]
    return sorted(versions, key=lambda x: int(x.split("_")[1]))[-1]


class LoadableCSVLogger(CSVLogger):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_train_start(self, trainer, pl_module):
        if self.metrics:
            trainer.callback_metrics = self.metrics[-1]

    @property
    @rank_zero_experiment
    def experiment(self):
        if self._experiment is None:
            load_csv = True
        else:
            load_csv = False

        experiment = super().experiment

        if load_csv:
            try:
                with open(experiment.metrics_file_path, "r", newline="") as f:
                    reader = csv.DictReader(f)
                    experiment.metrics = list(reader)
                    print(
                        f"Restored CSV ({len(experiment.metrics)} lines to step {experiment.metrics[-1]['step']}) logs from {experiment.metrics_file_path}"
                    )
            except IOError:
                print(f"No csv log files to restore")

        return experiment


class ReshuffleOnIndexZeroDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.indices = torch.randperm(len(dataset))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        if i == 0:
            self.indices = torch.randperm(len(self.dataset))

        return self.dataset[self.indices[i]]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train-demonstrations", type=str, required=True)
    parser.add_argument("--valid-demonstrations-directory", type=str, required=True)
    parser.add_argument("--dictionary", type=str, required=True)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--train-batch-size", type=int, default=64)
    parser.add_argument("--valid-batch-size", type=int, default=128)
    parser.add_argument("--batch-size-mult", type=int, default=16)
    parser.add_argument("--hidden-size", type=int, default=128)
    parser.add_argument("--nlayers", type=int, default=8)
    parser.add_argument("--dropout-p", type=float, default=0.0)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--wd", type=float, default=1e-2)
    parser.add_argument("--warmup-proportion", type=float, default=0.1)
    parser.add_argument("--decay-power", type=int, default=-1)
    parser.add_argument("--iterations", type=int, default=2500000)
    parser.add_argument("--check-val-every", type=int, default=1000)
    parser.add_argument("--enable-progress", action="store_true")
    parser.add_argument("--restore-from-checkpoint", action="store_true")
    parser.add_argument("--version", type=int, default=None)
    parser.add_argument("--tag", type=str, default="none")
    args = parser.parse_args()

    exp_name = "gscan"
    model_name = "transformer_encoder_only_decode_actions"
    dataset_name = "gscan"
    effective_batch_size = args.train_batch_size * args.batch_size_mult
    exp_name = f"{exp_name}_s_{args.seed}_m_{model_name}_it_{args.iterations}_b_{effective_batch_size}_d_gscan_t_{args.tag}"
    model_dir = f"models/{exp_name}/{model_name}"
    model_path = f"{model_dir}/{exp_name}.pt"
    print(model_path)
    print(
        f"Batch size {args.train_batch_size}, mult {args.batch_size_mult}, total {args.train_batch_size * args.batch_size_mult}"
    )

    os.makedirs(model_dir, exist_ok=True)

    if os.path.exists(f"{model_path}"):
        print(f"Skipping {model_path} as it already exists")
        return

    seed = args.seed
    iterations = args.iterations

    pl.seed_everything(seed)

    (
        (
            WORD2IDX,
            ACTION2IDX,
            color_dictionary,
            noun_dictionary,
        ),
        (train_demonstrations, valid_demonstrations_dict),
    ) = load_data(
        args.train_demonstrations, args.valid_demonstrations_directory, args.dictionary
    )

    IDX2WORD = {i: w for w, i in WORD2IDX.items()}
    IDX2ACTION = {i: w for w, i in ACTION2IDX.items()}

    pad_word = WORD2IDX["[pad]"]
    pad_action = ACTION2IDX["[pad]"]
    sos_action = ACTION2IDX["[sos]"]
    eos_action = ACTION2IDX["[eos]"]

    train_dataset = ReshuffleOnIndexZeroDataset(
        PaddingDataset(
            train_demonstrations,
            (8, 72, None),
            (pad_word, pad_action, None),
        )
    )

    pl.seed_everything(seed)
    meta_module = TransformerLearner(
        7,
        len(IDX2WORD),
        len(IDX2ACTION),
        args.hidden_size,
        args.dropout_p,
        args.nlayers,
        pad_word,
        pad_action,
        sos_action,
        eos_action,
        lr=args.lr,
        decay_power=args.decay_power,
        warmup_proportion=args.warmup_proportion,
    )
    print(meta_module)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        pin_memory=True,
    )

    check_val_opts = {}
    interval = args.check_val_every / len(train_dataloader)

    # Every check_val_interval steps, regardless of how large the training dataloader is
    if interval > 1.0:
        check_val_opts["check_val_every_n_epoch"] = math.floor(interval)
    else:
        check_val_opts["val_check_interval"] = interval

    checkpoint_cb = ModelCheckpoint(
        monitor="vexact/dataloader_idx_0",
        auto_insert_metric_name=False,
        save_top_k=5,
        mode="max",
    )

    logs_root_dir = f"logs/{exp_name}/{model_name}/{dataset_name}/{seed}"

    if args.restore_from_checkpoint:
        most_recent_version = get_most_recent_version(logs_root_dir)
        most_recent_version_path = os.path.join(
            logs_root_dir, "lightning_logs", most_recent_version, "checkpoints"
        )
        most_recent_checkpoint = sorted(
            [
                (int(os.path.splitext(c)[0].split("-")[-1]), c)
                for c in os.listdir(most_recent_version_path)
            ],
            key=lambda x: x[0],
        )[-1][-1]
        checkpoint_path = os.path.join(most_recent_version_path, most_recent_checkpoint)
        print(f"Restore checkpoint {checkpoint_path}")
        restore_checkpoint_opts = {"ckpt_path": checkpoint_path}
    else:
        most_recent_version = args.version
        restore_checkpoint_opts = {}

    trainer = pl.Trainer(
        logger=[
            TensorBoardLogger(logs_root_dir, version=most_recent_version),
            LoadableCSVLogger(
                logs_root_dir, version=most_recent_version, flush_logs_every_n_steps=10
            ),
        ],
        callbacks=[pl.callbacks.LearningRateMonitor(), checkpoint_cb],
        max_steps=iterations,
        num_sanity_val_steps=10,
        gpus=1 if torch.cuda.is_available() else 0,
        precision=16 if torch.cuda.is_available() else None,
        default_root_dir=logs_root_dir,
        accumulate_grad_batches=args.batch_size_mult,
        enable_progress_bar=sys.stdout.isatty() or args.enable_progress,
        gradient_clip_val=0.2,
        **check_val_opts,
    )

    trainer.fit(
        meta_module,
        train_dataloader,
        [
            DataLoader(
                PaddingDataset(
                    Subset(demonstrations, torch.randperm(len(demonstrations))[:1024]),
                    (8, 72, None),
                    (pad_word, pad_action, None),
                ),
                batch_size=max([args.train_batch_size, args.valid_batch_size]),
                pin_memory=True,
            )
            for demonstrations in valid_demonstrations_dict.values()
        ],
        **restore_checkpoint_opts,
    )
    print(f"Done, saving {model_path}")
    trainer.save_checkpoint(f"{model_path}")


if __name__ == "__main__":
    main()
