import torch
import torch.nn as nn
import transformers
transformers.set_seed(0)
from transformers import GPT2Config, GPT2Model
from IPython import embed

class action_learner(nn.Module):
    """Transformer class."""

    def __init__(self, config):
        super(action_learner, self).__init__()

        self.test = config['test']
        self.horizon = config['horizon']
        self.n_embd = config['n_embd']
        self.n_layer = config['n_layer']
        self.n_head = config['n_head']
        self.state_dim = config['state_dim']
        self.action_dim = config['action_dim']
        self.dropout = config['dropout']
        self.device = config['device']

        config = GPT2Config(
            n_positions=4 * (1 + self.horizon),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=self.n_head,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout,
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)

        self.embed_transition = nn.Linear(
            2 * self.state_dim + self.action_dim + 1, self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.action_dim)

    def forward(self, x):
        query_states = x['query_states'][:, None, :]
        zeros = torch.zeros(*query_states.shape[:-1], self.state_dim + self.action_dim).to(self.device)
        state_seq = torch.cat([query_states, x['context_states']], dim=1)
        action_seq = torch.cat(
            [zeros[:, :, :self.action_dim], x['context_actions']], dim=1)
        next_state_seq = torch.cat(
            [zeros[:, :, :self.state_dim], x['context_next_states']], dim=1)
        reward_seq = torch.cat([zeros[:, :, :1], x['context_rewards']], dim=1)

        seq = torch.cat(
            [state_seq, action_seq, next_state_seq, reward_seq], dim=2)
        stacked_inputs = self.embed_transition(seq)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_actions(transformer_outputs['last_hidden_state'])

        if self.test:
            return preds[:, -1, :]
        return preds[:, 1:, :]



class reward_learner(nn.Module):
    """Transformer class."""

    def __init__(self, config):
        super(reward_learner, self).__init__()
        self.test = config['test']
        self.horizon = config['horizon']
        self.n_embd = config['n_embd']
        self.n_layer = config['n_layer']
        self.n_head = config['n_head']
        self.state_dim = config['state_dim']
        self.action_dim = config['action_dim']
        self.reward_embd = 32
        self.dropout = config['dropout']
        self.device = config['device']
        if config['type'] == 'Q':
            self.embed_transition = nn.Linear(
                self.state_dim + self.action_dim, self.n_embd)
            self.forward = self.forward_Q
        else:
            self.embed_transition = nn.Linear(
                self.state_dim, self.n_embd)
            self.forward = self.forward_V

        config = GPT2Config(
            n_positions=4 * (1 + self.horizon),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=self.n_head,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout,
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)
        self.pred_rewards = nn.Sequential(
                            nn.Linear(self.n_embd, self.reward_embd),
                            nn.ReLU(),
                            nn.Linear(self.reward_embd, 1))

    def forward_Q(self, x):
        states = x['context_states']
        actions = x['context_actions']
        cumulative_rewards = x['cumulative_rewards'][...,None]
        state_action_seq = torch.cat([states, actions], dim=2)
        zeros = torch.zeros(*x['context_states'].shape[:-1], self.state_dim + self.action_dim - 1).to(self.device)
        reward_seq = torch.cat([cumulative_rewards, zeros], dim=2) # NOTE: x1, r1, x2, r2,.....x99, r99, x100
        seq = torch.cat([state_action_seq, reward_seq], dim=1)
        stacked_inputs = self.embed_transition(seq)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_rewards(transformer_outputs['last_hidden_state'])
        return preds[:,::2]
    
    def forward_V(self, x):
        states = x['context_states']
        cumulative_rewards = x['cumulative_rewards'][...,None]
        zeros = torch.zeros(*x['context_states'].shape[:-1], self.state_dim - 1).to(self.device)
        reward_seq = torch.cat([cumulative_rewards, zeros], dim=2) # NOTE: x1, r1, x2, r2,.....x99, r99, x100
        seq = torch.cat([states, reward_seq], dim=1)
        stacked_inputs = self.embed_transition(seq)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_rewards(transformer_outputs['last_hidden_state'])
        return preds[:,::2]
    