import numpy as np
import torch
import ast
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import transformers

from hrl.trajectory_model import TrajectoryModel
from hrl.trajectory_gpt2 import GPT2Model
from hrl.img_encoder import Encoder

from hrl import iq
from hrl.utils import pad


class DecisionTransformer(TrajectoryModel):

    """
    This model uses GPT to model (Lang, state_1, action_1, state_2, ...) or (state_1, option_1, action_1, ...)
    """

    def __init__(
            self,
            state_dim,
            action_dim,
            option_dim,
            lang_dim,
            discrete,
            hidden_size,
            use_language=False,
            use_options=True,
            option_il=False,
            predict_q=False,
            max_length=None,
            max_ep_len=4096,
            action_tanh=False,
            no_states=False,
            no_actions=False,
            ** kwargs):
        # max_length used to be K
        super().__init__(state_dim, action_dim, max_length=max_length)

        self.use_options = use_options
        self.use_language = use_language
        self.option_il = option_il
        self.predict_q = predict_q

        if use_language and use_options:
            raise ValueError("Cannot use language and options!")
        if not use_language and not use_options:
            raise ValueError("Have to use language or options!")
        self.option_dim = option_dim
        self.discrete = discrete

        self.hidden_size = hidden_size
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            **kwargs
        )
        if isinstance(self.state_dim, str):
            self.state_dim = ast.literal_eval(self.state_dim)

        if isinstance(self.state_dim, tuple):
            # LORL
            if self.state_dim[0] == 3:
                # LORL Sawyer
                self.embed_state = Encoder(self.state_dim, hidden_size=hidden_size, ch=3, robot=False)
            else:
                # LORL Franka
                self.embed_state = Encoder(self.state_dim, hidden_size=hidden_size, ch=12, robot=True)
        else:
            self.embed_state = nn.Linear(self.state_dim, hidden_size)

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = GPT2Model(config)

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)

        self.embed_action = nn.Linear(self.act_dim, hidden_size)

        self.no_states = no_states
        self.no_actions = no_actions

        if use_options:
            self.embed_option = nn.Linear(self.option_dim, hidden_size)

        if use_language:
            self.embed_lang = nn.Linear(lang_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)
        # note: we don't predict states or returns for the paper
        if isinstance(self.state_dim, int):
            self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh and not discrete else []))
        )
        if use_options:
            self.predict_option = torch.nn.Linear(hidden_size, self.option_dim)
        if predict_q:
            self.predict_q = torch.nn.Linear(hidden_size, self.act_dim)

    def forward(self, states, actions, timesteps, options=None, word_embeddings=None, attention_mask=None):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            raise ValueError('Should not have attention_mask NONE')
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

        if self.use_options:
            assert options is not None
            option_embeddings = self.embed_option(options)
            time_embeddings = self.embed_timestep(timesteps)

            # time embeddings are treated similar to positional embeddings
            option_embeddings = option_embeddings + time_embeddings

            if self.no_states:
                # IMP: MAKE SURE THIS IS NOT SET ON BY DEFAULT
                state_embeddings = self.embed_state(torch.zeros_like(states))
            else:
                state_embeddings = self.embed_state(states)
                state_embeddings = state_embeddings + time_embeddings

            if self.no_actions:
                # IMP: MAKE SURE THIS IS NOT SET ON BY DEFAULT
                action_embeddings = self.embed_action(torch.zeros_like(actions))
            else:
                action_embeddings = self.embed_action(actions)
                action_embeddings = action_embeddings + time_embeddings

            # this makes the sequence look like (o1, s1, a1,o2, s2, a2, ...)
            # which works nice in an autoregressive sense since states predict actions
            # note that o1 and o2 need not be different
            stacked_inputs = torch.stack(
                (option_embeddings, state_embeddings, action_embeddings),
                dim=1).permute(
                0, 2, 1, 3).reshape(
                batch_size, 3 * seq_length, self.hidden_size)
            # LAYERNORM
            stacked_inputs = self.embed_ln(stacked_inputs)

            # to make the attention mask fit the stacked inputs, have to stack it as well
            stacked_attention_mask = torch.stack(
                (attention_mask, attention_mask, attention_mask), dim=1
            ).permute(0, 2, 1).reshape(batch_size, 3 * seq_length)

            # we feed in the input embeddings (not word indices as in NLP) to the model
            transformer_outputs = self.transformer(
                inputs_embeds=stacked_inputs,
                attention_mask=stacked_attention_mask,
            )
            x = transformer_outputs['last_hidden_state']

            # reshape x so that the second dimension corresponds to the original
            # options (0), states (1) or actions (2); i.e. x[:,0,t] is the token for s_t
            traj_out = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
            # get predictions
            # predict next state given option, state and action. skip the last state for prediction
            if isinstance(self.state_dim, int):
                state_preds = self.predict_state(traj_out[:, 2])[:, :-1, :]
            else:
                state_preds = None
            # predict next action given state and option
            action_preds = self.predict_action(traj_out[:, 1])

            # reconstruct current option given current option
            if self.option_il:
                option_preds = self.predict_option(traj_out[:, 0])
                options_loss = F.mse_loss(option_preds, options.detach())
            else:
                options_loss = None

            outputs = {'state_preds': state_preds,
                       'action_preds': action_preds,
                       'options_loss': options_loss}

            if self.predict_q:
                # predict next Q given state and option   ## IMP: Don't use current action
                q_preds = self.predict_q(traj_out[:, 1])
                outputs.update({'q_preds': q_preds})

            return outputs

        if self.use_language:
            assert word_embeddings is not None
            num_tokens = word_embeddings.shape[1]
            state_embeddings = self.embed_state(states)
            lang_embeddings = self.embed_lang(word_embeddings)
            action_embeddings = self.embed_action(actions)
            time_embeddings = self.embed_timestep(timesteps)

            # time embeddings are treated similar to positional embeddings
            state_embeddings = state_embeddings + time_embeddings
            action_embeddings = action_embeddings + time_embeddings

            stacked_inputs = torch.stack(
                (state_embeddings, action_embeddings),
                dim=1).permute(
                0, 2, 1, 3).reshape(
                batch_size, 2 * seq_length, self.hidden_size)
            lang_and_inputs = torch.cat([lang_embeddings, stacked_inputs], dim=1)
            # LAYERNORM AFTER LANGUAGE
            stacked_inputs = self.embed_ln(lang_and_inputs)

            # to make the attention mask fit the stacked inputs, have to stack it as well
            stacked_attention_mask = torch.stack(
                (attention_mask, attention_mask), dim=1
            ).permute(0, 2, 1).reshape(batch_size, 2*seq_length)
            lang_attn_mask = torch.cat(
                [torch.ones((batch_size, num_tokens), device=states.device), stacked_attention_mask], dim=1)

            # we feed in the input embeddings (not word indices as in NLP) to the model
            transformer_outputs = self.transformer(
                inputs_embeds=stacked_inputs,
                attention_mask=lang_attn_mask,
            )
            x = transformer_outputs['last_hidden_state']

            # reshape x so that the second dimension corresponds to the original
            # states (0), or actions (1); i.e. x[:,0,t] is the token for s_t
            lang_out = x[:, :num_tokens, :].reshape(
                batch_size, num_tokens, 1, self.hidden_size).permute(0, 2, 1, 3)
            traj_out = x[:, num_tokens:, :].reshape(
                batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3)

            # get predictions
            # predict state given state, action. skip the last prediction
            if isinstance(self.state_dim, int):
                state_preds = self.predict_state(traj_out[:, 1])[:, :-1, :]
            else:
                state_preds = None
            action_preds = self.predict_action(traj_out[:, 0])  # predict next action given state

            outputs = {'state_preds': state_preds,
                       'action_preds': action_preds}

            if self.predict_q:
                # predict next Q given state   ## IMP: Don't use current action
                q_preds = self.predict_q(traj_out[:, 0])
                outputs.update({'q_preds': q_preds})

            return outputs

    def get_action(self, states, actions, timesteps, options=None, word_embeddings=None, **kwargs):

        if self.use_options:
            assert options is not None
            if isinstance(self.state_dim, tuple):
                states = states.reshape(1, -1, *self.state_dim)
            else:
                states = states.reshape(1, -1, self.state_dim)
            options = options.reshape(1, -1, self.option_dim)
            actions = actions.reshape(1, -1, self.act_dim)
            timesteps = timesteps.reshape(1, -1)

            if self.max_length is not None:
                states = states[:, -self.max_length:]
                options = options[:, -self.max_length:]
                actions = actions[:, -self.max_length:]
                timesteps = timesteps[:, -self.max_length:]

                # pad all tokens to sequence length
                attention_mask = pad(torch.ones(1, states.shape[1]), self.max_length).to(
                    dtype=torch.long, device=states.device).reshape(1, -1)
                states = pad(states, self.max_length).to(dtype=torch.float32)
                options = pad(options, self.max_length).to(dtype=torch.float32)
                actions = pad(actions, self.max_length).to(dtype=torch.float32)
                timesteps = pad(timesteps, self.max_length).to(dtype=torch.long)
            else:
                raise ValueError('Should not have max_length NONE')
                attention_mask = None

            preds = self.forward(
                states, actions, timesteps, options=options, attention_mask=attention_mask)

        if self.use_language:
            assert word_embeddings is not None
            if isinstance(self.state_dim, tuple):
                states = states.reshape(1, -1, *self.state_dim)
            else:
                states = states.reshape(1, -1, self.state_dim)
            actions = actions.reshape(1, -1, self.act_dim)
            timesteps = timesteps.reshape(1, -1)

            if self.max_length is not None:
                states = states[:, -self.max_length:]
                actions = actions[:, -self.max_length:]
                timesteps = timesteps[:, -self.max_length:]

                # pad all tokens to sequence length
                attention_mask = pad(
                    torch.ones(1, states.shape[1]),
                    self.max_length).to(
                    dtype=torch.long, device=states.device).reshape(
                    1, -1)
                states = pad(states, self.max_length).to(dtype=torch.float32)
                actions = pad(actions, self.max_length).to(dtype=torch.float32)
                timesteps = pad(timesteps, self.max_length).to(dtype=torch.long)
            else:
                attention_mask = None

            preds = self.forward(
                states, actions, timesteps, word_embeddings=word_embeddings, attention_mask=attention_mask, **kwargs)

        return preds


class MLPPolicy(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, embed_dim, option_dim: int, image_embed = 64):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.option_dim = option_dim  
        self.embed_dim = embed_dim
        self.image_embed = image_embed
        
        if isinstance(obs_dim, tuple):
            self.embed_state = Encoder(self.obs_dim, hidden_size=self.image_embed, ch=3, robot=False)
            self.obs_dim = self.image_embed
            
        self.obs_goal_embed = nn.Linear(self.obs_dim + self.option_dim, self.embed_dim)
        self.hidden1_layer = nn.Linear(self.embed_dim, self.embed_dim)
        self.hidden2_layer = nn.Linear(self.embed_dim, self.embed_dim)
        self.action_decoder = nn.Linear(self.embed_dim, self.action_dim)
        self.dropout_activation = nn.Sequential(nn.ReLU(), nn.ReLU(), nn.Dropout())

        self.mlp_policy = nn.Sequential(
            self.obs_goal_embed,
            self.dropout_activation,
            self.hidden1_layer,
            self.dropout_activation,
            self.hidden2_layer,
            self.dropout_activation,
            self.action_decoder
        )

    def forward(self, observations: Tensor, goal: Tensor = None):
        '''
        obs_mask: boolean tensor, True means masked
        '''
        if len(observations.shape) > 3:
            observations = self.embed_state(observations)
        if goal != None:
            obs_goal = torch.cat([observations, goal], dim=-1)
        else:
            obs_goal = observations

        pred_a = self.mlp_policy(obs_goal)

        return pred_a