import torch
import numpy as np
import torch.nn.functional as F
import json
import os
from transformers import AutoTokenizer, AutoModel


def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1
    # import pdb;pdb.set_trace()
    return num_transfer_tokens


@ torch.no_grad()
def generate(model, prompt, steps_all=128, gen_length=128, block_length=128, temperature=0.,
             cfg_scale=0., remasking='low_confidence',tokenizer=None, mask_id=126336, step_temp = 0, num_block_temp = 0, save_midden_step = 128, x0_midden = None):
    '''
    Args:
        model: Mask predictor.  
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
    '''
    # import pdb;pdb.set_trace()
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    prompt_index = (x != mask_id)
    # import pdb;pdb.set_trace()
    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps_all % num_blocks == 0
    steps = steps_all // num_blocks

    if save_midden_step < steps_all: # save for get rewards
        temp_b = save_midden_step // steps
        temp_s = save_midden_step % steps

    for num_block in range(num_blocks):
        if num_block_temp != 0 and num_block < num_block_temp:
            continue
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        print(steps)
        for i in range(steps):
            if step_temp != 0 and step_temp < i:
                continue
            elif step_temp!=0 and step_temp == i:
                x = x0_midden
            else:
                mask_index = (x == mask_id)
                if cfg_scale > 0.:
                    un_x = x.clone()
                    un_x[prompt_index] = mask_id
                    x_ = torch.cat([x, un_x], dim=0)
                    logits = model(x_).logits
                    logits, un_logits = torch.chunk(logits, 2, dim=0)
                    logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
                else:
                    logits = model(x,output_hidden_states=True).logits
                logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
                x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
                x_wo_mask = x0.clone()

            if remasking == 'low_confidence':
                p = F.softmax(logits.to(torch.float64), dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]
            if save_midden_step < steps_all and temp_b == num_block and temp_s == i:
                return {"input_ids":prompt,"input_ids_wo_mask":x_wo_mask,"input_ids_w_mask":x,"temp_b":temp_b,"temp_s":temp_s}

    return x


def main_for_get_midden_step():
    device = 'cuda'
    model = AutoModel.from_pretrained('./models/llada_test', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained('./models/llada_test', trust_remote_code=True)

    # prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"
    data_path = "./data/alpaca_data_en_52k.json"
    with open(data_path,"r",encoding="utf-8") as scr:
        datasets = json.load(scr)
    batchs_prompt = []
    prompts = []
    batch_size = 1
    for i in range(22):
        if i%batch_size == 0 and i!=0:
            batchs_prompt.append(prompts)
            prompts = []
        prompts.append([{"role": "user", "content": datasets[i]["instruction"] + "\n" + datasets[i]["input"]}])
    batchs_prompt.append(prompts)
    # Add special tokens for the Instruct model. The Base model does not require the following two lines.
    # m = [{"role": "user", "content": prompt}, ]

    for_save_midden = []
    save_step_tag = 42
    results_text = []
    for idx,batch in enumerate(batchs_prompt):
        prompt = tokenizer.apply_chat_template(batch, add_generation_prompt=True, tokenize=False)
        # input_ids = tokenizer(prompt,max_length=50,padding="max_length",return_tensors="pt")['input_ids']
        input_ids = tokenizer(prompt, return_tensors="pt")['input_ids']
        # import pdb;pdb.set_trace()
        # input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
        input_ids = torch.tensor(input_ids).to(device)

        import random
        save_midden_step_random = random.randint(0,128//2-1)
        out = generate(model, input_ids, steps_all=128, gen_length=128, block_length=32, temperature=0.2, cfg_scale=0., remasking='low_confidence', tokenizer=tokenizer, save_midden_step = save_midden_step_random)
        results_text.append(tokenizer.batch_decode(out["input_ids_wo_mask"], skip_special_tokens=True)[0])
        for_save_midden.append({k: v.cpu() if torch.is_tensor(v) else v for k, v in out.items()})
        if idx % save_step_tag == 0 and idx!=0:
            torch.save(for_save_midden, f"midden_features/for_save_midden_{idx-len(for_save_midden)+1}-{idx}.pt")
            del for_save_midden
            for_save_midden = []
    with open("results_text.json","w",encoding="utf-8") as f:
        json.dump(results_text,f,ensure_ascii=False,indent=2)
    torch.save(for_save_midden, f"midden_features/for_save_midden_{idx-len(for_save_midden)+1}-{idx}.pt")
        # import pdb;pdb.set_trace()
        # print(tokenizer.batch_decode(out["input_ids"], skip_special_tokens=True)[0])

def main_for_get_end():
    device = 'cuda'
    model = AutoModel.from_pretrained('./models/llada_test', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained('./models/llada_test', trust_remote_code=True)

    # prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"
    data_midden_path = "./models/midden_features"
    data_list = os.listdir(data_midden_path)
    sample_number = 4
    result_all = []
    count = 0
    for file_name in data_list:
        data_path_all = os.path.join(data_midden_path,file_name)
        data = torch.load(data_path_all)
        for idx,data_one in enumerate(data):
            count += 1
            temp_b = data_one["temp_b"]
            temp_s = data_one["temp_s"]
            input_ids = data_one["input_ids"].to(device)
            # input_ids_wo_mask = data_one["input_ids_wo_mask"].to(device)
            input_ids_w_mask = data_one["input_ids_w_mask"].to(device)
            res = []
            for i in range(sample_number):
                out = generate(model, input_ids, steps_all=128, gen_length=128, block_length=32, temperature=0.2, cfg_scale=0., remasking='low_confidence', tokenizer=tokenizer, save_midden_step = 128, step_temp = temp_s, num_block_temp = temp_b, x0_midden = input_ids_w_mask)
                res.append(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
            instruction = tokenizer.batch_decode(out[:, :input_ids.shape[1]], skip_special_tokens=True)[0]
            result_all.append(
                {
                    "instruction": instruction,
                    "candidate_res":res
                }
            )


    with open("result_all.json","w",encoding="utf-8") as f:
        json.dump(result_all,f,ensure_ascii=False,indent=2)

if __name__ == '__main__':
    # readme
    # ========================================================================================
    # main_for_get_midden_step() for get random step
    # set save_midden_step < steps_all to get random step, save_midden_step is the random step
    # return:
    # -"input_ids": is the prompt
    # -"input_ids_wo_mask": to get the response in a random step, we need to use this input_ids_wo_mask
    # -"input_ids_w_mask": to get the next step, we need to use this input_ids_w_mask
    # "temp_b": the current block number of the random step
    # "temp_s": the current step number of the random step
    # ========================================================================================
    # main_for_get_end for get end step by using the middle 
    # set save_midden_step = steps_all to get end step, step_temp is the current step number of the random step, num_block_temp is the current block number of the random step, x0_midden is the input_ids_w_mask of the random step
    # return:
    # -"instruction": is the prompt
    # -"candidate_res": is the response of the random step
    # ========================================================================================
    main_for_get_midden_step()
    main_for_get_end()