from trajectory.models.transformers import *
import torch

class ContinuousGPT(nn.Module):
    """  the full GPT language model, with a context size of block_size """

    def __init__(self, config):
        super().__init__()

        # input embedding stem (+1 for stop token)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        self.drop = nn.Dropout(config.embd_pdrop)
        # transformer
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(config.n_embd)
        # self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.head = nn.Linear(config.n_embd, config.transition_dim)
        self.nb_action_heads = 1

        self.block_size = config.block_size
        self.observation_dim = config.observation_dim

        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim
        self.action_weight = config.action_weight
        self.reward_weight = config.reward_weight
        self.value_weight = config.value_weight

        self.embedding_dim = config.n_embd

        self.state_emb = nn.Linear(self.observation_dim, self.embedding_dim)
        self.action_emb = nn.Linear(self.action_dim, self.embedding_dim)
        self.reward_emb = nn.Linear(1, self.embedding_dim)
        self.value_emb = nn.Linear(1, self.embedding_dim)

        self.pred_state = nn.Linear(self.embedding_dim, self.observation_dim)
        self.pred_action = nn.Sequential(nn.Linear(self.embedding_dim, self.action_dim*self.nb_action_heads))
        if self.nb_action_heads > 1:
            self.action_probs = nn.Sequential(nn.Linear(self.embedding_dim, self.nb_action_heads))
        self.pred_reward = nn.Linear(self.embedding_dim, 1)
        self.pred_value = nn.Linear(self.embedding_dim, 1)

        self.apply(self._init_weights)

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, EinLinear)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add('pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def pad_to_full_observation(self, x, verify=False):
        b, t, _ = x.shape
        n_pad = (self.transition_dim - t % self.transition_dim) % self.transition_dim
        padding = torch.zeros(b, n_pad, self.embedding_dim, device=x.device)
        ## [ B x T' x embedding_dim ]
        x_pad = torch.cat([x, padding], dim=1)
        ## [ (B * T' / transition_dim) x transition_dim x embedding_dim ]
        x_pad = x_pad.view(-1, self.transition_dim, self.embedding_dim)
        if verify:
            self.verify(x, x_pad)
        return x_pad, n_pad

    def verify(self, x, x_pad):
        b, t, embedding_dim = x.shape
        n_states = int(np.ceil(t / self.transition_dim))
        inds = torch.arange(0, self.transition_dim).repeat(n_states)[:t]
        for i in range(self.transition_dim):
            x_ = x[:,inds == i]
            t_ = x_.shape[1]
            x_pad_ = x_pad[:,i].view(b, n_states, embedding_dim)[:,:t_]
            print(i, x_.shape, x_pad_.shape)
            try:
                assert (x_ == x_pad_).all()
            except:
                pdb.set_trace()


    def forward(self, joined_inputs, targets=None, mask=None, returnx=False):
        """
            joined_inputs : [ B x T x joined_dimension]
            values : [ B x 1 x 1 ]

            returns:
            The action distribution is modelled as a mixture of guassian with a fixed variance.
            When nb_action_heads=1, the ensemble action pred will be of shape [B x T x action_dimension]
            Otherwise it would be [B x T x nb_action_heads x action_dimension],
            where each action head corresponds to the a mean of a (sub) guassian distribution.
            WHen nb_action_heads>1, the action_prob is then the latent categoricial probability
        """
        b, t, joined_dimension = joined_inputs.size()
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."

        # forward the GPT model
        states = joined_inputs[:,:,:self.observation_dim]
        actions = joined_inputs[:,:,self.observation_dim:self.observation_dim+self.action_dim]
        rewards = joined_inputs[:,:,-2, None]
        values = joined_inputs[:,:,-1, None]

        state_embeddings = self.state_emb(states)
        action_embeddings = self.action_emb(actions)
        reward_embeddings = self.reward_emb(rewards)
        value_embeddings = self.value_emb(values)

        token_embeddings = torch.stack([state_embeddings, action_embeddings, reward_embeddings, value_embeddings], dim=1)\
            .permute([0,2,1,3]).reshape(b, 4*t, self.embedding_dim)
        ## [ 1 x 4T x embedding_dim ]
        position_embeddings = self.pos_emb[:, :4*t, :] # each position maps to a (learnable) vector
        ## [ B x 4T x embedding_dim ]
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        ## [ B x 4T x embedding_dim ]
        x = self.ln_f(x)

        x = x.reshape(b, t, 4, self.embedding_dim).permute([0,2,1,3])

        ## [B x T x obs_dim]
        state_pred = self.pred_state(x[:,1]) # next state
        ensemble_action_pred = self.pred_action(x[:,0]) # current action
        reward_pred = self.pred_reward(x[:,1]) # current reward
        value_pred = self.pred_value(x[:,1]) # current value

        if self.nb_action_heads > 1:
            logits = self.action_probs(x[:,0])
            onehot = torch.nn.functional.gumbel_softmax(logits, hard=True)
            ensemble_action_pred = ensemble_action_pred.reshape([b, t, self.nb_action_heads, self.action_dim])
            action_pred = torch.sum(ensemble_action_pred*onehot[:,:,:,None], dim=-2)
            action_prob = torch.nn.functional.softmax(logits, dim=-1)
        else:
            action_pred = ensemble_action_pred
            action_prob = None

        # if we are given some desired targets also calculate the loss
        if targets is not None:
            ## [B x T x transition_dim]
            joined_pred = torch.cat([state_pred, action_pred, reward_pred, value_pred], dim=-1)

            target_states = targets[:, :, :self.observation_dim]
            target_actions = joined_inputs[:, :, self.observation_dim:self.observation_dim + self.action_dim]
            target_rewards = joined_inputs[:, :, -2, None]
            target_values = joined_inputs[:, :, -1, None]

            joined_target = torch.cat([target_states, target_actions, target_rewards, target_values], dim=-1)

            loss = torch.nn.functional.mse_loss(joined_pred, joined_target, reduction="none")
            loss = (loss*mask)
        else:
            loss = None

        return state_pred, ensemble_action_pred, reward_pred, value_pred, action_prob, loss
