import re
import math
import torch
import random
import numpy as np
import torch.nn.functional as F
from math_verify import parse, verify, ExprExtractionConfig
from trl.models.utils import unwrap_model_for_generation
from trl.core import PPODecorators


def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens


def format_reward(text_list, **kwargs):
        pattern = r"^<think>\n.*?\n</think>\n<answer>\n(.*?)\n</answer>$"
        completion_contents = text_list
        matches = [re.findall(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
        return [1.0 if len(match)>0 else 0.0 for match in matches], matches


def llada_generating(model, x, prompt, prompt_index, mask_id,
                     ground_truth, start_step, end_step, tokenizer,
                     steps, block_length, steps_per_block,
                     steps_to_calculate_correlation_reward=[],
                     cfg_scale=0., remasking='low_confidence',
                     ppl_scorer=None, temperature=.5,
                     reward_list=["correlation", "format", "accuracy", "ppl"],
                     ):

    generating_step = 0
    group_correlation = []
    group_logprobs = 0

    start_block, start_block_start_step = divmod(start_step, steps_per_block)
    end_block, end_block_end_step = divmod(end_step, steps_per_block)
    if end_block_end_step == 0:
        end_block -= 1
        end_block_end_step = steps_per_block

    for num_block in range(start_block, end_block+1):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)

        block_start_step = start_block_start_step if num_block == start_block else 0
        block_end_step = end_block_end_step if num_block == end_block else steps_per_block

        for i in range(block_start_step, block_end_step):
            # Save current masked state before prediction
            current_masked = x.clone()

            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if remasking == 'low_confidence':
                p = F.softmax(logits.to(torch.float64), dim=-1)
                x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)

            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True

            # Save predicted state before updating x
            predicted_state = x0.clone()

            # Update x with the new predictions
            x = x.clone()
            x[transfer_index] = x0[transfer_index]

            # logprob
            transfer_elements_idx = torch.nonzero(transfer_index)
            for idx in transfer_elements_idx:
                idx = tuple(idx.tolist())
                group_logprobs += F.log_softmax(logits, dim=2)[idx][x[idx]]

            if "correlation" in reward_list:
                # crrelation_reward
                if len(transfer_elements_idx) >= 2:
                    with torch.no_grad():
                        if generating_step in steps_to_calculate_correlation_reward:
                            if len(transfer_elements_idx) == 2:
                                x_temp = x.clone()
                                idx = tuple(transfer_elements_idx[-1].tolist())
                                x_temp[idx] = mask_id
                                correlation_reward = F.softmax(model(x_temp).logits, dim=-1)[idx][x[idx]].item()
                            else:
                                correlation_rewards = []
                                for idx in transfer_elements_idx:
                                    idx = tuple(idx.tolist())
                                    x_temp = x.clone()
                                    x_temp[idx] = mask_id
                                    correlation_rewards.append(F.softmax(model(x_temp).logits, dim=-1)[idx][x[idx]].item())
                                correlation_reward = np.mean(correlation_rewards)

                            group_correlation.append(correlation_reward)

            generating_step += 1

    if "correlation" in reward_list:
        group_correlation_reward = np.mean(group_correlation) if len(group_correlation) > 0 else 0
    else:
        group_correlation_reward = None

    res_str = tokenizer.batch_decode(x[:, prompt.shape[1]:], skip_special_tokens=True)[0]

    group_ppl_reward = None
    if "ppl" in reward_list:
        # ppl_reward
        with torch.no_grad():
            if len(res_str)>0 and ppl_scorer:
                ppl_reward = ppl_scorer.get_perplexity([res_str])
                ppl_reward = (100 - ppl_reward[0]) / 100
                ppl_reward = max(0, ppl_reward)
            else:
                ppl_reward = 0
            group_ppl_reward = ppl_reward

    group_format_reward, group_accuracy_reward = None, None
    if "format" in reward_list or "accuracy" in reward_list:
        group_format_reward, matches = format_reward([res_str])
        group_format_reward, matches = group_format_reward[0], matches[0]

        if "accuracy" in reward_list:
            if len(matches) > 0 and ground_truth!=None:
                ground_truth = ground_truth[0]
                gold = parse(ground_truth, extraction_config=[ExprExtractionConfig()])
                answer = parse(matches[-1], extraction_config=[ExprExtractionConfig()])
                group_accuracy_reward = int(verify(gold, answer) or (ground_truth in matches[-1]))

    if not "format" in reward_list:
        group_format_reward = None

    group_final_reward = (
        (group_correlation_reward if group_correlation_reward else 0) + \
        (group_format_reward if group_format_reward else 0) + \
        (group_accuracy_reward if group_accuracy_reward else 0) + \
        (group_ppl_reward if ((group_ppl_reward!=None) and (not math.isnan(group_ppl_reward))) else 0)
    )
    assert not math.isnan(group_final_reward)

    reward = {
        "final_response": res_str,
        "final_reward": group_final_reward,
        "correlation_reward": group_correlation_reward,
        "ppl_reward": group_ppl_reward,
        "format_reward": group_format_reward,
        "accuracy_reward": group_accuracy_reward,
    }

    return x, group_logprobs, reward

@PPODecorators.empty_device_cache()
def generate_with_intermediates(model, prompt, tokenizer,
                                accelerator, config, optimizer, model_params,
                                ground_truth=None,
                                num_steps_per_group=4, sampling_size=4,
                                steps=128, gen_length=128, block_length=128, temperature=0.,
                                cfg_scale=0., remasking='low_confidence', mask_id=126336, steps_per_group=4,
                                sample_ratio_calculating_correlation_inside_response=100,
                                selected_groups_to_update_model=[],
                                num_of_groups_to_accumulate_grad=1,
                                ppl_scorer=None,
                                rejection_sampling=False,
                                reward_list=["correlation", "format", "accuracy", "ppl"]):
    '''
    Modified version that returns intermediate states
    Returns:
        A list of tuples containing (masked_state, predicted_state) for each step
    '''
    loss_list = []
    model_reward_list = []

    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()
    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length
    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    assert steps % num_steps_per_group == 0
    group_number = steps // num_steps_per_group

    accumulated_grad = 0

    reward_previous_group = None
    for i in range(group_number):
        accelerator.wait_for_everyone()
        if i > selected_groups_to_update_model[-1]:
            break
        selected_steps_in_group_for_correlation = random.sample(
            list(range(num_steps_per_group)),
            int(num_steps_per_group * sample_ratio_calculating_correlation_inside_response)
        )
        max_reward = None
        max_reward_logprob = None
        max_reward_next_x = None

        if i in selected_groups_to_update_model:
            accelerator.wait_for_everyone()
            model.train()
            with accelerator.accumulate(model):
                for j in range(sampling_size):
                    _x, _logprob, _reward = llada_generating(
                        model, x.clone(), prompt, prompt_index, mask_id, ground_truth,
                        i*num_steps_per_group, (i+1)*num_steps_per_group, tokenizer=tokenizer,
                        steps=steps, block_length=block_length, steps_per_block=steps_per_block,
                        steps_to_calculate_correlation_reward=selected_steps_in_group_for_correlation,
                        cfg_scale=cfg_scale, remasking=remasking,ppl_scorer=ppl_scorer, temperature=temperature,
                        reward_list=reward_list,
                    )
                    if max_reward==None or max_reward["final_reward"]<_reward["final_reward"]:
                        max_reward = _reward
                        max_reward_logprob = _logprob
                        max_reward_next_x = _x

                if rejection_sampling:
                    loss = - max_reward_logprob
                else:
                    if reward_previous_group:
                        final_reward = (
                            (max_reward["correlation_reward"] if max_reward["correlation_reward"] else 0) + \
                            (max_reward["ppl_reward"] if ((max_reward["ppl_reward"]!=None) and (not math.isnan(max_reward["ppl_reward"]))) else 0) + \
                            (
                                (max_reward["format_reward"] if max_reward["format_reward"] else 0) + \
                                (max_reward["accuracy_reward"] if max_reward["accuracy_reward"] else 0)
                            ) - \
                            (
                                (reward_previous_group["format_reward"] if reward_previous_group["format_reward"] else 0) + \
                                (reward_previous_group["accuracy_reward"] if reward_previous_group["accuracy_reward"] else 0)
                            )
                        )
                    else:
                        final_reward = max_reward["final_reward"]

                    loss = - final_reward * max_reward_logprob

                reward_previous_group = max_reward

                if type(max_reward_logprob) == torch.Tensor:
                    loss_list.append(loss.item())
                    model_reward_list.append(max_reward)
                    accelerator.backward(loss)
                    accumulated_grad += 1

                    if config.max_grad_norm is not None:
                        if accelerator.sync_gradients:
                            accelerator.clip_grad_norm_(model_params, config.max_grad_norm)

            x = max_reward_next_x
        else:
            model.eval()
            with torch.no_grad():
                _x, _logprob, _reward = llada_generating(
                    model, x.clone(), prompt, prompt_index, mask_id, ground_truth,
                    i*num_steps_per_group, (i+1)*num_steps_per_group, tokenizer=tokenizer,
                    steps=steps, block_length=block_length, steps_per_block=steps_per_block,
                    steps_to_calculate_correlation_reward=selected_steps_in_group_for_correlation,
                    cfg_scale=cfg_scale, remasking=remasking,ppl_scorer=ppl_scorer, temperature=temperature,
                    reward_list=reward_list,
                )
                reward_previous_group = _reward
                x = _x

        accelerator.wait_for_everyone()
        if (accumulated_grad % num_of_groups_to_accumulate_grad == 0) or \
           (i==selected_groups_to_update_model[-1]):
            optimizer.step()
            optimizer.zero_grad()
            accumulated_grad = 0

        accelerator.wait_for_everyone()

    if len(loss_list) == 0:
        loss_list = [nan]
        model_reward_list = [{
            "final_reward": math.nan,
            "correlation_reward": math.nan,
            "ppl_reward": math.nan,
            "format_reward": math.nan,
            "accuracy_reward": math.nan,
        }]

    stats = {
        "loss": loss_list,
    }

    for k in model_reward_list[0].keys():
        if k == "final_response":
            continue
        stats[k] = []

    for item in model_reward_list:
        for k,v in item.items():
            if k == "final_response":
                continue
            if v:
                stats[k].append(v)
            else:
                stats[k].append(math.nan)

    stats = {k: torch.tensor(np.nanmean(v)) for k, v in stats.items()}

    return stats
