from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from einops import rearrange
from torch import nn
import torch.nn.functional as F
import torch
import random
import math
import json

from pathlib import Path
import sys
path_root = Path(__file__).parents[1]
sys.path.append(str(path_root))

# logits processors
from transformers.generation.logits_process import (
    LogitsProcessorList,
    RepetitionPenaltyLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
)

def find_first_true_index(bool_tensor, dim = -1):
    return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)

@torch.inference_mode()
def diffusion_decoding(
    model,
    tokenizer,
    input_ids,
    attention_mask,
    n_token_seq_len,
    temperature = 0.9,
    top_p = 0.9, 
    top_k = 20,
    repetition_penalty = 1.05, 
    lenience = 1.,
    accept_threshold = 0.8,
    confidence_threshold = 0.4,
    ):

    batch, prompt_len, out, device = 1, int(torch.sum(attention_mask[0])), input_ids.clone(), input_ids.device
    seq_lens = torch.full((batch,), prompt_len, device = device, dtype = torch.long)

    ### Initialization draft distribution q(x) with 0-1 distribution from prompt
    q_sampled = []
    q_logits_all = []
    for _ in range(n_token_seq_len):
        q_sample = torch.tensor([random.choice(input_ids[0].tolist())]).to(dtype=torch.long, device=model.device).unsqueeze(dim=0)
        out = torch.cat((out, q_sample), dim=1)
        q_logits = torch.full((batch, len(tokenizer)), float('-inf'), device=model.device)
        q_logits.scatter_(1, q_sample, 0.0) 
        q_sampled.append(q_sample)
        q_logits_all.append(q_logits)
    q_sampled = torch.cat(q_sampled, dim = 1)
    q_logits_all = torch.stack(q_logits_all, dim = -2)
    q_logits = q_logits_all

    ### Initialize LogitsProcessor with GenerationConfig
    logits_processors = LogitsProcessorList()
    if repetition_penalty is not None and repetition_penalty != 1.0:
        logits_processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
    if temperature is not None and temperature != 1.0:
        logits_processors.append(TemperatureLogitsWarper(temperature))
    if top_k is not None and top_k != 0:
        logits_processors.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
    if top_p is not None and top_p < 1.0:
        logits_processors.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
    
    ### Diffusion decoding for multi-block
    total_accepted_all_blocks = [0]
    q_logits_all_blocks = [q_logits]
    q_sampled_all_blocks = [q_sampled]
    out_accepted_all_blocks = [torch.empty((batch, 0), device=model.device)]
    iteration_all_blocks = [0]
    num_blocks = 1
    confidence_of_first_token = []
    while any(t < n_token_seq_len for t in total_accepted_all_blocks):

        ### verify and speculate with larger network within a forward pass
        out_attention_mask = torch.full_like(out, 1).to(model.device)
        logits = model(out, out_attention_mask).logits

        for block_id in range(num_blocks):

            if total_accepted_all_blocks[block_id] < n_token_seq_len:
                block_position = block_id*n_token_seq_len
                p_logits_per_block = logits[:, prompt_len+total_accepted_all_blocks[block_id]-1+block_position:prompt_len+(block_position+n_token_seq_len), :]
                q_logits_per_block = q_logits_all_blocks[block_id]
                
                # only support bsz=1 now
                p_scores = logits_processors(out, p_logits_per_block.squeeze(dim=0)).unsqueeze(dim=0)
                q_scores = logits_processors(out, q_logits_per_block.squeeze(dim=0)).unsqueeze(dim=0)
        
                ### prob and prob of draft distribution (p(x) and q(x))
                p_prob = nn.functional.softmax(p_scores, dim=-1)[:, :, :len(tokenizer)]
                q_prob = nn.functional.softmax(q_scores, dim=-1)[:, :, :len(tokenizer)]
        
                p, prob_next = p_prob[:, :-1], p_prob[:, -1]

                p = p.gather(-1, q_sampled_all_blocks[block_id].unsqueeze(dim=-1))
                q = q_prob.gather(-1, q_sampled_all_blocks[block_id].unsqueeze(dim=-1)) * lenience
                
                p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)]
                r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)
                threshold = torch.ones_like(q).float() * accept_threshold
        
                accepted = find_first_true_index(
                        (r > (p / q)) | (p < threshold)
                    )
        
                num_accepted = int(accepted[0])
                total_accepted_all_blocks[block_id] += num_accepted
                out_accepted_all_blocks[block_id] = out[:, prompt_len+block_position:prompt_len+block_position+total_accepted_all_blocks[block_id]]
        
                has_rejected = (num_accepted < q.shape[1])
        
                ### sample the additional token to better bound the worst case
                sample_additional_token = False
                if num_accepted == 0: 
                    next_token = torch.multinomial(p_prob[:, num_accepted, :], num_samples=1)
                    out_accepted_all_blocks[block_id] = torch.cat((out_accepted_all_blocks[block_id], next_token), dim = -1)
                    total_accepted_all_blocks[block_id] += 1
                    sample_additional_token = True
        
                if not has_rejected and all(t >= n_token_seq_len for t in total_accepted_all_blocks):
                    next_token = torch.multinomial(prob_next, num_samples=1)
                    out_accepted_all_blocks[block_id] = torch.cat((out_accepted_all_blocks[block_id], next_token), dim = -1)
                    total_accepted_all_blocks[block_id] += 1
                    out = out[:, :prompt_len]
                    for block_id in range(num_blocks):
                        out = torch.cat((out, out_accepted_all_blocks[block_id]), dim=-1)
                    return out
        
                if has_rejected:
                    ### update q(x) with self-speculated p(x) and sample new drafts tokens
                    if sample_additional_token:
                        q_logits_all_blocks[block_id] = p_logits_per_block[:, num_accepted+1:-1, :]
                        q_probs = p_prob[:, num_accepted+1:-1, :]
                    else:
                        q_logits_all_blocks[block_id] = p_logits_per_block[:, num_accepted:-1, :]
                        q_probs = p_prob[:, num_accepted:-1, :]
                    q_sampled_all_blocks[block_id] = torch.multinomial(q_probs.squeeze(dim=0), num_samples=1).reshape(1, -1)

                iteration_all_blocks[block_id] += 1
                print(f'Block id: {block_id}, Iteration step: {iteration_all_blocks[block_id]}, Accepted tokens: {total_accepted_all_blocks[block_id]}')

        ### ADDED: Calculate confidence of the last token and add new block ###
        first_token_confidence = torch.max(prob_next[0]).item()
        confidence_of_first_token.append(first_token_confidence)
        q_sampled_new_block = []
        q_logits_new_block_all = []
        # if first_token_confidence > confidence_threshold:
        if len(confidence_of_first_token) > 2 and all(c > confidence_threshold for c in confidence_of_first_token[-2:]):
            num_blocks += 1
            confidence_of_first_token = []
            print(f'-----------New block added. Currently {num_blocks} blocks in decoding-----------')
            q_sampled_next_block = []
            q_logits_next_block_all = []
            total_accepted_all_blocks.append(1)
            for step in range(n_token_seq_len):
                if step == 0:
                    top_token = torch.argmax(prob_next[0]).unsqueeze(0).to(dtype=torch.long, device=model.device)
                    out = torch.cat((out, top_token.unsqueeze(0)), dim=1)
                else:
                    q_sample_new_block = torch.tensor(
                        [random.choice(input_ids[0].tolist())],
                        dtype=torch.long, device=model.device
                    ).unsqueeze(0)
                    out = torch.cat((out, q_sample_new_block), dim=1)
        
                    q_logits_new_block = torch.full((batch, len(tokenizer)), float('-inf'), device=model.device)
                    q_logits_new_block.scatter_(1, q_sample_new_block, 0.0)
                    q_sampled_new_block.append(q_sample_new_block)
                    q_logits_new_block_all.append(q_logits_new_block)
        
            q_sampled_new_block = torch.cat(q_sampled_new_block, dim=1)
            q_logits_new_block_all = torch.stack(q_logits_new_block_all, dim=-2)
        
            q_logits_all_blocks.append(q_logits_new_block_all)
            q_sampled_all_blocks.append(q_sampled_new_block)
            out_accepted_all_blocks.append(top_token.unsqueeze(0))
            iteration_all_blocks.append(0)
        ### ADDED: Calculate confidence of the last token and add new block ###

        out = out[:, :prompt_len]
        for block_id in range(num_blocks):
            if out_accepted_all_blocks[block_id].numel() > 0:
                out = torch.cat((out, out_accepted_all_blocks[block_id]), dim=-1)
            if total_accepted_all_blocks[block_id] < n_token_seq_len:
                out = torch.cat((out, q_sampled_all_blocks[block_id]), dim=-1)

    return out

### Load dataset...
with open("/data/phd/kousiqi/kousiqi/CLLM2/data/bucketed_prompts_opencodeinstruct/OpenCodeInstruct_bucketed/bucket_0003_avg255_min250_max260.json", 'r') as f:
    data = json.load(f)
    
model_name = "/data/phd/kousiqi/kousiqi/ckpts/shiftedattn-8-31-coder-7B-ntok16_soft_ce_oci_datav1_59k_stp_ar_10_cyclic_prog_noise_all_lr5e-6"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map='cuda',
    torch_dtype=torch.bfloat16, 
    attn_implementation="flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained("/data/phd/kousiqi/kousiqi/ckpts/OpenThinker2-7B")

prompt = data[8000]

messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
print(f'Prompt from user: {text}')
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
input_ids = model_inputs['input_ids']
attention_mask = torch.full_like(input_ids, 1).to(model.device)

### Decoding with Diffusion decoding
prompt_lengths = input_ids.shape[1]
generated_part = input_ids[0, prompt_lengths:]
while not ((generated_part == tokenizer.eos_token_id) | (generated_part == 151645)).any():

    generated_part = input_ids[0, prompt_lengths:]
    generated_ids = diffusion_decoding(
        model,
        tokenizer,
        input_ids=input_ids,
        attention_mask=attention_mask,
        n_token_seq_len=16,
        temperature = 0.9,
        top_p = 0.9, 
        top_k = 20,
        repetition_penalty = 1.2, 
        lenience = 1.,
        accept_threshold = 0.2,
        confidence_threshold = 0.9,
        )

    input_ids = generated_ids
    attention_mask = torch.full_like(input_ids, 1).to(model.device)
    generated_str = ''.join(tokenizer.batch_decode(generated_ids, skip_special_tokens=False))
    print(generated_str)

generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# print(f'---------Generated Answer----------')
# print(response)