import os
import copy
import math
from functools import partial
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tarp.modules.layers import BaseProcessingNet, ConvBlockEnc, \
    ConvBlockDec, init_weights_xavier, get_num_conv_layers, ConvBlockHeadDec, ConvBlock, LayerBuilderParams, ConvBlockEnc, ConvBlockLastEnc, ConvBlockFirstEnc, ConvBlockFirstDec, ConvBlockMiddleDec, ConvBlockLastDec, DotAttn
from tarp.modules.recurrent_modules import BaseProcessingLSTM, \
    BidirectionalLSTM, ForwardLSTMCell, ForwardGRUCell, BareGRUCell
from tarp.modules.variational_inference import Gaussian
from tarp.modules.variational_inference import SequentialGaussian_SharedPQ
from tarp.modules.variational_inference import UnitGaussian, stack
from tarp.utils.general_utils import SkipInputSequential, GetIntermediatesSequential, \
    remove_spatial, batchwise_index, batch_apply, map_recursive, apply_linear, ParamDict, listdict2dictlist, map_dict
from tarp.utils.pytorch_utils import like, AttrDictPredictor, batchwise_assign, make_one_hot, mask_out
from tarp.utils.general_utils import broadcast_final, AttrDict
from torch import Tensor
from tarp.components.checkpointer import CheckpointHandler


class ParamLayer(nn.Module):
    def __init__(self, n_dim, init_value):
        super().__init__()
        self.param = nn.Parameter(torch.zeros(1, n_dim) + init_value)

    def forward(self, input):
        return self.param.repeat(input.size()[0], 1)


class Predictor(BaseProcessingNet):
    def __init__(self, hp, input_size, output_size, num_layers=None, detached=False, spatial=True,
                 final_activation=None, mid_size=None, use_convs=None):
        self.spatial = spatial
        mid_size = hp.nz_mid if mid_size is None else mid_size
        if num_layers is None:
            num_layers = hp.n_processing_layers

        if use_convs is not None:
            hp.builder.use_convs = False
        super().__init__(input_size, mid_size, output_size, num_layers=num_layers, builder=hp.builder,
                         detached=detached, final_activation=final_activation)

    def forward(self, *inp):
        out = super().forward(*inp)
        return remove_spatial(out, yes=not self.spatial)

class TwinPredictor(nn.Module):
    def __init__(self, hp, input_size, output_size, num_layers=None, detached=False, spatial=True,
                 final_activation=None, mid_size=None):
        super().__init__()
        self._hp = hp
        self.net1 = Predictor(hp, input_size, output_size, num_layers, detached, spatial, final_activation, mid_size)
        self.net2 = Predictor(hp, input_size, output_size, num_layers, detached, spatial, final_activation, mid_size)

    def forward(self, *inp):
        out1 = self.net1(*inp)
        out2 = self.net2(*inp)
        return out1, out2


class Encoder(nn.Module):
    def __init__(self, hp):
        super().__init__()
        self._hp = self._default_hparams().overwrite(hp)
        if hp.builder.use_convs:
            self.net = ConvEncoder(hp)
        else:
            self.net = Predictor(hp, hp.state_dim, hp.nz_enc, num_layers=hp.builder.get_num_layers())

        if self._hp.encoder_checkpoint is not None:
            self._load_checkpoint()

    def _default_hparams(self):
        return ParamDict({
            'finetune': False,
            'encoder_checkpoint': None,
            'encoder_epoch': 'latest',
            'detach_encoder': False,
        })

    def _load_checkpoint(self):
        """Loads weights for a given model from the given checkpoint directory."""
        assert self._hp.builder.use_convs
        checkpoint = self._hp.encoder_checkpoint
        epoch = self._hp.encoder_epoch
        # self.device = self._hp.device
        checkpoint_dir = checkpoint if os.path.basename(checkpoint) == 'weights' \
                            else os.path.join(checkpoint, 'weights')     # checkpts in 'weights' dir
        checkpoint_path = CheckpointHandler.get_resume_ckpt_file(epoch, checkpoint_dir)
        CheckpointHandler.load_weights(checkpoint_path, model=self, model_key='encoder')

    def forward(self, input):
        if (self._hp.encoder_checkpoint is not None and not self._hp.finetune) or self._hp.detach_encoder:
            with torch.no_grad():
                self.net.eval()
                feat = self.net(input).detach()
        else:
            feat = self.net(input)
        return feat

    def tie_conv_from(self, source):
        assert self._hp.builder.use_convs
        assert type(self) == type(source)
        self.net.tie_conv_from(source.net)

    @property
    def device(self):
        return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class ConvEncoder(nn.Module):
    def __init__(self, hp):
        super().__init__()
        self._hp = hp

        n = hp.builder.get_num_layers(hp.img_sz)
        self.net = GetIntermediatesSequential(hp.skips_stride) if hp.use_skips else nn.Sequential()

        self.net.add_module('input', ConvBlockEnc(in_dim=hp.input_nc, out_dim=hp.ngf, normalization=None,
                                                  builder=hp.builder))
        for i in range(n - 3):
            filters_in = hp.ngf * 2 ** i
            self.net.add_module('pyramid-{}'.format(i),
                                ConvBlockEnc(in_dim=filters_in, out_dim=filters_in*2, normalize=hp.builder.normalize,
                                             builder=hp.builder))

        # add output layer
        self.net.add_module('head', nn.Conv2d(hp.ngf * 2 ** (n - 3), hp.nz_enc, 4))
        self.net.apply(init_weights_xavier)

    def forward(self, input):
        return self.net(input)

    def tie_conv_from(self, source):
        assert type(self) == type(source)
        for i in range(len(self.net)):
            if type(self.net[i]) == ConvBlockEnc:
                self._tie_weights(self.net[i].conv, source.net[i].conv)
            else:
                self._tie_weights(self.net[i], source.net[i])

    def _tie_weights(self, target, source):
        assert type(source) == type(target)
        target.weight = source.weight
        target.bias = source.bias

    @property
    def device(self):
        return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class Decoder(nn.Module):
    """ A thin wrapper class that decides which decoder to build """
    def __init__(self, hp):
        super().__init__()
        self._hp = hp

        if hp.builder.use_convs:
            assert not (self._hp.add_weighted_pixel_copy & self._hp.pixel_shift_decoder)

            if self._hp.pixel_shift_decoder:
                self.net = PixelShiftDecoder(hp)
            elif self._hp.add_weighted_pixel_copy:
                self.net = PixelCopyDecoder(hp)
            elif self._hp.mask_pred_decoder:
                self.net = ConvMaskDecoder(hp)
            else:
                self.net = ConvDecoder(hp)

        else:
            assert not self._hp.use_skips
            assert not self._hp.add_weighted_pixel_copy
            assert not self._hp.pixel_shift_decoder
            state_predictor = Predictor(hp, hp.nz_enc, hp.state_dim, num_layers=hp.builder.get_num_layers())
            self.net = AttrDictPredictor({'images': state_predictor})

    def forward(self, input, **kwargs):
        if not (self._hp.pixel_shift_decoder or self._hp.add_weighted_pixel_copy) and 'pixel_source' in kwargs:
            kwargs.pop('pixel_source')

        if not self._hp.use_skips and 'skips' in kwargs:
            kwargs.pop('skips')

        output = self.net(input, **kwargs)
        return output

    def decode_seq(self, inputs, encodings):
        """ Decodes a sequence of images given the encodings

        :param inputs: {'skips', 'pixel_sources'} - pixel_sources is list of source images
        :param encodings:
        :param seq_len:
        :return:
        """

        def extend_to_seq(tensor):
            return tensor[:, None].expand([tensor.shape[0], encodings.shape[1]] + list(tensor.shape[1:])).contiguous()

        decoder_inputs = AttrDict(input=encodings)
        if 'skips' in inputs:
            decoder_inputs.skips = map_recursive(extend_to_seq, inputs.skips)
        if 'pixel_sources' in inputs:
            decoder_inputs.pixel_source = map_recursive(extend_to_seq, inputs.pixel_sources)

        return batch_apply(decoder_inputs, self, separate_arguments=True)


class ConvDecoder(nn.Module):
    def __init__(self, hp):
        super().__init__()

        self._hp = hp
        n = get_num_conv_layers(hp.img_sz)
        self.net = SkipInputSequential(hp.skips_stride) if hp.use_skips else nn.Sequential()
        out_dim = hp.ngf * 2 ** (n - 3)
        self.net.add_module('net',
                            ConvBlockHeadDec(in_dim=hp.nz_enc, out_dim=out_dim, normalize=hp.builder.normalize,
                                              builder=hp.builder))

        for i in reversed(range(n - 3)):
            filters_out = hp.ngf * 2 ** i
            filters_in = filters_out * 2
            if self._hp.use_skips and (i+1) % hp.skips_stride == 0:
                filters_in = filters_in * 2

            self.net.add_module('pyramid-{}'.format(i),
                                ConvBlockDec(in_dim=filters_in, out_dim=filters_out, normalize=hp.builder.normalize,
                                             builder=hp.builder))


        self.head_filters_out = filters_out = hp.ngf
        filters_in = filters_out
        if self._hp.use_skips and 0 % hp.skips_stride == 0:
            filters_in = filters_in * 2

        self.net.add_module('additional_conv_layer', ConvBlockDec(in_dim=filters_in, out_dim=filters_out,
                                                                      normalization=None, activation=nn.Tanh(), builder=hp.builder))

        if 'dec_last_activation' in self._hp:
            activation = self._hp.dec_last_activation
        else:
            activation = nn.Tanh()
        self.gen_head = ConvBlockDec(in_dim=filters_out, out_dim=hp.input_nc, normalization=None,
                                                 activation=activation, builder=hp.builder, upsample=False)

        self.net.apply(init_weights_xavier)
        self.gen_head.apply(init_weights_xavier)

    def forward(self, *args, **kwargs):
        output = AttrDict()
        output.feat = self.net(*args, **kwargs)
        output.images = self.gen_head(output.feat)
        return output

class ConvMaskDecoder(nn.Module):
    def __init__(self, hp):
        super().__init__()

        self._hp = hp
        n = get_num_conv_layers(hp.img_sz)
        self.net = SkipInputSequential(hp.skips_stride) if hp.use_skips else nn.Sequential()
        out_dim = hp.ngf * 2 ** (n - 3)
        self.net.add_module('net',
                            ConvBlockHeadDec(in_dim=hp.nz_enc, out_dim=out_dim, normalize=hp.builder.normalize,
                                              builder=hp.builder))

        for i in reversed(range(n - 3)):
            filters_out = hp.ngf * 2 ** i
            filters_in = filters_out * 2
            if self._hp.use_skips and (i+1) % hp.skips_stride == 0:
                filters_in = filters_in * 2

            self.net.add_module('pyramid-{}'.format(i),
                                ConvBlockDec(in_dim=filters_in, out_dim=filters_out, normalize=hp.builder.normalize,
                                             builder=hp.builder))


        self.head_filters_out = filters_out = hp.ngf
        filters_in = filters_out
        if self._hp.use_skips and 0 % hp.skips_stride == 0:
            filters_in = filters_in * 2

        self.net.add_module('additional_conv_layer', ConvBlockDec(in_dim=filters_in, out_dim=filters_out,
                                                                      normalization=None, activation=nn.Tanh(), builder=hp.builder))

        if 'dec_last_activation' in self._hp:
            activation = self._hp.dec_last_activation
        else:
            activation = nn.Tanh()
        self.gen_head = ConvBlockDec(in_dim=filters_out, out_dim=hp.input_nc*2, normalization=None,
                                                 activation=activation, builder=hp.builder, upsample=False)

        self.net.apply(init_weights_xavier)
        self.gen_head.apply(init_weights_xavier)

    def forward(self, *args, **kwargs):
        output = AttrDict()
        output.feat = self.net(*args, **kwargs)
        output.mean, output.mask = torch.chunk(self.gen_head(output.feat), 2, dim=1)
        # output.pred = torch.distributions.independent.Independent(torch.distributions.Normal(output.mean, 1), 3)
        return output

class PixelCopyDecoder(ConvDecoder):
    def __init__(self, hp, n_masks=None):
        super().__init__(hp)
        self.n_pixel_sources = hp.n_pixel_sources
        n_masks = n_masks or self.n_pixel_sources + 1

        self.mask_head = ConvBlockDec(in_dim=self.head_filters_out, out_dim=n_masks,
                                      normalization=None, activation=nn.Softmax(dim=1), builder=hp.builder,
                                      upsample=False)
        self.apply(init_weights_xavier)

    def forward(self, *args, pixel_source, **kwargs):
        output = super().forward(*args, **kwargs)
        assert len(pixel_source) == self.n_pixel_sources    # number of pixel sources does not correspond to param

        output.pixel_copy_mask, output.images = self.mask_and_merge(output.feat, pixel_source + [output.images])
        return output

    # @torch.jit.script_method
    def mask_and_merge(self, feat, pixel_source):
        # type: (Tensor, List[Tensor]) -> Tuple[Tensor, Tensor]

        mask = self.mask_head(feat)
        candidate_images = torch.stack(pixel_source, dim=1)
        images = (mask.unsqueeze(2) * candidate_images).sum(dim=1)
        return mask, images


class PixelShiftDecoder(PixelCopyDecoder):
    def __init__(self, hp):
        self.n_pixel_sources = hp.n_pixel_sources
        super().__init__(hp, n_masks=1 + self.n_pixel_sources * 2)

        self.flow_heads = nn.ModuleList([])
        for i in range(self.n_pixel_sources):
            self.flow_heads.append(ConvBlockDec(in_dim=self.head_filters_out, out_dim=2, normalization=None,
                                                activation=None, builder=hp.builder, upsample=False))

        self.apply(init_weights_xavier)

    @staticmethod
    def apply_flow(image, flow):
        """ Modified from
        https://github.com/febert/visual_mpc/blob/dev/python_visual_mpc/pytorch/goalimage_warping/goalimage_warper.py#L81
        """

        theta = image.new_tensor([[1, 0, 0], [0, 1, 0]]).reshape(1, 2, 3).repeat_interleave(image.size()[0], dim=0)
        identity_grid = F.affine_grid(theta, image.size())
        sample_pos = identity_grid + flow.permute(0, 2, 3, 1)
        image = F.grid_sample(image, sample_pos)
        return image

    def forward(self, *args, pixel_source, **kwargs):
        output = ConvDecoder.forward(self, *args, **kwargs)
        assert len(pixel_source) == self.n_pixel_sources  # number of pixel sources does not correspond to param

        output.flow_fields = list([head(output.feat) for head in self.flow_heads])
        output.warped_sources = list([self.apply_flow(source, flow) for source, flow in
                                      zip(pixel_source, output.flow_fields)])

        _, output.images = self.mask_and_merge(
            output.feat, pixel_source + output.warped_sources + [output.images])
        return output


class Attention(nn.Module):
    def __init__(self, hp):
        super().__init__()
        self._hp = hp
        time_cond_length = self._hp.max_seq_len if self._hp.one_hot_attn_time_cond else 1
        input_size = hp.nz_enc * 2 + time_cond_length if hp.timestep_cond_attention else hp.nz_enc * 2
        self.query_net = get_predictor(hp, input_size, hp.nz_attn_key)
        self.attention_layers = nn.ModuleList([MultiheadAttention(hp) for _ in range(hp.n_attention_layers)])
        self.predictor_layers = nn.ModuleList([get_predictor(hp, hp.nz_enc, hp.nz_attn_key, num_layers=2)
                                               for _ in range(hp.n_attention_layers)])
        self.out = nn.Linear(hp.nz_enc, hp.nz_enc)

    def forward(self, enc_demo_seq, enc_demo_key_seq, e_l, e_r, start_ind, end_ind, inputs, timestep=None):
        """Performs multi-layered, multi-headed attention."""

        if self._hp.forced_attention:
            return batchwise_index(enc_demo_seq, timestep[:,0].long()), None

        # Get (initial) attention key
        if self._hp.one_hot_attn_time_cond and timestep is not None:
            one_hot_timestep = make_one_hot(timestep.long(), self._hp.max_seq_len).float()
        else:
            one_hot_timestep = timestep
        args = [one_hot_timestep] if self._hp.timestep_cond_attention else []

        query = self.query_net(e_l, e_r, *args)

        # Attend
        s_ind, e_ind = (torch.floor(start_ind), torch.ceil(end_ind)) if self._hp.mask_inf_attention \
                                                                     else (inputs.start_ind, inputs.end_ind)
        norm_shape_k = query.shape[1:]
        norm_shape_v = enc_demo_seq.shape[2:]
        raw_attn_output, att_weights = None, None
        for attention, predictor in zip(self.attention_layers, self.predictor_layers):
            raw_attn_output, att_weights = attention(query, enc_demo_key_seq, enc_demo_seq, s_ind, e_ind,
                                                     forced_attention_step=timestep if self._hp.forced_attention else None)
            x = F.layer_norm(raw_attn_output, norm_shape_v)
            query = F.layer_norm(predictor(x) + query, norm_shape_k)  # skip connections around attention and predictor

        return apply_linear(self.out, raw_attn_output, dim=1), att_weights     # output non-normalized output of final attention layer


class MultiheadAttention(nn.Module):
    def __init__(self, hp, dropout=0.0):
        super().__init__()
        self._hp = hp
        self.nz = hp.nz_enc
        self.nz_attn_key = hp.nz_attn_key
        self.n_heads = hp.n_attention_heads
        assert self.nz % self.n_heads == 0  # number of attention heads needs to evenly divide latent
        assert self.nz_attn_key % self.n_heads == 0  # number of attention heads needs to evenly divide latent
        self.nz_v_i = self.nz // self.n_heads
        self.nz_k_i = self.nz_attn_key // self.n_heads
        self.temperature = nn.Parameter(self._hp.attention_temperature * torch.ones(1)) if self._hp.learn_attn_temp \
            else self._hp.attention_temperature

        # set up transforms for inputs / outputs
        self.q_linear = nn.Linear(self.nz_attn_key, self.nz_attn_key)
        self.k_linear = nn.Linear(self.nz_attn_key, self.nz_attn_key)
        self.v_linear = nn.Linear(self.nz, self.nz)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(self.nz, self.nz)

    def forward(self, q, k, v, start_ind, end_ind, forced_attention_step=None):
        batch_size, time = list(k.shape)[:2]
        latent_shape = list(v.shape[2:])

        # perform linear operation and split into h heads
        q = apply_linear(self.q_linear, q, dim=1).view(batch_size, self.n_heads, self.nz_k_i, *latent_shape[1:])
        k = apply_linear(self.k_linear, k, dim=2).view(batch_size, time, self.n_heads, self.nz_k_i, *latent_shape[1:])
        v = apply_linear(self.v_linear, v, dim=2).view(batch_size, time, self.n_heads, self.nz_v_i, *latent_shape[1:])

        # compute masked, multi-headed attention
        vals, att_weights = self.attention(q, k, v, self.nz_k_i, start_ind, end_ind, self.dropout, forced_attention_step)

        # concatenate heads and put through final linear layer
        concat = vals.contiguous().view(batch_size, *latent_shape)
        return apply_linear(self.out, concat, dim=1), att_weights.mean(dim=-1)

    def attention(self, q, k, v, nz_k, start_ind, end_ind, dropout=None, forced_attention_step=None):

        def tensor_product(key, sequence):
            dims = list(range(len(list(sequence.shape)))[3:])
            return (key[:, None] * sequence).sum(dim=dims)

        attn_scores = tensor_product(q, k) / math.sqrt(nz_k) * self.temperature
        attn_scores = MultiheadAttention.mask_out(attn_scores, start_ind, end_ind)
        attn_scores = F.softmax(attn_scores, dim=1)

        if forced_attention_step is not None:
            scores_f = torch.zeros_like(attn_scores)
            batchwise_assign(scores_f, forced_attention_step[:, 0].long(), 1.0)

        scores = scores_f if forced_attention_step is not None else attn_scores
        if dropout is not None and dropout.p > 0.0:
            scores = dropout(scores)

        return (broadcast_final(scores, v) * v).sum(dim=1), attn_scores

    @staticmethod
    def mask_out(scores, start_ind, end_ind):
        # Mask out the frames that are not in the range
        _, mask = mask_out(scores, start_ind, end_ind, -np.inf)
        scores[mask.all(dim=1)] = 1  # When the sequence is empty, fill ones to prevent crashing in Multinomial
        return scores

    def log_outputs_stateful(self, step, log_images, phase, logger):
        if phase == 'train':
            logger.log_scalar(self.temperature, 'attention_softmax_temp', step, phase)


class GaussianPredictor(Predictor):
    def __init__(self, hp, input_dim, gaussian_dim=None, spatial=False):
        if gaussian_dim is None:
            gaussian_dim = hp.nz_vae

        super().__init__(hp, input_dim, gaussian_dim * 2, spatial=spatial)

    def forward(self, *inputs):
        return Gaussian(super().forward(*inputs)).tensor()


class ApproximatePosterior(GaussianPredictor):
    def __init__(self, hp):
        super().__init__(hp, hp.nz_enc * 3)


class LearnedPrior(GaussianPredictor):
    def __init__(self, hp):
        super().__init__(hp, hp.nz_enc * 2)


class FixedPrior(nn.Module):
    def __init__(self, hp):
        super().__init__()
        self.hp = hp

    def forward(self, e_l, *args):  # ignored because fixed prior
        return UnitGaussian([e_l.shape[0], self.hp.nz_vae], self.hp.device).tensor()


class VariationalInference2LayerSharedPQ(nn.Module):
    def __init__(self, hp):
        super().__init__()
        self.q1 = GaussianPredictor(hp, hp.nz_enc * 3, hp.nz_vae * 2)
        self.q2 = GaussianPredictor(hp, hp.nz_vae + 2 * hp.nz_enc, hp.nz_vae2 * 2)  # inputs are two parents and z1

    def forward(self, e_l, e_r, e_tilde):
        g1 = self.q1(e_l, e_r, e_tilde)
        z1 = Gaussian(g1).sample()
        g2 = self.q2(z1, e_l, e_r)
        return SequentialGaussian_SharedPQ(g1, z1, g2)


class TwolayerPriorSharedPQ(nn.Module):
    def __init__(self, hp, p1, q_p_shared):
        super().__init__()
        self.p1 = p1
        self.q_p_shared = q_p_shared

    def forward(self, e_l, e_r):
        g1 = self.p1(e_l, e_r)
        z1 = Gaussian(g1).sample()
        g2 = self.q_p_shared(z1, e_l, e_r)  # make sure its the same order of arguments as in usage above!!

        return SequentialGaussian_SharedPQ(g1, z1, g2)


def get_prior(hp):
    if hp.prior_type == 'learned':
        return LearnedPrior(hp)
    elif hp.prior_type == 'fixed':
        return FixedPrior(hp)


def setup_variation_inference(hp):
    if hp.var_inf == '2layer':
        q = VariationalInference2LayerSharedPQ(hp)
        p = TwolayerPriorSharedPQ(hp, get_prior(hp), q.p_q_shared)

    elif hp.var_inf == 'standard':
        q = ApproximatePosterior(hp)
        p = get_prior(hp)

    elif hp.var_inf == 'deterministic':
        q = FixedPrior(hp)
        p = FixedPrior(hp)

    return q, p


class SeqEncodingModule(nn.Module):
    def __init__(self, hp, add_time=True):
        super().__init__()
        self.hp = hp
        self.add_time = add_time
        self.build_network(hp.nz_enc + add_time, hp)

    def build_network(self, input_size, hp):
        """ This has to define self.net """
        raise NotImplementedError()

    def run_net(self, seq):
        """ Run the network here """
        return self.net(seq)

    def forward(self, seq):
        sh = list(seq.shape)
        seq = seq.view(sh[:2] + [-1])

        if self.add_time:
            time = like(torch.arange, seq)(seq.shape[1])[None, :, None].repeat([sh[0], 1, 1])
            seq = torch.cat([seq, time], dim=2)

        proc_seq = self.run_net(seq)
        proc_seq = proc_seq.view(sh)
        return proc_seq


class ConvSeqEncodingModule(SeqEncodingModule):
    def build_network(self, input_size, hp):
        kernel_size = hp.conv_inf_enc_kernel_size
        assert kernel_size % 2 != 0     # need uneven kernel size for padding
        padding = int(np.floor(kernel_size / 2))
        n_layers = hp.conv_inf_enc_layers
        block = partial(ConvBlock, d=1, kernel_size=kernel_size, padding=padding)
        self.net = BaseProcessingNet(input_size, hp.nz_mid, hp.nz_enc, n_layers, hp.builder, block=block)

    def run_net(self, seq):
        # 1d convolutions expect length-last
        proc_seq = self.net(seq.transpose(1, 2)).transpose(1, 2)
        return proc_seq

class ConvFrameMaskDecoder(nn.Module):
    def __init__(self, hp, action_emb):
        super().__init__()
        self.action_emb = action_emb
        self._hp = hp
        if self._hp.exclude_lang:
            self.cell = nn.LSTMCell(self._hp.emb_vis_dim+self._hp.emb_dim, self._hp.nz_mid_lstm*2)
        else:
            self.cell = nn.LSTMCell(self._hp.emb_dim+self._hp.emb_vis_dim+self._hp.nz_mid_lstm*2, self._hp.nz_mid_lstm*2)

        if self._hp.use_conv_feat:
            from torchvision import models
            self.encoder = models.resnet18(pretrained=True)
            self.encoder = self.encoder.eval()
            self.encoder = nn.Sequential(*list(self.encoder.children())[:-2])
            self.vis_encoder = nn.Sequential(
                nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(64*8*8, self._hp.emb_vis_dim)
            )
        else:
            self.encoder = Encoder(self._hp)
            self.vis_encoder = Predictor(hp, input_size=self._hp.nz_enc, output_size=self._hp.emb_vis_dim)
        self.attn = DotAttn()
        self.go = nn.Parameter(torch.Tensor(self._hp.emb_dim))

        if self._hp.exclude_lang:
            self.actor = nn.Linear(self._hp.nz_mid_lstm*2+self._hp.emb_vis_dim+self._hp.emb_dim, self._hp.emb_dim)
        else:
            self.actor = nn.Linear(self._hp.nz_mid_lstm*4+self._hp.emb_vis_dim+self._hp.emb_dim, self._hp.emb_dim)

        decoder_hp = copy.deepcopy(self._hp)
        decoder_hp.input_nc = 1
        decoder_hp.dec_last_activation = None
        decoder_hp.builder.normalization = 'batch'
        decoder_hp.builder.normalize = True
        if self._hp.exclude_lang:
            decoder_hp.nz_enc = self._hp.nz_mid_lstm*2+self._hp.emb_vis_dim+self._hp.emb_dim
        else:
            decoder_hp.nz_enc = self._hp.nz_mid_lstm*4+self._hp.emb_vis_dim+self._hp.emb_dim
        self.mask_dec = Decoder(decoder_hp)
        # self.mask_dec = Decoder(dhid=self._hp.nz_mid_lstm*2+self._hp.emb_vis_dim+self._hp.emb_dim, pframe=self.pframe)
        self.h_tm1_fc = nn.Linear(self._hp.nz_mid_lstm*2, self._hp.nz_mid_lstm*2)

        if self._hp.exclude_lang:
            self.subgoal = nn.Linear(self._hp.nz_mid_lstm*2+self._hp.emb_vis_dim+self._hp.emb_dim, 1)
            self.progress = nn.Linear(self._hp.nz_mid_lstm*2+self._hp.emb_vis_dim+self._hp.emb_dim, 1)
        else:
            self.subgoal = nn.Linear(self._hp.nz_mid_lstm*4+self._hp.emb_vis_dim+self._hp.emb_dim, 1)
            self.progress = nn.Linear(self._hp.nz_mid_lstm*4+self._hp.emb_vis_dim+self._hp.emb_dim, 1)
        self.vis_dropout = nn.Dropout(0.3)
        self.hstate_dropout = nn.Dropout(0.3)
        nn.init.uniform_(self.go, -0.1, 0.1)

    def step(self, enc, enc_vis, e_t, state_tm1):
        h_tm1 = state_tm1[0]

        lang_feat_t = enc
        vis_feat_t = remove_spatial(self.vis_encoder(enc_vis))
        weighted_lang_t, lang_attn_t = self.attn(lang_feat_t, self.h_tm1_fc(h_tm1))

        # concat visual feats, weight lang, and previous action embedding
        if self._hp.exclude_lang:
            inp_t = torch.cat([vis_feat_t, e_t], dim=1)
        else:
            inp_t = torch.cat([vis_feat_t, weighted_lang_t, e_t], dim=1)

        # update hidden state
        state_t = self.cell(inp_t, state_tm1)
        state_t = [self.hstate_dropout(x) for x in state_t]
        h_t = state_t[0]

        # decode action and mask
        cont_t = torch.cat([h_t, inp_t], dim=1)
        action_emb_t = self.actor(cont_t)
        action_t = action_emb_t.mm(self.action_emb.weight.t())
        mask_t = self.mask_dec(cont_t.unsqueeze(-1).unsqueeze(-1)).images
        mask_t = F.interpolate(mask_t, size=(self._hp.pframe, self._hp.pframe), mode='bilinear')

        # predict subgoals completed and task progress
        subgoal_t = F.sigmoid(self.subgoal(cont_t))
        progress_t = F.sigmoid(self.progress(cont_t))

        return action_t, mask_t, state_t, lang_attn_t, subgoal_t, progress_t

    def encode(self, obs):
        enc_vis = self.encoder(obs)
        if self._hp.use_conv_feat or self._hp.detach_encoder:
            enc_vis = enc_vis.detach()
        enc_vis = self.vis_dropout(enc_vis)
        return enc_vis

    def forward(self, enc, obs, gold=None, max_decode=300, state_0=None):
        n, t, c, w, h = obs.shape
        obs = obs.reshape((n*t, c, w, h))
        enc_vis = self.encode(obs)
        _, _, enc_w, enc_h = enc_vis.shape
        enc_vis = enc_vis.reshape((n, t, self._hp.nz_enc, enc_w, enc_h))


        max_t = gold.size(1) if self.training else min(max_decode, enc_vis.shape[1])
        batch = enc.size(0)
        e_t = self.go.repeat(batch, 1)

        state_t = state_0
        actions = []
        masks = []
        attn_scores = []
        subgoals = []
        progresses = []
        for t in range(max_t):
            action_t, mask_t, state_t, attn_score_t, subgoal_t, progress_t = self.step(enc, enc_vis[:, t], e_t, state_t)
            masks.append(mask_t)
            actions.append(action_t)
            attn_scores.append(attn_score_t)
            subgoals.append(subgoal_t)
            progresses.append(progress_t)

            if self._hp.teacher_forcing and self.training:
                w_t = gold[:, t]
            else:
                w_t = action_t.max(1)[1]
            e_t = self.action_emb(w_t)
        results = AttrDict(
            actions=torch.stack(actions, dim=1),
            action_masks=torch.stack(masks, dim=1),
            attn_scores=torch.stack(attn_scores, dim=1),
            state_t=state_t,
            subgoal=torch.stack(subgoals, dim=1),
            progress=torch.stack(progresses, dim=1)
        )
        return results




class RecurrentSeqEncodingModule(SeqEncodingModule):
    def build_network(self, input_size, hp):
        self.net = BaseProcessingLSTM(hp, input_size, hp.nz_enc)


class BidirectionalSeqEncodingModule(SeqEncodingModule):
    def build_network(self, input_size, hp):
        self.net = BidirectionalLSTM(hp, input_size, hp.nz_enc)


class GeneralizedPredictorModel(nn.Module):
    """Predicts the list of output values with optionally different activations."""
    def __init__(self, hp, input_dim, output_dims, activations, detached=False):
        super().__init__()
        assert output_dims  # need non-empty list of output dims defining the number of output values
        assert len(output_dims) == len(activations)     # need one activation for every output dim
        self._hp = hp
        self.activations = activations
        self.output_dims = output_dims
        self.num_outputs = len(output_dims)
        self._build_model(hp, input_dim, detached)

    def _build_model(self, hp, input_dim, detached):
        self.p = Predictor(hp, input_dim, sum(self.output_dims), detached=detached)

    def forward(self, *inputs):
        net_outputs = self.p(*inputs)
        outputs = []
        current_idx = 0
        for output_dim, activation in zip(self.output_dims, self.activations):
            output = net_outputs[:, current_idx:current_idx+output_dim]
            if activation is not None:
                output = activation(output)
            if output_dim == 1:
                output = output.view(-1)      # reduce spatial dimensions for scalars
            outputs.append(output)
        outputs = outputs[0] if len(outputs) == 1 else outputs
        return outputs


class HybridConvMLPEncoder(nn.Module):
    """Encodes image and vector input, fuses features using MLP to produce output feature."""
    def __init__(self, hp):
        super().__init__()
        self._hp = self._default_hparams().overwrite(hp)
        if self._hp.input_width is None and self._hp.input_height is None:
            self._hp.input_width = self._hp.input_res
            self._hp.input_height = self._hp.input_res
        self._hp.builder = LayerBuilderParams(use_convs=False, normalization=self._hp.normalization)

        self._vector_enc = Predictor(self._hp,
                                     input_size=self._hp.input_dim,
                                     output_size=self._hp.nz_enc,
                                     mid_size=self._hp.nz_mid,
                                     num_layers=self._hp.n_layers,
                                     final_activation=None,
                                     spatial=False)
        self._image_enc = Encoder(self._updated_encoder_params())

        ratio = max(self._hp.input_width//self._hp.input_height, self._hp.input_height//self._hp.input_width)
        img_enc_size = self._hp.nz_enc * (ratio**2)
        input_size = self._hp.nz_enc + img_enc_size
        self._head = Predictor(self._hp,
                               input_size=input_size,
                               output_size=self._hp.output_dim,
                               mid_size=self._hp.nz_mid,
                               num_layers=2,
                               final_activation=None,
                               spatial=False)

    def _default_hparams(self):
        return ParamDict({
            'input_dim': None,          # dimensionality of the vector input
            'input_res': None,          # resolution of image input
            'output_dim': None,         # dimensionality of output tensor
            'input_nc': 3,              # number of input channels
            'ngf': 8,                   # number of channels in shallowest layer of image encoder
            'nz_enc': 32,               # number of dimensions in encoder-latent space
            'nz_mid': 32,               # number of dimensions for internal feature spaces
            'n_layers': 3,              # number of layers in MLPs
            'normalization': 'none',    # normalization used in encoder network ['none', 'batch']
            'use_convs': False,
            'device': None,
            'finetune': False,
            'encoder_checkpoint': None,
            'input_width': None,
            'input_height': None
        })

    @property
    def encoder(self):
        return self._image_enc

    def _load_checkpoint(self):
        """Loads weights for a given model from the given checkpoint directory."""
        checkpoint = self._hp.encoder_checkpoint
        epoch = self._hp.encoder_epoch
        self._image_enc.device = self.device
        checkpoint_dir = checkpoint if os.path.basename(checkpoint) == 'weights' \
                            else os.path.join(checkpoint, 'weights')     # checkpts in 'weights' dir
        checkpoint_path = checkpointhandler.get_resume_ckpt_file(epoch, checkpoint_dir)
        model_weights = torch.load(checkpoint_path)
        self._image_enc.load_state_dict(model_weights['encoder'])

    def forward(self, inputs, **kwargs):
        vector_feature = self._vector_enc(inputs.vector)
        if self._hp.encoder_checkpoint is not None and not self._hp.finetune:
            with torch.no_grad():
                self._image_enc.eval()
                img_feature = remove_spatial(self._image_enc(inputs.image)).detach()
        else:
            img_feature = remove_spatial(self._image_enc(inputs.image))
        return self._head(torch.cat((vector_feature, img_feature), dim=-1))

    def _updated_encoder_params(self):
        params = copy.deepcopy(self._hp)
        return params.overwrite(AttrDict(
            use_convs=True,
            use_skips=False,                  # no skip connections needed bc we are not reconstructing
            img_sz=self._hp.input_res,        # image resolution
            input_nc=self._hp.input_nc,       # number of input feature maps
            builder=LayerBuilderParams(use_convs=True, normalization=self._hp.normalization)
        ))

class MultiHeadHybridConvMLPEncoder(nn.Module):
    """Encodes image and vector input, fuses features using MLP to produce output feature."""
    def __init__(self, hp):
        super().__init__()
        self._hp = self._default_hparams().overwrite(hp)
        self._hp.builder = LayerBuilderParams(use_convs=False, normalization=self._hp.normalization)

        self._vector_enc = Predictor(self._hp,
                                     input_size=self._hp.input_dim,
                                     output_size=self._hp.nz_enc,
                                     mid_size=self._hp.nz_mid,
                                     num_layers=self._hp.n_layers,
                                     final_activation=None,
                                     spatial=False)
        # self._vector_enc = nn.ModuleDict({name: Predictor(self._hp,
        #                                      input_size=self._hp.input_dim,
        #                                      output_size=self._hp.nz_enc,
        #                                      mid_size=self._hp.nz_mid,
        #                                      num_layers=self._hp.n_layers,
        #                                      final_activation=None,
        #                                      spatial=False)
        #                                 for name in self._hp.head_keys})

        self._image_enc = Encoder(self._updated_encoder_params())

        input_size = 2*self._hp.nz_enc
        self._heads = nn.ModuleDict({
            name: Predictor(self._hp,
                       input_size=input_size,
                       output_size=self._hp.output_dim,
                       mid_size=self._hp.nz_mid,
                       num_layers=2,
                       final_activation=None,
                       spatial=False)
            for name in self._hp.head_keys
        })

    def _default_hparams(self):
        return ParamDict({
            'input_dim': None,          # dimensionality of the vector input
            'input_res': None,          # resolution of image input
            'output_dim': None,         # dimensionality of output tensor
            'input_nc': 3,              # number of input channels
            'ngf': 8,                   # number of channels in shallowest layer of image encoder
            'nz_enc': 32,               # number of dimensions in encoder-latent space
            'nz_mid': 32,               # number of dimensions for internal feature spaces
            'n_layers': 3,              # number of layers in MLPs
            'normalization': 'none',    # normalization used in encoder network ['none', 'batch']
            'use_convs': False,
            'device': None,
            'finetune': False,
            'encoder_checkpoint': None,
            'head_keys': ['main']
        })

    @property
    def encoder(self):
        return self._image_enc

    def _load_checkpoint(self):
        """Loads weights for a given model from the given checkpoint directory."""
        checkpoint = self._hp.encoder_checkpoint
        epoch = self._hp.encoder_epoch
        self._image_enc.device = self.device
        checkpoint_dir = checkpoint if os.path.basename(checkpoint) == 'weights' \
                            else os.path.join(checkpoint, 'weights')     # checkpts in 'weights' dir
        checkpoint_path = checkpointhandler.get_resume_ckpt_file(epoch, checkpoint_dir)
        model_weights = torch.load(checkpoint_path)
        self._image_enc.load_state_dict(model_weights['encoder'])

    def forward(self, inputs, **kwargs):
        # vector_features = {name: self._vector_enc[name](vec) for name, vec in inputs.vector.items()}
        vector_features = {name: self._vector_enc(vec) for name, vec in inputs.vector.items()}
        if self._hp.encoder_checkpoint is not None and not self._hp.finetune:
            with torch.no_grad():
                self._image_enc.eval()
                img_feature = remove_spatial(self._image_enc(inputs.image)).detach()
        else:
            img_feature = remove_spatial(self._image_enc(inputs.image))
        out = AttrDict({name: self._heads[name](torch.cat((vector_features[name], img_feature), dim=-1)) for name in vector_features.keys()})
        return out

    def _updated_encoder_params(self):
        params = copy.deepcopy(self._hp)
        return params.overwrite(AttrDict(
            use_convs=True,
            use_skips=False,                  # no skip connections needed bc we are not reconstructing
            img_sz=self._hp.input_res,        # image resolution
            input_nc=self._hp.input_nc,       # number of input feature maps
            builder=LayerBuilderParams(use_convs=True, normalization=self._hp.normalization)
        ))

class HybridConvTwinMLPEncoder(nn.Module):
    """Encodes image and vector input, fuses features using MLP to produce output feature."""
    def __init__(self, hp):
        super().__init__()
        self._hp = self._default_hparams().overwrite(hp)
        self._hp.builder = LayerBuilderParams(use_convs=False, normalization=self._hp.normalization)

        self._vector_enc = Predictor(self._hp,
                                     input_size=self._hp.input_dim,
                                     output_size=self._hp.nz_enc,
                                     mid_size=self._hp.nz_mid,
                                     num_layers=self._hp.n_layers,
                                     final_activation=None,
                                     spatial=False)
        self._image_enc = Encoder(self._updated_encoder_params())
        self._head = TwinPredictor(self._hp,
                               input_size=2*self._hp.nz_enc,
                               output_size=self._hp.output_dim,
                               mid_size=self._hp.nz_mid,
                               num_layers=2,
                               final_activation=None,
                               spatial=False)

    def _default_hparams(self):
        return ParamDict({
            'input_dim': None,          # dimensionality of the vector input
            'input_res': None,          # resolution of image input
            'output_dim': None,         # dimensionality of output tensor
            'input_nc': 3,              # number of input channels
            'ngf': 8,                   # number of channels in shallowest layer of image encoder
            'nz_enc': 32,               # number of dimensions in encoder-latent space
            'nz_mid': 32,               # number of dimensions for internal feature spaces
            'n_layers': 3,              # number of layers in MLPs
            'normalization': 'none',    # normalization used in encoder network ['none', 'batch']
            'use_convs': False,
            'device': None,
        })

    def forward(self, inputs, **kwargs):
        vector_feature = self._vector_enc(inputs.vector)
        if not self._hp.finetune:
            self._image_enc.eval()
        img_feature = remove_spatial(self._image_enc(inputs.image))
        # if not self._hp.finetune:
        img_feature = img_feature.detach()
        return self._head(torch.cat((vector_feature, img_feature), dim=-1))

    def _updated_encoder_params(self):
        params = copy.deepcopy(self._hp)
        return params.overwrite(AttrDict(
            use_convs=True,
            use_skips=False,                  # no skip connections needed bc we are not reconstructing
            img_sz=self._hp.input_res,        # image resolution
            input_nc=self._hp.input_nc,       # number of input feature maps
            builder=LayerBuilderParams(use_convs=True, normalization=self._hp.normalization)
        ))

class DummyModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, *args, **kwargs):
        return AttrDict()

    def loss(self, *args, **kwargs):
        return {}
