import re
import sys
import math
import torch
import types
import lmppl
import random
import numpy as np
from tqdm import trange
from pathlib import Path
import torch.nn.functional as F
from math_verify import parse, verify, ExprExtractionConfig
from transformers import AutoConfig, AutoModel, AutoTokenizer


sys.path.insert(0, Path(__file__).parent.parent.as_posix())


class TestTimeScaling:

    def __init__(self, model_path, max_input_length=None, device=None, model_to_calculate_ppl=None):
        self.model_path = model_path
        # device = "auto" # "cuda" if torch.cuda.is_available() else "cpu"
        if device==None:
            device = "auto"
        self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map=device)

        self.max_input_length = max_input_length
        self.extra_tokenizer_config = {}
        if self.max_input_length:
            self.extra_tokenizer_config["model_max_length"] = self.max_input_length
            self.extra_tokenizer_config["truncation_side"] = "left"

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, **self.extra_tokenizer_config)
        self.model = self.model.eval()
        self.model_type = self.model.config.model_type
        assert self.model_type in ["llada"]
        mask_id = {
            "llada": 126336,
        }
        self.mask_id = mask_id[self.model_type]

        self.scorer = None
        # https://github.com/asahi417/lmppl
        if model_to_calculate_ppl:
            self.scorer = lmppl.LM(model_to_calculate_ppl)

    def to(self, device):
        self.model.to(device)

    # https://github.com/ML-GSAI/LLaDA/blob/4aa0dd2402fb9fec1137648cf768b56103a4a849/generate.py#L44
    def inference_llada_from_to(self, x, r_t, step_start, step_end,
                                total_steps=128, block_length=128, temperature=0.,
                                cfg_scale=0., remasking='low_confidence',
                                return_confidence_reward_per_step=False, reward_batch_size=2,
                                sample_ratio_calculating_correlation_inside_response=1):
        '''
        x: input_ids of prompt after applying the template
        r_t: response at time step t
        step_start: start from time step t
        step_end: end at time step t

        return: r_t after step_end, mask of tokens in r_t that changed during steps
        '''

        # _x = torch.full((1, x.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
        # _x[:, :x.shape[1]] = x.clone()
        res_correlation_rewards = []
        from diffusion_prediction.LLaDA.generate import add_gumbel_noise, get_num_transfer_tokens
        _x = torch.cat((x, r_t), 1)

        prompt_index = (_x != self.mask_id)

        gen_length = r_t.shape[-1]
        assert gen_length % block_length == 0
        num_blocks = gen_length // block_length

        assert total_steps % num_blocks == 0
        steps = total_steps // num_blocks
        # assert step_start % steps == 0 and step_end % steps == 0
        start_block = step_start // steps
        start_block_start_step = step_start % steps
        end_block = step_end // steps
        end_block_end_step = step_end % steps

        generating_step = 0

        if end_block_end_step == 0:
            end_block -= 1
            end_block_end_step = steps

        for num_block in range(start_block, end_block+1):
            block_start_step = start_block_start_step if num_block == start_block else 0
            block_end_step = end_block_end_step if num_block == end_block else steps

            block_mask_index = (_x[:, x.shape[1] + num_block * block_length: x.shape[1] + (num_block + 1) * block_length:] == self.mask_id)
            num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

            for i in range(block_start_step, block_end_step):
                mask_index = (_x == self.mask_id)
                if cfg_scale > 0.:
                    un_x = _x.clone()
                    un_x[prompt_index] = self.mask_id
                    x_ = torch.cat([_x, un_x], dim=0)
                    logits = self.model(x_).logits
                    logits, un_logits = torch.chunk(logits, 2, dim=0)
                    logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
                else:
                    logits = self.model(_x).logits

                logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
                x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

                if remasking == 'low_confidence':
                    p = F.softmax(logits.to(torch.float64), dim=-1)
                    x0_p = torch.squeeze(
                        torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
                elif remasking == 'random':
                    x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
                else:
                    raise NotImplementedError(remasking)

                x0_p[:, x.shape[1] + (num_block + 1) * block_length:] = -np.inf

                # replace logits with noise for mask
                x0 = torch.where(mask_index, x0, _x)
                confidence = torch.where(mask_index, x0_p, -np.inf)

                transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
                for j in range(confidence.shape[0]):
                    _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                    transfer_index[j, select_index] = True

                _x[transfer_index] = x0[transfer_index]

                steps_to_calculate_correlation_reward = random.sample(
                    list(range(step_start, step_end)),
                    int(sample_ratio_calculating_correlation_inside_response*(step_end-step_start+1))
                )
                transfer_elements_idx = torch.nonzero(transfer_index)
                if return_confidence_reward_per_step:
                    # crrelation_reward
                    if len(transfer_elements_idx) >= 2:
                        with torch.no_grad():
                            if generating_step in steps_to_calculate_correlation_reward:
                                correlation_reward = None
                                if len(transfer_elements_idx) == 2:
                                    x_temp = _x.clone()
                                    idx = tuple(transfer_elements_idx[-1].tolist())
                                    x_temp[idx] = self.mask_id
                                    correlation_reward = F.softmax(self.model(x_temp).logits, dim=-1)[idx][_x[idx]].item()
                                else:
                                    correlation_rewards = []
                                    for idx in transfer_elements_idx:
                                        idx = tuple(idx.tolist())
                                        x_temp = _x.clone()
                                        x_temp[idx] = self.mask_id
                                        correlation_rewards.append(F.softmax(self.model(x_temp).logits, dim=-1)[idx][_x[idx]].item())
                                    correlation_reward = np.mean(correlation_rewards)

                                res_correlation_rewards.append(correlation_reward)

                generating_step += 1

        if return_confidence_reward_per_step:
            return _x[:, x.shape[1]:], res_correlation_rewards
        else:
            return _x[:, x.shape[1]:], None

    # https://github.com/huggingface/open-r1/blob/50590a41b9c3c97dcca9018a1778f8dd1645f525/src/open_r1/rewards.py#L89
    def format_reward(self, text_list, **kwargs):
        pattern = r"^<think>\n.*?\n</think>\n<answer>\n(.*?)\n</answer>$"
        completion_contents = text_list
        matches = [re.findall(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
        return [1.0 if len(match)>0 else 0.0 for match in matches], matches

    def beam_search(self, x, r_t_list, step_start, step_end, search_every_steps_n, topk, n, reward_batch_size=16,
                    gen_length=128, ground_truth=None, reward_list=["correlation", "format", "accuracy", "ppl"], **kwargs):
        if self.model_type == "llada":
            inference_from_to = self.inference_llada_from_to
        else:
            raise NotImplementedError()

        if type(r_t_list) != list:
            return self.beam_search(x, [r_t_list], step_start, step_end, search_every_steps_n, topk, n,
                                    reward_batch_size=reward_batch_size,
                                    gen_length=128, ground_truth=ground_truth,
                                    reward_list=reward_list, **kwargs)

        trajectory = []
        for step_iter in range(step_start, step_end, search_every_steps_n):
            res_score = {}
            for r_t in r_t_list:
                res = {}
                for i in range(n):
                    res_, rewards = inference_from_to(x, r_t, step_iter, step_iter+search_every_steps_n,
                                                      return_confidence_reward_per_step=(
                                                          self.scorer!=None and "correlation" in reward_list
                                                      ), reward_batch_size=reward_batch_size, **kwargs)
                    # reward score calculated by confidence between masked tokens
                    if rewards:
                        reward_by_confidence_between_masked_tokens = np.mean(rewards).item()
                    else:
                        reward_by_confidence_between_masked_tokens = None

                    format_reward, accuracy_reward, ppl_reward = None, None, None
                    res_str = self.tokenizer.batch_decode(res_, skip_special_tokens=True)[0]
                    res_ = res_[0]
                    if "format" in reward_list or "accuracy" in reward_list:
                        # format reward
                        format_reward, matches = self.format_reward([res_str])
                        format_reward, matches = format_reward[0], matches[0]
                        if "accuracy" in reward_list:
                            # accuracy reward
                            # https://github.com/huggingface/Math-Verify
                            accuracy_reward = 0
                            if len(matches) > 0 and ground_truth!=None:
                                gold = parse(ground_truth, extraction_config=[ExprExtractionConfig()])
                                answer = parse(matches[-1], extraction_config=[ExprExtractionConfig()])
                                accuracy_reward = int(verify(gold, answer) or (ground_truth in matches[-1]))

                    if not "format" in reward_list:
                        format_reward = None

                    if "ppl" in reward_list:
                        # ppl reward
                        if len(res_str)>0 and self.scorer:
                            ppl_reward = self.scorer.get_perplexity([res_str])
                            ppl_reward = (100 - ppl_reward[0]) / 100
                            ppl_reward = max(0, ppl_reward)
                        else:
                            ppl_reward = 0

                    reward = 1.0 * (reward_by_confidence_between_masked_tokens if reward_by_confidence_between_masked_tokens else 0) + \
                        1.0 * (format_reward if format_reward else 0) + \
                        1.0 * (accuracy_reward if accuracy_reward else 0) + \
                        1.0 * (ppl_reward if ((ppl_reward!=None) and (not math.isnan(ppl_reward))) else 0)
                    res_score.update({res_str: {"reward_final": reward,
                                                "reward_format": format_reward,
                                                "reward_accuracy": accuracy_reward,
                                                "reward_ppl": ppl_reward,
                                                "reward_confidence": reward_by_confidence_between_masked_tokens,
                                                "r_t": res_}})

            res_score = dict(sorted(res_score.items(), key=lambda item: item[1]["reward_final"]))
            topk_res = {k: res_score[k] for k in list(res_score.keys())[:topk]}
            r_t_list = [v["r_t"].unsqueeze(0) for k,v in topk_res.items()]

            for k in topk_res.keys():
                topk_res[k].pop("r_t")
            trajectory.append(topk_res)

        generations = list(trajectory[-1].keys())
        return generations, trajectory

    @ torch.no_grad()
    def complete(self, chats, gen_length=128, steps=128, search_every_steps_n=16, topk=4, n=8,
                 reward_batch_size=4, ground_truth=None, **kwargs):
        # assert len(chats) == 1, "Batch size for test time scaling should be 1 currently!"
        all_generations, all_trajectories = [], []
        if ground_truth == None:
            ground_truth = [None for i in chats]

        for chat, ground_truth_item in zip(chats, ground_truth):
            if self.model_type in ["llada"]:
                input_ids = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, truncation=self.max_input_length!=None, max_length=self.max_input_length)['input_ids']
                x = input_ids.to(self.model.device)
                r_mask = torch.full((1, gen_length), self.mask_id, dtype=torch.long).to(self.model.device)
                generations, trajectory = self.beam_search(x, r_mask, 0, steps, search_every_steps_n, topk, n,
                                                           reward_batch_size=reward_batch_size, gen_length=gen_length,
                                                           ground_truth=ground_truth_item, total_steps=steps,
                                                           **kwargs)
                all_generations.append(generations)
                all_trajectories.append(trajectory)
            else:
                raise NotImplementedError()

        return all_generations, all_trajectories
