"""BenchmarkEvaluator for running datasets and saving predictions.

Encapsulates the core of the prior get_pred flow in a reusable class.
"""

from __future__ import annotations

import json
import os
import time
from typing import Iterable, Dict, Any
import torch

from tqdm import tqdm

from ..utils.prompting import build_chat, post_process
from ..utils.normalization import SafeDict, normalize_prompt_fields
from ..utils.timer import GlobalTimer
from ..patching.patcher import patch_model_for_mask_topk


class BenchmarkEvaluator:
    def __init__(self, model, tokenizer, timer: GlobalTimer):
        self.model = model
        self.tokenizer = tokenizer
        self.timer = timer

    def run(self, dataset: Iterable[Dict[str, Any]], prompt_format: str, max_gen: int, out_path: str,
            *, model_name: str, max_length: int, dataset_name: str) -> None:
        if not os.path.exists(os.path.dirname(out_path)):
            os.makedirs(os.path.dirname(out_path), exist_ok=True)

        device = self.model.device
        min_ctx_length = 100000
        max_ctx_length = 0
        line_num = 0
        all_time_elapsed = 0.0
        all_token_generated = 0

        # resume if file exists
        if os.path.exists(out_path):
            with open(out_path, "r", encoding="utf-8") as f:
                while f.readline():
                    line_num += 1

        # Use shared helpers to make prompt formatting robust across dataset field names

        for i, json_obj in tqdm(enumerate(dataset)):
            if i < line_num:
                continue

            # obj_norm = normalize_prompt_fields(json_obj)
            prompt = prompt_format.format(**json_obj)
            tokenized_prompt = self.tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
            original_token_cnt = len(tokenized_prompt)
            if len(tokenized_prompt) > max_length:
                half = int(max_length / 2)
                prompt = self.tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + self.tokenizer.decode(
                    tokenized_prompt[-half:], skip_special_tokens=True
                )
                original_token_cnt = max_length

            # Align with test.py: do NOT pre-build chat for specific datasets
            # Important: use the explicit dataset_name passed from the runner (json_obj may not have a 'dataset' field)
            if dataset_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
                prompt = build_chat(self.tokenizer, prompt, model_name)
                
            inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
            context_length = inputs.input_ids.shape[-1]
            min_ctx_length = min(min_ctx_length, context_length)
            max_ctx_length = max(max_ctx_length, context_length)

            begin_gen = time.perf_counter()
            outputs = self.model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                pad_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=max_gen,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
                output_scores=False,
            )[0]
            end_gen = time.perf_counter()

            all_time_elapsed += end_gen - begin_gen
            all_token_generated += int(outputs[context_length:].shape[0])

            pred = self.tokenizer.decode(outputs[context_length:], skip_special_tokens=True).strip("\n")
            pred = post_process(pred, model_name)

            with open(out_path, "a", encoding="utf-8") as f:
                json.dump(
                    {
                        "pred": pred,
                        "answers": json_obj.get("answers"),
                        "all_classes": json_obj.get("all_classes"),
                        "length": json_obj.get("length"),
                        "request_time": {"batch_time": 0, "batch_size": 1},
                        "input_tokens": int(original_token_cnt),
                    },
                    f,
                    ensure_ascii=False,
                )
                f.write("\n")

        if all_token_generated > 0:
            print("time per token decode:", self.timer.decode_total / all_token_generated)
            print("time per append:", self.timer.append_time / all_token_generated)
            print("time per search:", self.timer.search_time / all_token_generated)
            print("time per fi kv:", self.timer.skip_search_time / all_token_generated)
            print("time per attn:", self.timer.attn_time / all_token_generated)
            print("time per mlp:", self.timer.mlp_time / all_token_generated)
            print("time per attention_layer_total:", self.timer.attnetion_layer_total / all_token_generated)

    def run_mask_topk_multi(self,
                            dataset: Iterable[Dict[str, Any]],
                            prompt_format: str,
                            max_gen: int,
                            out_dir: str,
                            sparsity_list,
                            *, model_name: str, max_length: int, dataset_name: str) -> None:
        """Prefill once per sample, then decode multiple times with different sparsity ratios.

        Produces one jsonl per sparsity ratio in out_dir.
        """
        os.makedirs(out_dir, exist_ok=True)

        # Prepare writers per ratio
        writers = {}
        for r in sparsity_list:
            fname = f"mask_topk_top{r:.4f}.jsonl"
            writers[r] = open(os.path.join(out_dir, fname), "a", encoding="utf-8")

        for i, json_obj in tqdm(enumerate(dataset)):
            prompt = prompt_format.format(**json_obj)
            tokenized_prompt = self.tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
            original_token_cnt = len(tokenized_prompt)
            if len(tokenized_prompt) > max_length:
                half = int(max_length / 2)
                prompt = self.tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + self.tokenizer.decode(
                    tokenized_prompt[-half:], skip_special_tokens=True
                )
                original_token_cnt = max_length

            if dataset_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
                prompt = build_chat(self.tokenizer, prompt, model_name)
            inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
            context_length = inputs.input_ids.shape[-1]

            # Prefill once by calling a standard forward pass over the full prompt
            # Important: disable autograd and avoid HF past_key_values to reduce memory.
            with torch.no_grad():
                _ = self.model(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    use_cache=False,
                )

            # For each sparsity ratio: reset managers to prefill boundary and decode
            for layer in getattr(self.model, "model").layers:
                mgr = getattr(layer, "mask_topk_manager", None)
                if mgr is not None:
                    mgr.reset_to_prefill()

            for r in sparsity_list:
                # Reset to prefill boundary for this ratio
                for layer in getattr(self.model, "model").layers:
                    mgr = getattr(layer, "mask_topk_manager", None)
                    if mgr is not None:
                        mgr.reset_to_prefill()
                # Set ratio
                for layer in getattr(self.model, "model").layers:
                    mgr = getattr(layer, "mask_topk_manager", None)
                    if mgr is not None:
                        mgr.set_sparsity_ratio(float(r))

                generated_tokens = []
                last_token = inputs.input_ids[:, -1:].to(self.model.device)
                with torch.no_grad():
                    for step in range(max_gen):
                        # Position id equals absolute position of this new token
                        # after the prefill sequence.
                        pos = context_length + step
                        pos_ids = torch.tensor([[pos]], device=self.model.device, dtype=torch.long)
                        attn_mask = torch.ones((1, 1), device=self.model.device, dtype=torch.long)
                        out = self.model(
                            input_ids=last_token,
                            position_ids=pos_ids,
                            attention_mask=attn_mask,
                            use_cache=False,
                        )
                        logits = out.logits[:, -1, :]
                        next_token = torch.argmax(logits, dim=-1, keepdim=True)
                        generated_tokens.append(next_token)
                        last_token = next_token
                        if self.tokenizer.eos_token_id is not None and int(next_token.item()) == int(self.tokenizer.eos_token_id):
                            break

                if len(generated_tokens) > 0:
                    gen_ids = torch.cat(generated_tokens, dim=1)
                else:
                    gen_ids = torch.empty((1, 0), dtype=torch.long, device=self.model.device)
                pred = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True).strip("\n")
                pred = post_process(pred, model_name)

                f = writers[r]
                json.dump(
                    {
                        "pred": pred,
                        "answers": json_obj.get("answers"),
                        "all_classes": json_obj.get("all_classes"),
                        "length": json_obj.get("length"),
                        "request_time": {"batch_time": 0, "batch_size": 1},
                        "input_tokens": int(original_token_cnt),
                    },
                    f,
                    ensure_ascii=False,
                )
                f.write("\n")

        for f in writers.values():
            f.close()

    def run_mask_topk_recall_multi(self,
                                   dataset: Iterable[Dict[str, Any]],
                                   prompt_format: str,
                                   max_gen: int,
                                   out_dir: str,
                                   base_ratio: float,
                                   recall_list,
                                   *, model_name: str, max_length: int, dataset_name: str) -> None:
        """Prefill once, decode multiple times under recall-aware top-k selection.

        File naming convention:
          mask_topk_recall_top{base_ratio:.4f}_recall{recall:.0f}.jsonl
        """
        import torch
        os.makedirs(out_dir, exist_ok=True)

        # Prepare writers per recall ratio
        writers = {}
        for recall in recall_list:
            fname = f"mask_topk_recall_top{base_ratio:.4f}_recall{float(recall):.0f}.jsonl"
            writers[float(recall)] = open(os.path.join(out_dir, fname), "a", encoding="utf-8")

        for i, json_obj in tqdm(enumerate(dataset)):
            prompt = prompt_format.format(**json_obj)
            tokenized_prompt = self.tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
            original_token_cnt = len(tokenized_prompt)
            if len(tokenized_prompt) > max_length:
                half = int(max_length / 2)
                prompt = self.tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + self.tokenizer.decode(
                    tokenized_prompt[-half:], skip_special_tokens=True
                )
                original_token_cnt = max_length

            if dataset_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
                prompt = build_chat(self.tokenizer, prompt, model_name)
            inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
            context_length = inputs.input_ids.shape[-1]

            # Prefill once using standard forward
            with torch.no_grad():
                _ = self.model(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    use_cache=False,
                )

            # Reset managers and run per-recall decode
            for layer in getattr(self.model, "model").layers:
                mgr = getattr(layer, "mask_topk_manager", None)
                if mgr is not None and hasattr(mgr, "reset_to_prefill"):
                    mgr.reset_to_prefill()

            for recall in recall_list:
                # Reset and set ratios for this run
                for layer in getattr(self.model, "model").layers:
                    mgr = getattr(layer, "mask_topk_manager", None)
                    if mgr is not None:
                        if hasattr(mgr, "reset_to_prefill"):
                            mgr.reset_to_prefill()
                        if hasattr(mgr, "set_sparsity_ratio"):
                            mgr.set_sparsity_ratio(float(base_ratio))
                        if hasattr(mgr, "set_recall_ratio"):
                            mgr.set_recall_ratio(float(recall) / 100.0 if float(recall) > 1.0 else float(recall))

                generated_tokens = []
                last_token = inputs.input_ids[:, -1:].to(self.model.device)
                with torch.no_grad():
                    for step in range(max_gen):
                        pos = context_length + step
                        pos_ids = torch.tensor([[pos]], device=self.model.device, dtype=torch.long)
                        attn_mask = torch.ones((1, 1), device=self.model.device, dtype=torch.long)
                        out = self.model(
                            input_ids=last_token,
                            position_ids=pos_ids,
                            attention_mask=attn_mask,
                            use_cache=False,
                        )
                        logits = out.logits[:, -1, :]
                        next_token = torch.argmax(logits, dim=-1, keepdim=True)
                        generated_tokens.append(next_token)
                        last_token = next_token
                        if self.tokenizer.eos_token_id is not None and int(next_token.item()) == int(self.tokenizer.eos_token_id):
                            break

                if len(generated_tokens) > 0:
                    gen_ids = torch.cat(generated_tokens, dim=1)
                else:
                    gen_ids = torch.empty((1, 0), dtype=torch.long, device=self.model.device)
                pred = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True).strip("\n")
                pred = post_process(pred, model_name)

                f = writers[float(recall)]
                json.dump(
                    {
                        "pred": pred,
                        "answers": json_obj.get("answers"),
                        "all_classes": json_obj.get("all_classes"),
                        "length": json_obj.get("length"),
                        "request_time": {"batch_time": 0, "batch_size": 1},
                        "input_tokens": int(original_token_cnt),
                    },
                    f,
                    ensure_ascii=False,
                )
                f.write("\n")

        for f in writers.values():
            f.close()




