from __future__ import print_function

from copy import deepcopy
import numpy as np
import tree
import torch
import torch.nn as nn

import helpers.layers as layers
import helpers.utils as utils
import helpers.distributions as dist
import models.vae.flows as flows
from .spatial_transformer import SpatialTransformer
from models.vae.abstract_vae import AbstractVAE
from models.vae.reparameterizers import get_reparameterizer


class KeyReparameterizer(nn.Module):
    """"A fully flow based reparameterizer."""

    def __init__(self, continuous_size, latent_size, config, flow_type='maf'):
        """Builds a continuous --> flow reparameterizer.

        :param continuous_size: the size of the continuous projection (3 for spatial xformer)
        :param latent_size: latent size for flows
        :param config: argparse config
        :param flow_type: maf, realnvp, maf_split, maf_split_glow
        :returns: reparameterizer
        :rtype:

        """
        super(KeyReparameterizer, self).__init__()

        # Build the flow
        self.num_reads = config["memory_read_steps"]
        if flow_type == "maf":
            self.flow = flows.build_maf_flow(
                num_inputs=continuous_size,
                num_hidden=latent_size,
                num_cond_inputs=None,
                num_blocks=5,
                activation_str="tanh",
            )
        elif flow_type == "maf_split":
            self.flow = flows.build_maf_split_flow(
                num_inputs=continuous_size,
                num_hidden=latent_size,
                num_cond_inputs=None,
                num_blocks=5,
                s_activation_str="tanh",
                t_activation_str="relu",
            )
        elif flow_type == "maf_split_glow":
            self.flow = flows.build_maf_split_glow_flow(
                num_inputs=continuous_size,
                num_hidden=latent_size,
                num_cond_inputs=None,
                num_blocks=5,
                s_activation_str="tanh",
                t_activation_str="relu",
            )
        elif flow_type == "glow":
            self.flow = flows.build_glow_flow(
                num_inputs=continuous_size,
                num_hidden=latent_size,
                num_cond_inputs=None,
                num_blocks=5,
                activation_str=config["encoder_activation"],
                normalization_str=config["dense_normalization"],
                layer_modifier=config["encoder_layer_modifier"],
                cuda=config["cuda"],
            )
        elif flow_type == "realnvp":
            self.flow = flows.build_realnvp_flow(
                num_inputs=continuous_size,
                num_hidden=latent_size,
                num_cond_inputs=None,
                num_blocks=5,
                s_activation_str="tanh",
                t_activation_str="relu",
                normalization_str=config["dense_normalization"],
                cuda=config["cuda"],
            )

        self.input_size = continuous_size
        self.output_size = continuous_size

    def prior(self, batch_size, scale_var=1.0, **kwargs):
        noise = torch.Tensor(batch_size, self.flow.num_inputs).normal_(std=scale_var)
        gen = self.flow.sample(num_samples=batch_size, noise=noise)
        return gen

    def get_reparameterizer_scalars(self):
        """ Returns any scalars used in reparameterization.

        :returns: dict of scalars
        :rtype: dict

        """
        return {}

    def kl(self, dist_a, prior=None):
        recon, logdet = dist_a['recon_x'], dist_a['logdet']
        nll = (-0.5 * recon.pow(2) - 0.5 * np.log(2 * np.pi)).sum(-1)
        return torch.mean(-nll - logdet, -1)

    def forward(self, logits, force=False):
        batch_size, num_reads, input_size = logits.shape
        flow, logdet = self.flow(logits.contiguous().view(-1, input_size))
        flow = flow.view([-1, num_reads, input_size])
        logdet = logdet.view([-1, num_reads])
        return flow, {'recon_x': flow, 'logdet': logdet}


class KanervaPlusPlus(AbstractVAE):
    def __init__(self, input_shape, **kwargs):
        """Implements Kanerva++.

        :param input_shape: the input shape
        :returns: an object of AbstractVAE
        :rtype: AbstractVAE

        """
        super().__init__(input_shape, **kwargs)

        # the network to get the memory crops and the memory itself
        self.spatial_transformer = SpatialTransformer(self.config)
        self.memory_chans = self.config['memory_channels']
        self.memory_shape = [
            self.config["batch_size"],
            self.memory_chans,
            self.config["memory_size"],
            self.config["memory_size"],
        ]

        # the main memory parameter, resets every forward pass
        self.memory = None
        self.trace_shape = [
            self.memory_chans,
            self.config["window_size"],
            self.config["window_size"],
        ]

        # The read-key reparameterizer and projection model
        self.reparameterizer, self.key_encoder = self.build_key_model()

        # The memory writer model
        episode_len = self.config['episode_length']
        output_shape = (self.memory_chans, *self.memory_shape[-2:])
        writer_config = deepcopy(self.config)
        writer_config['decoder_layer_type'] = self.config['writer_layer_type']
        writer_config['conv_normalization'] = self.config['writer_conv_normalization']
        writer_config['dense_normalization'] = self.config['writer_dense_normalization']
        writer_config['decoder_layer_modifier'] = self.config['writer_layer_modifier']
        # writer_config['decoder_activation'] = 'swish'

        # writer_config['decoder_channel_multiplier'] = 1.0    # XXX: parameterize
        writer_config['decoder_base_channels'] = 256           # XXX: parameterize

        self.mem_writer = nn.Sequential(
            layers.get_decoder(output_shape=output_shape, **writer_config)(
                input_size=self.config['latent_size']       # [B, NW*C, H, W]
            ),
            layers.View([-1, episode_len, *output_shape]),  # [B, T, NW*C, H, W]
        )

        # build the encoder and decoder
        self.encoder = self.build_encoder()
        self.enc_of_dec, self.decoder = self.build_decoder()
        self.merger = self.build_merger_model()

        print(
            "Params\n\tencoder: {}\n\tdecoder: {}\n\tmem_writer: {}\n\tkey_encoder: {}".format(
                utils.number_of_parameters(self.encoder) / 1e6,
                utils.number_of_parameters(self.decoder) / 1e6,
                utils.number_of_parameters(self.mem_writer) / 1e6,
                utils.number_of_parameters(self.key_encoder) / 1e6,
            )
        )


    def build_key_model(self):
        """helper to build the key reparm and projection models.

        :returns: tuple(key_reparam, key_encoder)
        :rtype: tuple(Reparameterizer, nn.Module)

        """
        key_config = deepcopy(self.config)
        key_config['continuous_size'] = 6  # fixed size for spatial tranformer, [s, x, y] ~ N
        key_config['discrete_size'] = 3    # fixed size for spatial tranformer, [s, x, y] ~ D
        key_config['encoder_layer_type'] = 'dense'
        key_config['input_shape'] = [self.config['latent_size']]

        if key_config['flow_type'] is None:
            # standard reparam
            print("using standard {} reparameterizer".format(self.config['reparam_type']))
            key_reparameterizer = get_reparameterizer(self.config['reparam_type'])(config=key_config)
        else:
            # full flow version
            print("using flow model of type {}".format(key_config['flow_type']))
            key_reparameterizer = KeyReparameterizer(
                continuous_size=3,
                latent_size=512,
                config=key_config,
                flow_type=key_config["flow_type"],
            )

        # desired memory is # [NR, T, B, C, H, W]
        num_reads = key_config['memory_read_steps']
        key_encoder = nn.Sequential(
            layers.View([-1, *key_config['input_shape']]),
            layers.get_encoder(**key_config, name='key')(
                output_size=key_reparameterizer.input_size * num_reads),
            layers.View([-1, num_reads, key_reparameterizer.input_size])
        )

        return key_reparameterizer, key_encoder

    def build_encoder(self):
        """Helper to build the encoder type.

        :returns: an encoder
        :rtype: nn.Module

        """
        is_3d_model = 'tsm' in self.config['encoder_layer_type'] \
            or 's3d' in self.config['encoder_layer_type']
        import torchvision
        encoder = nn.Sequential(
            # fold in if we have a 3d model: [T*B, C, H, W]
            layers.View([-1, *self.input_shape]) if not is_3d_model else layers.Identity(),
            layers.get_encoder(
                norm_first_layer=True, norm_last_layer=False,
                layer_fn=torchvision.models.resnet18,
                pretrained=False, num_segments=1, shift_div=4,
                temporal_pool=False, **self.config)
            (
                output_size=self.config['latent_size']
            ),
            layers.View([-1, self.config['episode_length'], self.config['latent_size']])  # un-fold episode to [T, B, F]
        )

        is_torchvision_encoder = isinstance(encoder[1], (layers.TSMResnetEncoder,
                                                         layers.S3DEncoder,
                                                         layers.TorchvisionEncoder))
        if self.config['encoder_layer_modifier'] == 'sine' and is_torchvision_encoder:
            layers.convert_to_sine_module(encoder[1].model)

        return encoder

    def build_merger_model(self):
        config = deepcopy(self.config)
        config['encoder_layer_type'] = 'dense'
        config['input_shape'] = [self.config['latent_size']]  # encode+prev enode
        return nn.Sequential(
            layers.View([-1, *config['input_shape']]),
            layers.get_encoder(**config, name='episode_merger')(
                output_size=self.config['latent_size']),
        )

    def build_decoder(self):
        """helper function to build convolutional or dense decoder

        :returns: a decoder
        :rtype: nn.Module

        """
        dec_conf = deepcopy(self.config)
        if dec_conf['nll_type'] == 'pixel_wise':
            dec_conf['input_shape'][0] *= 256

        # The input is [B, reads*trace_chans, trace_x, trace_y]
        input_chans = self.config['memory_read_steps'] * self.memory_chans
        episode_len = self.config['episode_length']

        # Following for standard decoder from layers:
        enc_of_dec = deepcopy(self.config)
        enc_of_dec['input_shape'] = [input_chans, *self.trace_shape[1:]]
        enc_of_dec['encoder_layer_type'] = 'conv'

        enc_of_dec_model = nn.Sequential(
            layers.View([-1, *enc_of_dec['input_shape']]),
            layers.get_encoder(pretrained=False, **enc_of_dec)(
                output_size=self.config['latent_size']*2
            )
        )

        enc_of_dec_conf = deepcopy(self.config)
        enc_of_dec_conf['continuous_size'] = self.config['latent_size']*2
        enc_of_dec_conf['discrete_size'] = self.config['latent_size']
        self.enc_of_dec_reparameterizer = get_reparameterizer(self.config["reparam_type"])(
            config=enc_of_dec_conf
        )

        decoder = nn.Sequential(
            layers.get_decoder(output_shape=dec_conf['input_shape'], **dec_conf)(
                input_size=self.config['latent_size']
            )
        )

        # append the variance as necessary
        decoder = self._append_variance_projection(decoder)
        output_shape = dec_conf['input_shape']
        output_shape[0] *= 2 if dist.nll_has_variance(self.config['nll_type']) else 1
        decoder = nn.Sequential(decoder, layers.View([-1, episode_len, *output_shape]))

        return enc_of_dec_model, decoder

    def encode(self, x):
        """ Encodes a tensor x to a set of logits.

        :param x: the input tensor
        :returns: logits
        :rtype: torch.Tensor

        """
        encoded = self.encoder(x)  # [B, T, C, H, W] -> [B, T, F], no temporal pooling
        if encoded.dim() < 2:
            return encoded.unsqueeze(-1)

        return encoded

    def cleanup_read_loop(self, num_loops, inputs):
        """ Run the reconstruct loop (no writing), using the memory to cleanup the inputs.

        :param num_loops: the number of repeat forward pass loops.
        :param inputs: the inputs
        :returns: tensor with base inputs and each subsequent num_loop of cleanups
        :rtype: torch.Tensor

        """
        with torch.no_grad():
            cleanups = [inputs]  # original input is the one passed to this fn
            for _ in range(num_loops):
                inputs_i = [m.squeeze(1) for m in cleanups[-1].split(1, dim=1)]
                inputs_i = torch.cat([i.unsqueeze(1) for i in inputs_i], 1)
                encoded = self.encode(inputs_i)                            # result is [B, F]
                decoded, _, _ = self.read_loop(encoded, reparam_key=True)  # read stage of the process
                cleanups.append(self.nll_activation(decoded))

            return [r.view(-1, *r.shape[-3:]) for r in cleanups]

    def generate_synthetic_samples(self, batch_size, **kwargs):
        """ Generates synthetic samples.

        :param batch_size: the number of samples to generate.
        :returns: decoded logits
        :rtype: torch.Tensor

        """
        num_reads = self.config['memory_read_steps']
        episode_len = self.config['episode_length']
        generative_scale_var = kwargs.get(
            'generative_scale_var', self.config['generative_scale_var'])

        def generate_single_batch(batch_size):
            """This fn is called K times for > batch_size generations"""
            with torch.no_grad():
                if 'use_aggregate_posterior' in kwargs and kwargs['use_aggregate_posterior']:
                    z_logits = self.reparameterize_aggregate_posterior()       # [NR*T, B, 6]
                    z_mosiac_samples, _ = self.reparameterizer(z_logits, force=True)  # reparam to [NR*T, B, 3]
                else:
                    # Set the reparameterizer to eval mode [only useful for flow models]
                    training_tmp = self.reparameterizer.training
                    self.reparameterizer.train(False)
                    z_mosiac_samples = self.reparameterizer.prior(
                        batch_size * num_reads * episode_len,
                        scale_var=generative_scale_var,
                        **kwargs
                    ).view([batch_size, episode_len, num_reads, self.reparameterizer.output_size])
                    self.reparameterizer.train(training_tmp)

                def _generate(z_samples):
                    """Internal helper to generate initial samples."""
                    with torch.no_grad():
                        unactivated, _, _ = self.read_loop(z_samples, reparam_key=False)
                        return self.nll_activation(unactivated)

                # create a single z_mean vector for the keys
                z_mean_single = self.reparameterizer.prior(
                    batch_size * num_reads,
                    scale_var=generative_scale_var,
                    **kwargs
                ).view([batch_size, num_reads, self.reparameterizer.output_size])

                # duplicate the same random vector episode_len times and a small N(0, 0.2) perturbation
                z_mean_samples = torch.cat([z_mean_single.clone().unsqueeze(1) for _ in range(episode_len)], 1)
                z_mean_samples = torch.cat([(z_i + torch.randn_like(z_i) * 0.2).unsqueeze(0)  # add some gaussian noise
                                            for z_i in z_mean_samples], 0)

                # Converge to a fixed point using the read-loop again ang again
                num_loops = self.config['num_fixed_point_generation_iterations']
                mean = self.cleanup_read_loop(num_loops, _generate(z_mean_samples))
                mosiac = self.cleanup_read_loop(num_loops, _generate(z_mosiac_samples))

                # return mean and mosiac generated samples
                generations = {}
                for idx, (mean_i, mosiac_i) in enumerate(zip(mean, mosiac)):
                    generations['mean_perturbed_generated{}_imgs'.format(idx)] \
                        = mean_i.view(-1, *mean_i.shape[-3:])
                    generations['mosiac_generated{}_imgs'.format(idx)] \
                        = mosiac_i.view(-1, *mosiac_i.shape[-3:])

                return generations

        # To prevent OOM-ing we will generate a single batch many times
        full_generations, num_generated = {}, 0
        def detach_to_cpu(t): return t.detach().cpu()  # move the tensor to cpu memory
        while num_generated < batch_size:
            gen = tree.map_structure(detach_to_cpu, generate_single_batch(self.config['batch_size']))
            for k, v in gen.items():
                full_generations[k] = v if k not in full_generations \
                    else torch.cat([full_generations[k], v], 0)

            # increment generation count by using the first key in the generation map
            num_generated += gen[list(gen.keys())[0]].shape[0]

        # Return only the requested amount of samples.
        def reduce_to_requested(t): return t[-batch_size:]
        return tree.map_structure(reduce_to_requested, full_generations)

    def _expand_memory(self, key):
        """Internal helper to expand the memory view for the ST

        :param key: the indexing key: [B*T*NR, 3]
        :returns: an expanded view of the memory
        :rtype: torch.Tensor

        """
        if self.memory.shape[0] >= key.shape[0]:
            # Base case where we don't need to expand.
            # Useful for generations of less than batch_size.
            return self.memory[0:key.shape[0]]

        expand_count = int(np.ceil(key.shape[0] / float(self.memory.shape[0])))  # expand by NR
        expanded_memory_shape = [-1] + self.memory_shape[1:]
        expanded_memory = self.memory.unsqueeze(1).expand(  # [B, 1, C, H, W] -> [B, NR, C, H, W]
            -1, expand_count, -1, -1, -1).contiguous().view(expanded_memory_shape)  # [-1, C, H, W]
        return expanded_memory[0:key.shape[0]]                                      # [NR*B, C, H, W]

    def _read_memory_traces(self, keys):
        """Read the existing locations from memory.

        :param keys: the reparameterized keys [B, NR, 3]
        :returns: memory traces [B*NR, trace_c, trace_x, trace_y], sampling grids
        :rtype: torch.Tensor, torch.Tensor

        """
        flattened_keys = keys.view(-1, keys.shape[-1])            # flatten [B*NR, 3]
        expanded_memory = self._expand_memory(flattened_keys)     # [NR*B, C, H, W]
        memory_trace, grid = self.spatial_transformer(
            flattened_keys.contiguous(), expanded_memory.contiguous())   # [NR*B, C, H', W']
        return memory_trace, grid

    def write_loop(self, logits):
        """Simple memory writer that pools output channels to build the memory."""
        batch_size, episode_length, feature_size = logits.shape
        memory_logits = self.mem_writer(logits.view(-1, feature_size))

        # Reduce appropriately
        self.memory = torch.mean(memory_logits, 1)  # pool over episode

    def read(self, encoded, reparam_key=True):
        """Process the reading loop and returns a decoded image.

        :param encoded: base logits
        :param reparam_key: reparams the key logits (useful as False during generation)
        :returns: decoded images and memory traces
        :rtype: torch.Tensor, torch.Tensor

        """
        if reparam_key:  # reparameterize the read_keys (sample).
            read_key_logits = self.key_encoder(encoded)
            read_keys, read_key_params = self.reparameterizer(
                read_key_logits, force=False
            )  # reparam to [NR*T, B, 3]
        else:            # used in generation branch
            read_keys = encoded
            read_key_params = None

        # read the existing locations and decode
        memory_trace, _ = self._read_memory_traces(read_keys.contiguous())                  # [T*NR*B, C', H', W']
        memory_trace = memory_trace.view((encoded.shape[0], -1, *memory_trace.shape[-2:]))  # [B, T*C', H', W']

        enc_of_dec = self.enc_of_dec(memory_trace)
        enc_of_dec, enc_of_dec_params = self.enc_of_dec_reparameterizer(enc_of_dec.squeeze())
        merged = enc_of_dec
        if reparam_key:
            merged = self.merger(
                torch.cat([encoded.view(-1, self.config['latent_size'])], -1))
            read_key_params['enc_of_dec'] = enc_of_dec
            read_key_params['merged'] = merged
            read_key_params['enc_of_dec_params'] = enc_of_dec_params
            read_key_params['enc_of_dec_prior'] = {
                'gaussian': {'mu': merged, 'sigma': torch.ones_like(merged) * 0.1}
            }

        decoded = self.decoder(merged)                                                    # [T*NR*B, C, H, W]
        return [
            decoded,
            read_key_params,
            memory_trace.view([-1, *memory_trace.shape[-3:]]),
        ]

    def read_loop(self, encoded, reparam_key=True):
        """Process the reading loop and returns a decoded image.

        :param encoded: base logits
        :param reparam_key: reparams the key logits (useful as False during generation)
        :returns: decoded images and memory traces
        :rtype: torch.Tensor, torch.Tensor

        """
        decodes, read_key_params = [], []
        decodes, read_key_params, mem_traces = self.read(encoded, reparam_key)
        mem_traces_for_visualization = mem_traces[0, ::]
        mem_traces_for_visualization = mem_traces_for_visualization.view(
            [-1, 1, *mem_traces_for_visualization.shape[-2:]]
        )
        return [decodes,
                read_key_params,
                mem_traces_for_visualization]

    def forward(self, inputs, labels=None):
        """Accepts input, gets posterior and latent and decodes.

        :param inputs: input tensors as a list.
        :param labels: (optional) labels
        :returns: decoded logits and reparam dict
        :rtype: torch.Tensor, dict

        """
        assert isinstance(
            inputs, (list, tuple)
        ), "need a set of inputs as list[Tensor], got {}".format(type(inputs))

        inputs = torch.cat([i.unsqueeze(1) for i in inputs], 1)     # expand to [B, T, C, W, H]
        encoded = self.encode(inputs)                      # result is [B, F]

        self.write_loop(encoded)                                              # write stage of the process
        decoded, read_params, read_imgs = self.read_loop(encoded)             # read stage of the process
        full_params = {
            'read_imgs': read_imgs,
            'key_reparam': read_params,
        }
        return decoded, full_params

    def kld(self, dist_a):
        """ KL-Divergence of the distribution dict and the prior of that distribution.

        :param dist_a: the distribution dict.
        :returns: tensor that is of dimension batch_size
        :rtype: torch.Tensor

        """
        episode_len = self.config['episode_length']
        num_reads = self.config['memory_read_steps']
        kld_z = self.enc_of_dec_reparameterizer.kl(
            dist_a["enc_of_dec_params"], prior=dist_a["enc_of_dec_prior"]
        )
        kld_z = torch.mean(kld_z.view(-1, episode_len), 1)

        if self.config['flow_type']is not None:
            kld = self.reparameterizer.kl(dist_a)
            return torch.mean(kld.view(-1, episode_len), 1) + kld_z

        def unsqueezer(x):
            return (
                x.view([-1, episode_len, num_reads, x.shape[-1]]) if x.dim() > 1 else x
            )

        dist_a = tree.map_structure(unsqueezer, dist_a)
        kld = self.reparameterizer.kl(dist_a)
        return torch.mean(kld, (1, 2)) + kld_z

    def preprocess_minibatch_and_labels(self, minibatch, labels):
        """Simple helper to push the minibatch to the correct device and shape."""
        minibatch = minibatch.cuda(non_blocking=True) if self.config['cuda'] else minibatch
        labels = labels.cuda(non_blocking=True) if self.config['cuda'] else labels
        return minibatch, labels

    def loss_function(self, recon_x, x, params, K=1):
        """ Loss function for Kanerva ++

        :param recon_x: the reconstruction logits
        :param x: the original samples
        :param params: the reparameterized parameter dict containing reparam
        :param K: number of monte-carlo samples to use.
        :returns: loss dict
        :rtype: dict

        """
        x = torch.cat([xi.unsqueeze(1) for xi in x], 1).contiguous()  # [B, T, C, H, W]
        loss_dict = super().loss_function(
            recon_x=recon_x, x=x, params=params['key_reparam'], K=K)

        # Add extra tracking metrics
        loss_dict['memory_difference_norm_mean'] = torch.norm(self.memory[0] - self.memory[1])
        loss_dict['bits_per_dim_recon_mean'] = (loss_dict['nll_mean'] / np.prod(self.input_shape)) / np.log(2)
        loss_dict['bits_per_dim_elbo_mean'] = (loss_dict['elbo_mean'] / np.prod(self.input_shape)) / np.log(2)

        # Return full loss dict
        return loss_dict

    def get_images_from_reparam(self, reparam_dict):
        """ returns a dictionary of images from the reparam map

        :param reparam_maps_list: a list of reparam dicts
        :returns: a dictionary of images
        :rtype: dict

        """
        read_shape = reparam_dict['read_imgs'].shape
        return {
            'memory_imgs': self.memory.view(-1, 1, *self.memory.shape[-2:]),  # Flatten mem channels
            'read_traces_minibatch0_imgs': reparam_dict['read_imgs'].view(-1, *read_shape[-3:]),
        }

    def get_activated_reconstructions(self, reconstr_container):
        """ Returns activated reconstruction

        :param reconstr: unactivated reconstr logits list
        :returns: activated reconstr
        :rtype: dict

        """
        activated_recon = self.nll_activation(reconstr_container)
        shape = activated_recon.shape
        return {
            'reconstruction_imgs': activated_recon.view(-1, *shape[-3:])
        }
