from .dataset import Gsm8kDataset, MathDataset, Math500Dataset, AimeDataset, AMC23Dataset
from .reward import Reward_Prm, Reward_Skywork
from pmpd import PMPDForCausalLM, Scheduler
from tqdm import tqdm
from transformers import AutoTokenizer
from collections import defaultdict
import torch
import gc
import os
data_root = os.environ.get("DATA_ROOT")

dataset = {
    "gsm8k": Gsm8kDataset,
    "math": MathDataset,
    "math500": Math500Dataset,
    "aime": AimeDataset,
    "amc23": AMC23Dataset,
}

dataset_path = {
    "gsm8k": f"{data_root}/gsm8k",
    "math": f"{data_root}/efficient-reasoning/competition_math",
    "math500": f"{data_root}/MATH-500",
    "amc23":f"{data_root}/AMC-23",
    "aime": f"{data_root}/aime-2021-2025",
}

model_path = {
    "qwen7b": f"{data_root}/quantize_model/packed/qwen7b-distill",
    "qwen38": f"{data_root}/quantize_model/packed/qwen3-8b",
}

reward_model = {
    "prm": Reward_Prm,
    "skywork": Reward_Skywork,
}

reward_model_path = {
    "prm": f"{data_root}/Qwen2.5-Math-PRM-7B",
    "skywork": f"{data_root}/Skywork-Reward-Llama-3.1-8B-v0.2",
}

top_p = {
    "qwen7b": 1,
    "qwen38": 0.95,
}

top_k = {
    "qwen7b": 0,
    "qwen38": 20,
}

min_p = {
    "qwen7b": 0,
    "qwen38": 0,
}

token_place = {
    "qwen7b": 1,
    "qwen38": 0,
}

descent_prompt = {
    "None": "</think>",
    "enterthink": "\n</think>",
    "entertime": "\n</think>Time is limited, let's directly answer this question.",
    "enterfinal":"\n</think> **Final Answer**",
    "enternothink":" Okay, I think I have finished thinking. \n<think>",
    "nothink":" Okay, I think I have finished thinking. </think>"
}
class er_model:
    def __init__(self, **kwargs):        
        self.device = kwargs.get("device")
        if self.device is None:
            raise ValueError("device parameter is required")

        reward_model_name = kwargs.get("reward_model")
        if reward_model_name is None:
            raise ValueError("reward_model parameter is required")
        
        model_name = kwargs.get("model")
        if model_name is None:
            raise ValueError("model parameter is required")
            
        dataset_name = kwargs.get("dataset")
        if dataset_name is None:
            raise ValueError("dataset parameter is required")

        self.reward_model = reward_model[reward_model_name](model_path=reward_model_path[reward_model_name], device=self.device)
        self.model = PMPDForCausalLM(model_path[model_name],
                                     precisions=[2,3,4,5,6,7,8],
                                     use_anyprec=True,
                                     Solution=True,
                                     prune_func=kwargs.get("prune_func", None),
                                     reward_func=self.reward_model,
                                     split=kwargs.get("split", False),
                                     ).eval().to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path[model_name])
        if kwargs.get("xverify", False):
            self.xverify_path = f"{data_root}/xVerify-9B-C"
        else:
            self.xverify_path = None
        self.dataset = dataset[dataset_name](dataset_path=dataset_path[dataset_name],
                                                      prompt_type=kwargs.get("prompt_type"),
                                                      shuffle=kwargs.get("shuffle", False),
                                                      xverify_path=self.xverify_path,
                                                      device=self.device,
                                                      )
        self.max_steps = kwargs.get("max_steps", 255)
        self.prefill_bit = kwargs.get("prefill_bit", 8)
        self.past_key_values = kwargs.get("past_key_values", None)
        self.record_token_count = kwargs.get("record_token_count", False)

        self.think_token = self.tokenizer.encode("<think>")[token_place[model_name]]
        self.think_end_token = self.tokenizer.encode("</think>")[token_place[model_name]]

        gpu_idx = int(self.device.split(":")[1])
        torch.cuda.set_device(gpu_idx)

        config = self.model.model.config
        if config is None:
            raise ValueError("Model config is None")

        kw_dict = {
            "precisions": kwargs.get("naive_bit"),
            "high_bit_steps": kwargs.get("high_bit_steps"),
            "precision_switch_points": [],
            "save_dir": "phpd_test",
            "dim": config.hidden_size // config.num_attention_heads,
            "num_heads": config.num_key_value_heads,
            "cot": kwargs.get("part") == "cot",
            "solution": kwargs.get("part") == "solution",
            "answer": kwargs.get("part") == "answer",
            "do_sample": kwargs.get("do_sample", True),
            "temperature": kwargs.get("temperature", 0.6),
            "top_p": top_p[model_name],
            "top_k": top_k[model_name],
            "min_p": min_p[model_name],
            "split": kwargs.get("split", True),
            "mean_score": kwargs.get("mean_score"),
            "max_score": kwargs.get("max_score"),
            "block_with_split_mean_score": kwargs.get("block_with_split_mean_score"),
            "block_with_split_max_score": kwargs.get("block_with_split_max_score"),
            "alpha_split": kwargs.get("alpha_split"),
            "minus_score": kwargs.get("minus_score"),
            "problem_split_mean_score": kwargs.get("problem_split_mean_score"),
            "problem_split_max_score": kwargs.get("problem_split_max_score"),
            "computation_split_mean_score": kwargs.get("computation_split_mean_score"),
            "computation_split_max_score": kwargs.get("computation_split_max_score"),
            "verification_split_mean_score": kwargs.get("verification_split_mean_score"),
            "verification_split_max_score": kwargs.get("verification_split_max_score"),
            "sol_precision": kwargs.get("sol_precision"),
            "windows": kwargs.get("windows"),
            "text_type": kwargs.get("text_type"),
            "dewey_text_type": kwargs.get("dewey_text_type"),
            "calon_mean_score": kwargs.get("calon_mean_score"),
            "calon_max_score": kwargs.get("calon_max_score"),
            "calve_mean_score": kwargs.get("calve_mean_score"),
            "calve_max_score": kwargs.get("calve_max_score"),
            "seek_mean_score": kwargs.get("seek_mean_score"),
            "seek_max_score": kwargs.get("seek_max_score"),
            "problem_mean_score": kwargs.get("problem_mean_score"),
            "problem_max_score": kwargs.get("problem_max_score"),
            "computation_mean_score": kwargs.get("computation_mean_score"),
            "computation_max_score": kwargs.get("computation_max_score"),
            "verification_mean_score": kwargs.get("verification_mean_score"),
            "verification_max_score": kwargs.get("verification_max_score"),
            "descent_prompt": descent_prompt[kwargs.get("descent_prompt", "None")],
        }
        self.kw_dict = kw_dict
        self.model.scheduler = Scheduler.get_scheduler(kwargs.get("scheduler"), **kw_dict)

    def evaluate(self, num_samples=None):
        content = []
        cot_precision = defaultdict(int)
        text_type_stats = defaultdict(int)
        list_prob_correct = []
        list_prob_false = []
        list_split_correct = []
        list_split_false = []
        all_split_prob_15 = []

        prompt = self.dataset.get_prompt(index=num_samples)
        answer_token_len = []
        thinking_chain_token_len = []
        answer = []
        thinking_chain = []
        no_thinking_chain_count = 0
        
        real_time_token_counts = []

        for item in tqdm(prompt, desc="Processing prompts", unit="prompt"):
            self.model.reset(item, "")
            messages = [{"role": "user", "content": item}]
            text = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
            # print(text)
            model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
            input_ids = model_inputs.input_ids
            # print(input_ids)
            outputs = self.model.generate(
                input_ids=input_ids, 
                max_new_tokens=self.max_steps, 
                prefill_bit=self.prefill_bit, 
                past_key_values=self.past_key_values,
                **self.kw_dict,
            )

            if self.record_token_count:
                generated_tokens = len(outputs.input_ids[0]) - len(input_ids[0])
                real_time_token_counts.append(generated_tokens)

            temp_precision = outputs.cot_precision.copy()
            for key, value in temp_precision.items():
                cot_precision[key] += value
            # print("cot_precision",cot_precision)

            output_ids = outputs.input_ids
            output_ids = output_ids[0][len(input_ids[0])-2:]
    
            temp_output = self.tokenizer.decode(output_ids)
            #print("temp_output",temp_output)
            content.append(temp_output)
            output_ids_list = output_ids.tolist()
            try:
                start_idx = output_ids_list.index(self.think_token)
                end_idx = output_ids_list.index(self.think_end_token)

                thinking_chain_ids = output_ids[start_idx+1:end_idx]
                solution_ids = output_ids[end_idx+1:]

                answer_token_len.append(len(solution_ids))
                thinking_chain_token_len.append(len(thinking_chain_ids))

                thinking_chain_text = self.tokenizer.decode(thinking_chain_ids, skip_special_tokens=True)
                solution_text = self.tokenizer.decode(solution_ids, skip_special_tokens=True)

                thinking_chain.append(thinking_chain_text)
                answer.append(solution_text)
            except ValueError:
                no_thinking_chain_count += 1
                thinking_chain.append("Error Thinking Chain")
                answer.append("Error Answer")
    
        # eval_result = self.dataset.result_eval_reward(answer, thinking_chain, self.reward_model, system_prompt="")
        # accuracy = eval_result["accuracy"]
        # correct_reward = eval_result["correct_reward"]
        # wrong_reward = eval_result["wrong_reward"]

        avg_thinking_chain_len = sum(thinking_chain_token_len) / len(thinking_chain_token_len) if thinking_chain_token_len else 0
        avg_solution_len = sum(answer_token_len) / len(answer_token_len) if answer_token_len else 0

        # correct_reward_mean = sum(correct_reward) / len(correct_reward) if correct_reward else 0
        # wrong_reward_mean = sum(wrong_reward) / len(wrong_reward) if wrong_reward else 0
        weighted_sum = sum(bit * count for bit, count in cot_precision.items())
        non_zero_tokens = sum(count for bit, count in cot_precision.items())
        avg_precision = weighted_sum / non_zero_tokens if non_zero_tokens > 0 else 0
        
        precision_stats = {
            'total_tokens': non_zero_tokens,
            'avg_precision': avg_precision,
            'distribution': {}
        }
        for bit, count in sorted(cot_precision.items()):
            if bit != 0: 
                probability = count / non_zero_tokens
                precision_stats['distribution'][bit] = {
                    'probability': probability
                }

        if torch.cuda.is_available():
            all_prob_means = self.model.list_prob if hasattr(self.model, 'list_prob') else []
            all_split_probs = self.model.list_split_prob if hasattr(self.model, 'list_split_prob') else []
            all_split_prob_15 = self.model.list_split_prob_15 if hasattr(self.model, 'list_split_prob_15') else []
            
            del self.model
            del self.reward_model
            del self.tokenizer
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            gc.collect()

        if self.xverify_path is not None:
            accuracy, correct_list = self.dataset.eval_math_is_correct(answer)
            for i, is_correct in enumerate(correct_list):
                if i < len(all_prob_means): 
                    prob_mean = all_prob_means[i]
                    if is_correct == 1:
                        list_prob_correct.append(prob_mean)
                    else:
                        list_prob_false.append(prob_mean)
        else:
            accuracy = self.dataset.result_eval(answer)
        
        if all_split_probs and self.xverify_path is not None:
            for i, is_correct in enumerate(correct_list):
                if i < len(all_split_probs): 
                    split_prob = all_split_probs[i]
                    if is_correct == 1:
                        list_split_correct.append(split_prob)
                    else:
                        list_split_false.append(split_prob)
        
        split_prob_15_correct = [[] for _ in range(15)]
        split_prob_15_false = [[] for _ in range(15)]
        
        if all_split_prob_15 and self.xverify_path is not None:
            for i, is_correct in enumerate(correct_list):
                if i < len(all_split_prob_15): 
                    split_prob_15 = all_split_prob_15[i]
                    for pos in range(15):
                        if pos < len(split_prob_15):
                            prob_value = split_prob_15[pos]
                            if is_correct == 1:
                                split_prob_15_correct[pos].append(prob_value)
                            else:
                                split_prob_15_false[pos].append(prob_value)
        
        avg_split_prob_15_correct = []
        avg_split_prob_15_false = []
        for pos in range(15):
            if split_prob_15_correct[pos]:
                avg_correct = sum(split_prob_15_correct[pos]) / len(split_prob_15_correct[pos])
            else:
                avg_correct = 0.0
            avg_split_prob_15_correct.append(avg_correct)
            if split_prob_15_false[pos]:
                avg_false = sum(split_prob_15_false[pos]) / len(split_prob_15_false[pos])
            else:
                avg_false = 0.0
            avg_split_prob_15_false.append(avg_false)
        
        avg_prob_correct = sum(list_prob_correct) / len(list_prob_correct) if list_prob_correct else 0
        avg_prob_false = sum(list_prob_false) / len(list_prob_false) if list_prob_false else 0
        avg_split_prob_correct = sum(list_split_correct) / len(list_split_correct) if list_split_correct else 0
        avg_split_prob_false = sum(list_split_false) / len(list_split_false) if list_split_false else 0
        
        # Prepare results
        results = {
            "accuracy": accuracy,
            "no_thinking_chain_count": no_thinking_chain_count,
            "cot_precision": avg_precision,
            "avg_thinking_chain_len": avg_thinking_chain_len,
            "avg_solution_len": avg_solution_len,
            "precision_stats": precision_stats,
            "text_type_stats": dict(text_type_stats),
            "avg_prob_correct": avg_prob_correct,
            "avg_prob_false": avg_prob_false,
            "avg_split_prob_correct": avg_split_prob_correct,
            "avg_split_prob_false": avg_split_prob_false,
            "avg_split_prob_15_correct": avg_split_prob_15_correct,
            "avg_split_prob_15_false": avg_split_prob_15_false,
        }
        
        if self.record_token_count:
            results["real_time_token_counts"] = real_time_token_counts
            results["avg_tokens_per_question"] = sum(real_time_token_counts) / len(real_time_token_counts) if real_time_token_counts else 0
            results["total_tokens_generated"] = sum(real_time_token_counts) if real_time_token_counts else 0

        return results, content


if __name__ == "__main__":
    args = {
            "model_path": f"{data_root}/quantize_model/packed/anyprec--w8_orig2-gc1-c4_s100_blk512",
            "dataset": "gsm8k",
            "reward_model": "skywork",
            "reward_model_path": f"{data_root}/Skywork-Reward-Llama-3.1-8B-v0.2",
            "prompt_type": "better",
            "scheduler": "part_split",
            "device": "cuda:0",
            "max_steps": 2048,
            "prefill_bit": 8,
            "past_key_values": None,
            "naive_bit": [7,6],
            "high_bit_steps": 512,
            "part": "cot",
            "do_sample": True,
            "temperature": 0.6,
            "top_p": 1.0,
            "top_k": 0,
            "min_p": 0.0,
            "record_token_count": True,
        }
    model_instance = er_model(**args)
    results = model_instance.evaluate(num_samples=0)
    print(results)
