from trajectory.models.transformers import *
from trajectory.models.continous_transformers import ContinuousGPT
import torch
import torch.nn as nn
from trajectory.models.autoencoders import Encoder, Decoder
from trajectory.search.sampling import sample_n_continuous

class MLPActionModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.condition_size = config.observation_dim
        self.trajectory_input_length = config.action_dim * (config.block_size//config.transition_dim-1) + self.condition_size

        encoder_layer_sizes = [self.trajectory_input_length, 256]
        decoder_layer_sizes = [256, self.trajectory_input_length - self.condition_size]

        self.encoder = Encoder(
            encoder_layer_sizes, config.trajectory_embd, 0)
        self.decoder = Decoder(
            decoder_layer_sizes, config.trajectory_embd, self.condition_size)

    def encode(self, actions, state):
        """
            actions: [B x T x action_dimension]
            state: [B x observation_dimension]
        """
        B, _, _ = actions.shape
        inputs = torch.reshape(actions, shape=[B, -1])
        latents = self.encoder(torch.cat([inputs, state], dim=-1))
        return latents

    def decode(self, latents, state):
        """
            latents: [B x latent_size]
            state: [B x observation_dimension]
        """
        B, _ = latents.shape
        state_flat = torch.reshape(state, shape=[B, -1])
        inputs = torch.cat([state_flat, latents], dim=-1)
        reconstructed = self.decoder(inputs)
        return reconstructed


@torch.no_grad()
def model_rollout_continuous(init_state, actions, forward_model):
    """
        state: [B x observation_dimension]
        actions: [B x T x action_dimension]
    """
    state = init_state[:, None, :]
    B, nb_steps, _ = actions.shape
    x = None

    for step in range(nb_steps):
        new_transition = torch.cat([state, actions[:, step, None, :], torch.zeros([B, 1, 2], device=state.device)], dim=-1)
        if not x is None:
            x = torch.cat([x, new_transition], dim=1)
        else:
            x = new_transition

        state_pred, _, reward_pred, value_pred, _, loss = forward_model(x)
        x[:, -1, -2, None] = reward_pred[:, -1]
        x[:, -1, -1, None] = value_pred[:,-1]
        state = state_pred[:, -1, None]
    return x


class ContinuousActionVAE(nn.Module):
    """  the full GPT language model, with a context size of block_size """

    def __init__(self, config):
        super().__init__()

        # input embedding stem (+1 for stop token)
        self.action_vae = MLPActionModel(config)
        self.forward_model = ContinuousGPT(config)
        self.trajectory_embd = config.trajectory_embd
        self.block_size = config.block_size
        self.observation_dim = config.observation_dim

        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim
        self.action_weight = config.action_weight
        self.reward_weight = config.reward_weight
        self.value_weight = config.value_weight

        self.apply(self._init_weights)

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, EinLinear)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        no_decay.add('forward_model.pos_emb')
        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def decode(self, latent, state):
        actions = self.action_vae.decode(latent, state).reshape([latent.shape[0], -1, self.action_dim])
        return model_rollout_continuous(state, actions, self.forward_model)

    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)

        return mu + eps * std

    def forward(self, joined_inputs, targets=None, mask=None, returnx=False):
        """
            joined_inputs : [ B x T x joined_dimension]
            values : [ B x 1 x 1 ]
        """
        b, t, joined_dimension = joined_inputs.size()

        state = joined_inputs[:,0,:self.observation_dim]
        actions = joined_inputs[:,:,self.observation_dim:self.observation_dim+self.action_dim]
        ## [ B x T x embedding_dim ]
        # forward the GPT model
        means, log_var = self.action_vae.encode(actions, state)
        z = self.reparameterize(means, log_var)

        reconstructed = self.action_vae.decode(z, state)
        reconstructed = torch.reshape(reconstructed, shape=[b, t, self.action_dim])

        # if we are given some desired targets also calculate the loss
        if targets is not None:
            kl = -0.5 * torch.mean(1 + log_var - means.pow(2) - log_var.exp())
            #kl = torch.mean(-0.5 * torch.sum(1 + log_var - means.pow(2) - log_var.exp(), dim=1), dim=0)
            mse = F.mse_loss(reconstructed, actions, reduction='none')
            #reconstruction_loss = torch.sqrt((mse * mask).sum(dim=1)).mean()
            reconstruction_loss = mse.mean()
        else:
            reconstruction_loss = None
            kl = None

        _, _, _, _, _, autoregressive_loss = self.forward_model(joined_inputs, targets, mask)
        return reconstructed, reconstruction_loss, kl, autoregressive_loss