import torch
from transformers import LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria
from torch.nn.utils.rnn import pad_sequence
from utils.logits_processors import DFABeamSearchScorer
import numpy as np

@torch.compile
def log_softmax(x):
    return torch.log_softmax(x)

def next_token_logits_for_beams(beams, model, tokenizer, device, dtype, batch_size=16,neginf=None,debug=False):
    """
    beams: list of tuples (input_ids, …)
    returns: Tensor [num_beams, vocab_size] of next-token logits
    """
    num_beams = len(beams)
    eos_id = tokenizer.eos_token_id

    #print(eos_id)

    vocab_size = model.config.vocab_size
    #print(vocab_size)

    # 1) Extract the seqs and find which beams are finished
    seqs     = [b["input_ids"] for b in beams]
    finished = []
    for seq in seqs:
        finished.append(seq[-1].item() == eos_id or seq[-1].item() == 128255)

    if debug:
        print(seqs[-1])
    # 2) Prepare output logits tensor, fill with -inf
    next_logits = torch.full(
        (num_beams, vocab_size),
        fill_value=neginf,
        device=device,
        dtype=dtype
        
    )

    # 3) For finished beams, emit EOS only
    for i, fin in enumerate(finished):
        if fin:
            next_logits[i, eos_id] = 1e4
            #updated_pasts[i] = beams[i]["past"]

    # 4) If there are any alive beams, batch them through the model
    alive_idx = [i for i, done in enumerate(finished) if not done]
    if len(alive_idx) > 0:
        # collect and pad only the alive sequences
        alive_seqs = [seqs[i] for i in alive_idx]
        pad_id = eos_id # tokenizer.pad_token_id or
        batch_ids = pad_sequence(alive_seqs, batch_first=True, padding_value=pad_id)
        attn_mask = (batch_ids != pad_id).long()

        batch_ids     = batch_ids.to(device)
        attn_mask     = attn_mask.to(device)
        with torch.no_grad():
            outputs = model(
                input_ids=batch_ids,
                attention_mask=attn_mask,
                use_cache=False,
                batch_size=batch_size,
            )
        # outputs.logits: [len(alive_idx), seq_len, vocab_size]
        logits_alive = outputs.logits[:, -1, :]  # [A, V]

        # scatter them back into the full tensor
        
        idx = torch.tensor(alive_idx, dtype=torch.long, device=next_logits.device)
        next_logits[idx] = logits_alive
        
    return next_logits,finished

def generate_with_dfa(
    model,
    tokenizer,
    dfa_processor,
    prompt: str,
    num_beams: int,
    max_new_tokens: int,
    length_penalty: float = 1.0,
    device: str = "cuda:0",
    half: bool = False,
    inference_batch_size: int = 8,
    beam_scorer_batch_size: int = 1,
    accepting_states =None,
    check_it:bool = False,
    final_normal=True,
    chat_model:bool = False,
    
    
    
    **gen_kwargs ) -> list[str]:
    """
    Beam‐search generation that keeps a DFA‐aware LogitsProcessor in sync.

    Args:
      model:           a CausalLM model (e.g. GPT2LMHeadModel)
      tokenizer:       its tokenizer
      dfa_processor:   your DFALogitsProcessor instance
      prompt:          text prompt to start from
      num_beams:       beam width
      max_new_tokens:  number of new tokens to generate
      length_penalty:  beam search length penalty
      device:          "cuda" or "cpu"
      half:            whether to use FP16
      beam_scrorer_batch_size: number of independent prompts (B)
    """
    # 1) Prepare inputs
    if not chat_model:
        inputs = tokenizer(prompt, return_tensors='pt')
    else:
        chat_prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True
        )
        inputs = tokenizer(chat_prompt, return_tensors='pt')
    inputs = {k: v for k, v in inputs.items()}
    input_ids = inputs['input_ids'].to(device)
    B, cur_len = input_ids.shape
    inputs_len = cur_len
    assert B == beam_scorer_batch_size, f"Batch size mismatch: {B} != {beam_scorer_batch_size}"
    # 2) Precision & device
   
   
    dtype = torch.float16 if half else torch.float32
    neginf= -1e+30 if dtype == torch.float32 else -1e+4

    # 3) HF helpers
    logits_processor = LogitsProcessorList([dfa_processor])
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=cur_len + max_new_tokens)
    ])
    beam_scorer = DFABeamSearchScorer(
        dfa_processor     = dfa_processor,
        batch_size        = beam_scorer_batch_size,
        num_beams         = num_beams,
        device            = device,
        length_penalty    = length_penalty,
        do_early_stopping = True,
        num_beam_hyps_to_keep= num_beams,
    )

    vocab_size = len(tokenizer) #tokenizer.vocab_size
    # initialize beams list of length B*num_beams
    beams = [{'input_ids': input_ids[0].clone()} for b in range(num_beams) ]

    # 4) Beam search loop

    beam_scores = torch.zeros(B, num_beams, device=device, dtype=dtype)
    normal_beam_scores = torch.zeros(B,num_beams, device=device, dtype=dtype)
    for step in range(max_new_tokens):

        # 4.1) collect inputs
        inp = torch.stack([b['input_ids'] for b in beams], dim=0)  # (B*num_beams, cur_len)
        if step==31:
            #debug=True
            pass
        else:
            debug=False
        # 4.2) model forward for next logits
        if step!=0 or not check_it:
            next_logits, finished = next_token_logits_for_beams(
                beams, model, tokenizer, device=device, dtype=dtype,
                batch_size=inference_batch_size,neginf=neginf,debug=debug
            )  # (B*num_beams, vocab_size)
        else:
            next_logits, finished = next_token_logits_for_beams(
                [beams[0]], model, tokenizer, device=device, dtype=dtype,
                batch_size=inference_batch_size,neginf=neginf,debug=debug
            )
            next_logits = next_logits.repeat(0,num_beams)
    
        if np.array(finished).sum()==num_beams:
            
            break

     
        normal_scores = torch.log_softmax(next_logits, dim=-1)  # (B*num_beams, vocab_size)
       
        # normal_scores = normal_scores.view(B, num_beams, vocab_size)
        # normal_scores = normal_scores + normal_beam_scores.unsqueeze(-1)  # (B, num_beams, vocab_size)
        # normal_scores = normal_scores.view(B * num_beams, vocab_size)
        # 4.3) apply DFA logits processor
        scores = logits_processor(inp, next_logits, neginf=neginf)  # (B*num_beams, vocab_size)
        
        tokenizer
        all_on_eos = torch.full(size = (num_beams,vocab_size),fill_value=neginf).to('cuda')
        all_on_eos[:,-1] = 0.0
        finished_mask = torch.tensor(finished).to(device).unsqueeze(1).expand((num_beams,vocab_size))
        
        scores = torch.where(finished_mask, all_on_eos, scores)
        scores = torch.log_softmax(scores,dim=-1)
        

        
        
        # 4.4) add previous beam scores
        scores = scores.view(B, num_beams, vocab_size)
        normal_scores = normal_scores.view(B, num_beams, vocab_size)
        
        scores = scores + beam_scores.unsqueeze(-1)
        normal_scores = normal_scores + normal_beam_scores.unsqueeze(-1)
        if step==0:
            best_score_per_beam, candidate_tokens = scores.topk(k=num_beams,dim=-1)  # (B, num_beams,num_beams)
            best_scores_first_beam = best_score_per_beam[:,0] # (B, num_beams)
            candidate_tokens_first_beam = candidate_tokens[:,0] # (B, num_beams)
            next_token_ids = candidate_tokens_first_beam
            next_beam_ids = torch.arange(start=0,end=num_beams).unsqueeze(0).expand(B,-1) # (B, num_beams)
            next_scores = best_scores_first_beam # (B, num_beams)
        else:
            # 4.5) flatten and pick top-k
            flat = scores.view(B, num_beams * vocab_size)
            next_scores, flat_ids = flat.topk(num_beams, dim=1)

            num_not_neginf_per_beam = (flat != neginf).sum()
            actual_topk = min(num_not_neginf_per_beam,num_beams)
            next_scores, flat_ids = flat.topk(actual_topk, dim=1) # (B, actual_topk)
            # If actual_topk < num_beams, tile or pad the scores to maintain consistent shape
            if actual_topk < num_beams:
                repeat_factor = num_beams // actual_topk
                remainder = num_beams % actual_topk

                next_scores = next_scores.repeat(1, repeat_factor)
                if remainder > 0:
                    next_scores = torch.cat([next_scores, next_scores[:, :remainder]], dim=1)

                flat_ids = flat_ids.repeat(1, repeat_factor)
                if remainder > 0:
                    flat_ids = torch.cat([flat_ids, flat_ids[:, :remainder]], dim=1)




            next_beam_ids  = flat_ids // vocab_size  # (B, num_beams)
            next_token_ids = flat_ids %  vocab_size  # (B, num_beams)
        #print(input_ids)
        #print(next_scores)
        #print(next_beam_ids)
        # 4.6) pass through beam scorer
        
        
        beam_out = beam_scorer.process(
            input_ids=inp.view(B*num_beams, cur_len),
            next_scores=next_scores,
            next_tokens=next_token_ids,
            next_indices=next_beam_ids,
        )
        
        beam_scores = beam_out['next_beam_scores']       # (B * num_beams)
        next_beam_ids = beam_out['next_beam_indices']    # (B * num_beams)
        next_token_ids = beam_out['next_beam_tokens']    # (B * num_beams)
        
        
        normal_scores = normal_scores.view(B * num_beams, vocab_size)
        normal_scores = normal_scores[next_beam_ids]
        normal_beam_scores = normal_scores[
        torch.arange(next_token_ids.size(0), device=next_token_ids.device),
            next_token_ids,
        ]     
        #print(f"beam scores shape: {beam_scores.shape}")
        #print(f"normal beam scores shape: {normal_beam_scores.shape}")
        
        #print(f"next tokens {tokenizer.batch_decode(next_token_ids)}")
        # we also compute the normal beam scores using the next_token_ids and next_beam_ids
        # # this is used to compute the final scores
        # print(f"normal beam scores shape: {normal_beam_scores.shape}")
   
        # we need to gather the scores using the next_token_ids and next_beam_ids
        # normal_beam_scores = normal_scores.view(B, num_beams, vocab_size)
        # next_beam_ids = next_beam_ids.view(B, num_beams, 1)
        # next_token_ids = next_token_ids.view(B, num_beams, 1)
        # normal_beam_scores = normal_beam_scores.gather(1, next_beam_ids)
        # normal_beam_scores = normal_beam_scores.gather(2, next_token_ids)
        # normal_beam_scores = normal_beam_scores.view(B, num_beams)
      
        # 4.7) update DFA states
        dfa_processor.update_beam_states(next_beam_ids,next_token_ids)
        
        #print(dfa_processor.beam_states)
        
        #print()
        # 4.8) repackage beams
        # reshape to (B, K, cur_len)
        prev = inp.view(B, num_beams, cur_len)
        # gather previous states
        reordered = prev[torch.arange(B).unsqueeze(1), next_beam_ids]
 
        new_ids = torch.cat([reordered, next_token_ids.unsqueeze(0).unsqueeze(-1)], dim=-1)

        
        # flatten to (B*num_beams, cur_len+1)
        flat_ids = new_ids.view(B * num_beams, cur_len + 1)
        beams = [{'input_ids': flat_ids[i]} for i in range(B * num_beams)]
        cur_len += 1
        
        # stop if all beams finished
        if beam_scorer.is_done:
            break
        
        

    # 5) finalize and decode
    
    if final_normal:
        final_beam_scores = normal_beam_scores
    else:
        final_beam_scores = beam_scores
    output = beam_scorer.finalize(
        input_ids=flat_ids,
        final_beam_scores=final_beam_scores,
        final_beam_tokens=next_token_ids,
        final_beam_indices=next_beam_ids,
        max_length=cur_len,
        pad_token_id=tokenizer.eos_token_id, #tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )  # (B*num_beams, cur_len)
    #(dfa_processor.beam_states)
    #print(sequences)
    #print(dfa_processor.beam_states)
    if accepting_states is not None:
        end_state_accepting = [state in accepting_states for state in dfa_processor.beam_states]
    #print("other candidates:")
    # sequences = output['sequences']
    # for i in range(min(16,num_beams)):
    #     candidate = tokenizer.decode(sequences[i][inputs_len:], skip_special_tokens=True)
    #     print(candidate)
    return output


