from trajectory.models.transformers import *
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, layer_sizes, latent_size, condition_size):

        super().__init__()
        layer_sizes[0] += condition_size

        self.MLP = nn.Sequential()

        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            self.MLP.add_module(
                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
            self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())

        self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
        self.linear_log_var = nn.Linear(layer_sizes[-1], latent_size)

    def forward(self, x):

        x = self.MLP(x)

        means = self.linear_means(x)
        log_vars = self.linear_log_var(x)

        return means, log_vars


class Decoder(nn.Module):

    def __init__(self, layer_sizes, latent_size, condition_size):
        super().__init__()
        self.MLP = nn.Sequential()

        input_size = latent_size + condition_size

        for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
            self.MLP.add_module(
                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
            if i+1 < len(layer_sizes):
                self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())

    def forward(self, z):
        x = self.MLP(z)

        return x


class DiscreteVAE(nn.Module):
    """  the full GPT language model, with a context size of block_size """

    def __init__(self, config):
        super().__init__()

        # input embedding stem (+1 for stop token)
        self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd)

        self.condition_size = config.n_embd*config.observation_dim
        self.trajectory_input_length = config.n_embd*(config.block_size-1) - self.condition_size

        encoder_layer_sizes = [self.trajectory_input_length, 1024]
        decoder_layer_sizes = [1024, (config.vocab_size+1)*(config.block_size-1)]

        self.encoder = Encoder(
            encoder_layer_sizes, config.n_embd*10, self.condition_size)
        self.decoder = Decoder(
            decoder_layer_sizes, config.n_embd*10, self.condition_size)

        self.vocab_size = config.vocab_size
        self.stop_token = config.vocab_size * config.transition_dim
        self.block_size = config.block_size
        self.observation_dim = config.observation_dim

        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim
        self.action_weight = config.action_weight
        self.reward_weight = config.reward_weight
        self.value_weight = config.value_weight

        self.embedding_dim = config.n_embd
        self.apply(self._init_weights)

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, EinLinear)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)


        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def encode(self, X, state):
        """
            X: [B x (T-observation_dim) x EmbeddingSize]
            state: [B x observation_dim x EmbeddingSize]
        """
        B, _, _ = X.shape
        inputs = torch.cat([X, state], dim=1)
        inputs = torch.reshape(inputs, shape=[B, -1])
        latents = self.encoder(inputs)
        return latents

    def decode(self, latents, state):
        """
            latents: [B x latent_size]
            state: [B x observation_dim x EmbeddingSize]
        """
        B, _ = latents.shape
        state_flat = torch.reshape(state, shape=[B, -1])
        inputs = torch.cat([state_flat, latents], dim=-1)
        reconstructed = self.decoder(inputs)
        return reconstructed

    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)

        return mu + eps * std

    def offset_tokens(self, idx):
        _, t = idx.shape
        n_states = int(np.ceil(t / self.transition_dim))
        offsets = torch.arange(self.transition_dim) * self.vocab_size
        offsets = offsets.repeat(n_states).to(idx.device)
        offset_idx = idx + offsets[:t]
        offset_idx[idx == self.vocab_size] = self.stop_token
        return offset_idx

    def forward(self, idx, targets=None, mask=None, returnx=False):
        """
            idx : [ B x T ]
            values : [ B x 1 x 1 ]
        """
        b, t = idx.size()
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."
        offset_idx = self.offset_tokens(idx)

        ## [ B x T x embedding_dim ]
        # forward the GPT model
        state, trajectory = self.tok_emb(offset_idx[:,:self.observation_dim]), self.tok_emb(offset_idx[:,self.observation_dim:])
        means, log_var = self.encode(trajectory, state)
        z = self.reparameterize(means, log_var)
        logits = self.decode(z, state)
        logits = torch.reshape(logits, shape=[b, t, self.vocab_size+1])

        # if we are given some desired targets also calculate the loss
        if targets is not None:
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), idx.view(-1), reduction='none')
            if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1:
                #### make weights
                n_states = int(np.ceil(t / self.transition_dim))
                weights = torch.cat([
                    torch.ones(self.observation_dim, device=idx.device),
                    torch.ones(self.action_dim, device=idx.device) * self.action_weight,
                    torch.ones(1, device=idx.device) * self.reward_weight,
                    torch.ones(1, device=idx.device) * self.value_weight,
                ])
                ## [ t + 1]
                weights = weights.repeat(n_states)
                ## [ b x t ]
                weights = weights[:-1].repeat(b, 1)
                ####
                loss = loss * weights.view(-1)
            loss = (loss * mask.view(-1)).mean()
        else:
            loss = None
        return logits, loss


class MLPModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.condition_size = config.observation_dim
        self.trajectory_input_length = config.block_size - config.transition_dim

        encoder_layer_sizes = [self.trajectory_input_length, 512, 256]
        decoder_layer_sizes = [256, 512, self.trajectory_input_length]

        self.encoder = Encoder(
            encoder_layer_sizes, config.trajectory_embd, 0)
        self.decoder = Decoder(
            decoder_layer_sizes, config.trajectory_embd, self.condition_size)

    def encode(self, X):
        """
            X: [B x T x transition_dimension]
        """
        B, _, _ = X.shape
        inputs = torch.reshape(X, shape=[B, -1])
        latents = self.encoder(inputs)
        return latents

    def decode(self, latents, state):
        """
            latents: [B x latent_size]
            state: [B x observation_dimension]
        """
        B, _ = latents.shape
        state_flat = torch.reshape(state, shape=[B, -1])
        inputs = torch.cat([state_flat, latents], dim=-1)
        reconstructed = self.decoder(inputs)
        return reconstructed

class SymbolWiseTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.latent_size = config.trajectory_embd
        self.condition_size = config.observation_dim
        self.trajectory_input_length = config.block_size - config.transition_dim
        self.embedding_dim = config.n_embd
        self.trajectory_length = 4*(config.block_size//config.transition_dim-1)
        self.block_size = config.block_size
        self.observation_dim = config.observation_dim
        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim

        self.encoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.decoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])

        self.pos_emb = nn.Parameter(torch.zeros(1, self.trajectory_length, config.n_embd))

        self.state_emb = nn.Linear(self.observation_dim, self.embedding_dim)
        self.action_emb = nn.Linear(self.action_dim, self.embedding_dim)
        self.reward_emb = nn.Linear(1, self.embedding_dim)
        self.value_emb = nn.Linear(1, self.embedding_dim)

        self.pred_state = nn.Linear(self.embedding_dim, self.observation_dim)
        self.pred_action = nn.Sequential(nn.Linear(self.embedding_dim, self.action_dim))
        self.pred_reward = nn.Linear(self.embedding_dim, 1)
        self.pred_value = nn.Linear(self.embedding_dim, 1)

        self.linear_means = nn.Linear(self.embedding_dim, self.latent_size)
        self.linear_log_var = nn.Linear(self.embedding_dim, self.latent_size)
        self.latent_mixing = nn.Linear(self.latent_size+self.observation_dim, self.embedding_dim)

        self.ln_f = nn.LayerNorm(config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)


    def encode(self, joined_inputs):
        b, t, joined_dimension = joined_inputs.size()
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."

        # forward the GPT model
        states = joined_inputs[:, :, :self.observation_dim]
        actions = joined_inputs[:, :, self.observation_dim:self.observation_dim + self.action_dim]
        rewards = joined_inputs[:, :, -2, None]
        values = joined_inputs[:, :, -1, None]

        state_embeddings = self.state_emb(states)
        action_embeddings = self.action_emb(actions)
        reward_embeddings = self.reward_emb(rewards)
        value_embeddings = self.value_emb(values)

        token_embeddings = torch.stack([state_embeddings, action_embeddings, reward_embeddings, value_embeddings],
                                       dim=1) \
            .permute([0, 2, 1, 3]).reshape(b, 4 * t, self.embedding_dim)
        ## [ 1 x 4T x embedding_dim ]
        position_embeddings = self.pos_emb[:, :4 * t, :]  # each position maps to a (learnable) vector
        ## [ B x 4T x embedding_dim ]
        x = self.drop(token_embeddings + position_embeddings)
        x = self.encoder(x)
        ## [ B x 4T x embedding_dim ]
        trajectory_feature = x.max(dim=1).values
        means = self.linear_means(trajectory_feature)
        log_vars = self.linear_log_var(trajectory_feature)
        return means, log_vars

    def decode(self, latents, state):
        """
            latents: [B x latent_size]
            state: [B x observation_dimension]
        """
        B, _ = latents.shape
        state_flat = torch.reshape(state, shape=[B, -1])
        inputs = torch.cat([state_flat, latents], dim=-1)
        inputs = self.latent_mixing(inputs)
        inputs = inputs[:, None, :] + self.pos_emb[:, :]
        x = self.decoder(inputs)
        x = self.ln_f(x)

        x = x.reshape(B, -1, 4, self.embedding_dim).permute([0,2,1,3])

        ## [B x T x obs_dim]
        state_pred = self.pred_state(x[:,1]) # next state
        action_pred = self.pred_action(x[:,0]) # current action
        reward_pred = self.pred_reward(x[:,1]) # current reward
        value_pred = self.pred_value(x[:,1]) # current value
        joined_pred = torch.cat([state_pred, action_pred, reward_pred, value_pred], dim=-1)

        return joined_pred


class StepWiseTranformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.latent_size = config.trajectory_embd
        self.condition_size = config.observation_dim
        self.trajectory_input_length = config.block_size - config.transition_dim
        self.embedding_dim = config.n_embd
        self.trajectory_length = config.block_size//config.transition_dim-1
        self.block_size = config.block_size
        self.observation_dim = config.observation_dim
        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim

        self.encoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.decoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])

        self.pos_emb = nn.Parameter(torch.zeros(1, self.trajectory_length, config.n_embd))

        self.embed = nn.Linear(self.transition_dim, self.embedding_dim)

        self.predict = nn.Linear(self.embedding_dim, self.transition_dim)

        self.linear_means = nn.Linear(self.embedding_dim, self.latent_size)
        self.linear_log_var = nn.Linear(self.embedding_dim, self.latent_size)
        self.latent_mixing = nn.Linear(self.latent_size+self.observation_dim, self.embedding_dim)

        self.ln_f = nn.LayerNorm(config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)


    def encode(self, joined_inputs):
        b, t, joined_dimension = joined_inputs.size()
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."

        # forward the GPT model
        token_embeddings = self.embed(joined_inputs)

        ## [ 1 x 4T x embedding_dim ]
        position_embeddings = self.pos_emb[:, :t, :]  # each position maps to a (learnable) vector
        ## [ B x 4T x embedding_dim ]
        x = self.drop(token_embeddings + position_embeddings)
        x = self.encoder(x)
        ## [ B x 4T x embedding_dim ]
        trajectory_feature = x.max(dim=1).values
        means = self.linear_means(trajectory_feature)
        log_vars = self.linear_log_var(trajectory_feature)
        return means, log_vars

    def decode(self, latents, state):
        """
            latents: [B x latent_size]
            state: [B x observation_dimension]
        """
        B, _ = latents.shape
        state_flat = torch.reshape(state, shape=[B, -1])
        inputs = torch.cat([state_flat, latents], dim=-1)
        inputs = self.latent_mixing(inputs)
        inputs = inputs[:, None, :] + self.pos_emb[:, :]
        x = self.decoder(inputs)
        x = self.ln_f(x)

        ## [B x T x obs_dim]
        joined_pred = self.predict(x)
        joined_pred[:, :, -1] = torch.sigmoid(joined_pred[:, :, -1])
        return joined_pred


class ContinuousVAE(nn.Module):
    def __init__(self, config):
        super().__init__()

        # input embedding stem (+1 for stop token)
        if config.model == "Transformer":
            self.model = StepWiseTranformer(config)
        elif config.model == "MLP":
            self.model = MLPModel(config)
        self.trajectory_embd = config.trajectory_embd
        self.vocab_size = config.vocab_size
        self.stop_token = config.vocab_size * config.transition_dim
        self.block_size = config.block_size
        self.observation_dim = config.observation_dim

        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim
        self.action_weight = config.action_weight
        self.reward_weight = config.reward_weight
        self.value_weight = config.value_weight

        self.padding_vector = torch.zeros(self.transition_dim)
        self.apply(self._init_weights)

    def get_block_size(self):
        return self.block_size

    def set_padding_vector(self, padding):
        self.padding_vector = padding

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, EinLinear)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        if isinstance(self.model, SymbolWiseTransformer) or isinstance(self.model, StepWiseTranformer):
            no_decay.add('model.pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def decode(self, latent, state):
        return self.model.decode(latent, state)

    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)

        return mu + eps * std

    def forward(self, joined_inputs, targets=None, mask=None, terminals=None, returnx=False):
        """
            joined_inputs : [ B x T x joined_dimension]
            values : [ B x 1 x 1 ]
        """
        b, t, joined_dimension = joined_inputs.size()
        padded = torch.tensor(self.padding_vector, dtype=torch.float32,
                              device=joined_inputs.device).repeat(b, t, 1)
        if terminals is not None:
            terminal_mask = torch.clone(1 - terminals).repeat(1, 1, joined_inputs.shape[-1])
            joined_inputs = joined_inputs*terminal_mask+(1-terminal_mask)*padded

        state = joined_inputs[:,0,:self.observation_dim]
        ## [ B x T x embedding_dim ]
        # forward the GPT model
        means, log_var = self.model.encode(torch.cat([joined_inputs, terminals], dim=2))
        z = self.reparameterize(means, log_var)
        reconstructed = self.model.decode(z, state)
        pred_trajectory = torch.reshape(reconstructed[:, :, :-1], shape=[b, t, joined_dimension])
        pred_terminals = reconstructed[:, :, -1, None]

        # if we are given some desired targets also calculate the loss
        if targets is not None:
            kl = -0.5 * torch.mean(1 + log_var - means.pow(2) - log_var.exp())
            #kl = torch.mean(-0.5 * torch.sum(1 + log_var - means.pow(2) - log_var.exp(), dim=1), dim=0)
            weights = torch.cat([
                torch.ones(self.observation_dim, device=joined_inputs.device),
                torch.ones(self.action_dim, device=joined_inputs.device) * self.action_weight,
                torch.ones(1, device=joined_inputs.device) * self.reward_weight,
                torch.ones(1, device=joined_inputs.device) * self.value_weight,
            ])
            mse = F.mse_loss(pred_trajectory, joined_inputs, reduction='none')*weights[None, None, :]
            #reconstruction_loss = torch.sqrt((mse * mask).sum(dim=1)).mean()
            cross_entropy = F.binary_cross_entropy(pred_terminals, torch.clip(terminals.float(), 0.0, 1.0))
            reconstruction_loss = (mse*mask*terminal_mask*mask).mean()+cross_entropy
        else:
            reconstruction_loss = None
            kl = None
        return reconstructed, reconstruction_loss, kl, 0