import torch 
import torch.nn as nn 

class LogRewardLinear(nn.Module): 
    
    def __init__(self, max_size, vocab_size, seed, device, temperature=1.): 
        super(LogRewardLinear, self).__init__() 
        self.max_size = max_size 
        self.vocab_size = vocab_size 
        self.seed = seed 
        self.temperature = temperature 

        g = torch.Generator(device=device) 
        g.manual_seed(seed) 
        self.p_encoding = torch.rand(generator=g, size=(self.max_size + 1,), requires_grad=False) - 2 
        self.t_encoding = torch.rand(generator=g, size=(self.vocab_size + 1,), requires_grad=False) 

    @torch.no_grad() 
    def forward(self, batch_state): 
        indices = torch.arange(batch_state.size) 
        log_rewards = (batch_state.mask_padding * \
                self.p_encoding.expand(batch_state.state.shape[0], -1)[:, :batch_state.size] * \
                    self.t_encoding[batch_state.state]
                ).sum(dim=1)  
        return log_rewards / self.temperature 

class Sequence: 

    def __init__(self, max_size, vocab_size, log_reward, batch_size=512, device='cpu'): 
        self.max_size = max_size 
        self.vocab_size = vocab_size 
        self.batch_size = batch_size 
        self.batch_ids = torch.arange(self.batch_size) 
        
        self._log_reward = log_reward 

        # Masks 
        self.state = torch.zeros((self.batch_size, 1), dtype=torch.long)
        self.mask_padding = torch.ones((self.batch_size, 1)) # Padding for sequences of variable size 
        self.mask = torch.ones((self.batch_size, self.vocab_size + 1)) # Mask over actions (plus one for the interrupting actions) 
        
        self.stopped = torch.zeros((self.batch_size,)) 
        self.actions = torch.arange(self.vocab_size + 1) 
        self.size = 1 
        self.lengths = torch.ones((self.batch_size,)) 

    @torch.no_grad() 
    def is_stop_action(self, actions): 
        return (actions == 0).to(dtype=torch.get_default_dtype()) # The same token corresponds to the first and interrupting elements of the built sequence 

    @torch.no_grad() 
    def apply(self, indices): 
        # Mask all actions is the sequence is beyond a certain size         
        actions = self.actions[indices] 
        stop_action_mask = self.is_stop_action(actions) 
        self.stopped += stop_action_mask 
        self.lengths += (1 - stop_action_mask) 
        # print(self.stopped, stop_action_mask, self.size, self.mask, '\n', '+' * 9) 

        self.size += 1 
        
        updated_mask = self.mask.clone() 
        updated_mask[:, 1:] = (1 - stop_action_mask[:, None]) * updated_mask[:, 1:]  
        updated_mask[:, 1:] = (self.size < self.max_size) * updated_mask[:, 1:] 
        
        self.mask = updated_mask 
        
        # Update state 
        self.state = torch.hstack([self.state, actions.unsqueeze(-1)]) 
        # Update paddings 
        self.mask_padding = torch.hstack([self.mask_padding, (1 - stop_action_mask).unsqueeze(-1)]) 

        if (self.stopped >= 1.).all() and self.size != self.max_size + 1: 
            padding = torch.zeros((self.batch_size, self.max_size + 1 - self.size), dtype=torch.long)  
            self.size = self.max_size + 1 
            self.state = torch.hstack([self.state, padding]) 
            self.mask_padding = torch.hstack([self.mask_padding, padding]) 
        return (self.stopped < 2) 
    
    @torch.no_grad() 
    def log_reward(self): 
        return self._log_reward(self) 

    @torch.no_grad() 
    def merge(self, batch_state): 
        self.state = torch.vstack([self.state, batch_state.state]) 
        self.mask_padding = torch.vstack([self.mask_padding, batch_state.mask_padding]) 
        self.batch_size += batch_state.batch_size 
        self.batch_ids = torch.hstack([self.batch_ids, batch_state.batch_ids + self.batch_size])  
        self.lengths = torch.hstack([self.lengths, batch_state.lengths]) 