import torch
import torch.nn as nn
from torch.nn import functional as F


from src.models.gpt import GPT


class BCQ_TransformerGPT(GPT):
    def __init__(self, config):
        self.observation_weight = config.observation_weight if hasattr(config, 'observation_weight') else 1

        print('self observation_weight: ', self.observation_weight)

        if not hasattr(config, 'output_dim'):
            config.output_dim = config.n_embd
        config.mask_values = False

   
        super().__init__(config)

        self.observation_mean = nn.Parameter(config.observation_mean, requires_grad=False)
        self.observation_std = nn.Parameter(config.observation_std + 1.e-6, requires_grad=False)
        self.subgoal_mean = nn.Parameter(config.subgoal_mean, requires_grad=False)
        self.subgoal_std = nn.Parameter(config.subgoal_std + 1.e-6, requires_grad=False)
       
    def create_layers(self, config):
        # embedding layers
        self.observation_embed = nn.Sequential(
            nn.Linear(self.observation_dim, self.embedding_dim)
        )
      
        self.embed_ln = nn.LayerNorm(self.embedding_dim)
        self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size, self.embedding_dim))

        # decoder layers
        self.observation_decoder = nn.Sequential(
            nn.LayerNorm(self.output_dim),
            nn.Linear(self.output_dim, self.observation_dim)
        )
     
        super().create_layers(config)

    def pad_to_full_observation(self, x):
        x_view = x.view(-1, self.transition_dim, self.embedding_dim)
        return x_view, 0

    def embed_inputs(self, inputs):
        observations = (inputs['observations'] - self.observation_mean) / self.observation_std
        # actions = (inputs['actions'] - self.action_mean) / self.action_std
        b, obs_t, *_ = observations.shape
        # _, act_t, *_ = actions.shape
        # t = obs_t + act_t
        # assert t <= self.block_size, "Cannot forward, model block size is exhausted."

        observation_embeddings = self.observation_embed(observations)
        # action_embeddings = self.action_embed(actions)

        # [ B x T x embedding_dim ]
        # embeddings = torch.stack([observation_embeddings, action_embeddings], dim=2).reshape((b, t, self.embedding_dim))
        embeddings = observation_embeddings 
        embeddings = self.embed_ln(embeddings)

        # [ 1 x T x embedding_dim ]
        position_embeddings = self.pos_emb[:, :obs_t, :]  # each position maps to a (learnable) vector

        if 'embedding_offset' in inputs:
            position_embeddings = position_embeddings + inputs['embedding_offset']
        return embeddings + position_embeddings

    def decode_outputs(self, outputs, inputs):
        

        preds = {}

 
        subgoals_preds = self.subgoal_std * self.observation_decoder(outputs) + self.subgoal_mean
        preds['subgoals'] = subgoals_preds

      
        return preds

    def compute_loss(self, outputs, inputs, targets, mask=None):
      
       
        observation_error = F.mse_loss(outputs['subgoals'], targets['subgoals'], reduction='none')
       
        observation_loss = torch.sum(observation_error / (self.observation_std ** 2), dim=-1, keepdims=True)
        loss = self.observation_weight * observation_loss[mask[:, 1:]].mean()

        # print(outputs['subgoals'][0,-1,:], targets['subgoals'][0,-1,:])
        # print('loss: ', loss)

      

        return loss
