import torch
from torch import nn
from x_transformers import TransformerWrapper, Decoder
import torch.nn.functional as F
from typing import Tuple


def build_transformer(model, max_length, vocab_size):
    return TransformerWrapper(
        num_tokens = vocab_size,
        max_seq_len = max_length,
        use_abs_pos_emb = model.use_abs_pos_emb,
        attn_layers = Decoder(
            dim = model.width,
            depth = model.depth,
            heads = model.num_heads,
            rotary_pos_emb = (model.pos_emb == 'rope'),
            alibi_pos_bias = (model.pos_emb == 'alibi'),
            rel_pos_bias = (model.pos_emb == 't5'),
        )
    )

class Transformer(nn.Module):

    def __init__(self, model_config, tokenizer):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.tokenizer = tokenizer
        self.model_config = model_config
        self.autoregressive = model_config.autoregressive
        self.padding_idx = tokenizer.pad_token_id
        self.output_dim = model_config.target_max_length
        max_length = model_config.max_length if not model_config.autoregressive else model_config.max_length + model_config.target_max_length
        self.model = build_transformer(model_config, max_length, tokenizer.vocab_size).to(self.device)
    
    def forward(self, input_ids=None, attention_mask=None):
        if not self.autoregressive:
            padding = torch.full((input_ids.shape[0], self.output_dim), self.padding_idx, device=self.device)
            input_ids = torch.cat((input_ids, padding), dim=1)
        
        if attention_mask is not None:
            attention_mask = attention_mask.bool()
            
        logits = self.model(input_ids, mask=attention_mask)
        pred = logits.transpose(1, 2)
        
        return pred.squeeze(-1)

    @torch.no_grad()
    def generate(self, input_ids, attention_mask=None, max_new_tokens=None, num_return_sequences=1, do_sample=False, pad_token_id=None):
        batch_size = input_ids.shape[0]
        device = input_ids.device
        
        generated = input_ids.clone()
        for _ in range(max_new_tokens):
            outputs = self.forward(input_ids=generated, attention_mask=attention_mask)
            next_token_logits = outputs[:, :, -1]
            next_tokens = next_token_logits.argmax(dim=-1)
            
            generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([
                    attention_mask,
                    torch.ones((batch_size, 1), dtype=torch.bool, device=device)
                ], dim=-1)
        
        return generated
    

    @torch.no_grad()
    def batch_generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        *,
        max_new_tokens: int,
        num_return_sequences: int = 1,
        temperature: float = 0.0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert 0.0 <= temperature <= 1.0
        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)
        if attention_mask is not None and attention_mask.dim() == 1:
            attention_mask = attention_mask.unsqueeze(0)

        device = input_ids.device
        B, L = input_ids.shape
        T = max_new_tokens
        flat_B = B * num_return_sequences   

        delimiter_mask = (input_ids == self.tokenizer.delimiter_id)
        delimiter_pos = delimiter_mask.nonzero(as_tuple=True)[1]
        if len(delimiter_pos) != B:
            raise ValueError(f"Expected exactly one delimiter token per sequence, got {len(delimiter_pos)} for batch size {B}")
        
        input_ids = torch.stack([ids[:pos+1] for ids, pos in zip(input_ids, delimiter_pos)])
        if attention_mask is not None:
            attention_mask = torch.stack([mask[:pos+1] for mask, pos in zip(attention_mask, delimiter_pos)])

        generated = input_ids.repeat_interleave(num_return_sequences, dim=0)
        if attention_mask is not None:
            attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)

        logprob_buf = []

        for _ in range(T):
            logits = self.forward(input_ids=generated,
                                attention_mask=attention_mask)
            next_logits = logits[:, :, -1]

            if temperature > 0.0:
                dist_logits = next_logits / temperature
                next_tokens = torch.multinomial(F.softmax(dist_logits, dim=-1), 1).squeeze(-1)
            else:
                dist_logits = next_logits
                next_tokens = dist_logits.argmax(dim=-1)

            log_probs_step = F.log_softmax(dist_logits, dim=-1).gather(1, next_tokens[:, None]).squeeze(-1)
            logprob_buf.append(log_probs_step)         

            generated = torch.cat([generated, next_tokens[:, None]], dim=-1)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask,
                                            torch.ones_like(next_tokens)[:, None]], dim=-1)

        log_probs = torch.stack(logprob_buf, dim=1)

        generated_responses = generated[:, -T:]
        generated_responses = generated_responses.view(B, num_return_sequences, T)
        log_probs = log_probs.view(B, num_return_sequences, T)

        return generated_responses, log_probs
