import os
import json
import warnings
import logging
import argparse
import random
import time 
from tqdm import tqdm
import numpy as np
from typing import Iterable, Tuple, List, Dict, Any

# --- Core Libraries ---
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, AutoModelForCausalLM
import trl 

warnings.filterwarnings("ignore", category=FutureWarning)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def _assemble_qwen_text(example: Dict[str, Any]) -> Dict[str, Any]:

    SYSTEM_PROMPT = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>"
    
    if "thinking_trajectories" in example and example["thinking_trajectories"] and len(example["thinking_trajectories"]) > 0:
        think_text = example["thinking_trajectories"][0]
    else:
        think_text = "" 
    
    answer_text = example.get("attempt", "")
    question_text = example.get("question", "")
    cot_type = example.get("cot_type", None)

    parts = [
        SYSTEM_PROMPT,
        "<|im_start|>user",
        question_text,
        "<|im_end|>",
        "<|im_start|>assistant", 
        "<|im_start|>think",
        think_text,
        "<|im_start|>answer",
        answer_text,
        "<|im_end|>"
    ]

    assembled_text = "\n".join(parts)
    
    return {"text": assembled_text, "cot_type": cot_type}


def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    pd = F.softmax(logits, dim=-1)
    return torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)

class CustomDataCollatorForCompletionOnlyLM:

    def __init__(self, tokenizer, instruction_template, response_template, max_seq_length, **kwargs):
        self.tokenizer = tokenizer
        self.instruction_template = instruction_template
        self.response_template = response_template
        self.max_seq_length = max_seq_length
        self.label_pad_token_id = -100

        self.instruction_token_ids = self.tokenizer.encode(instruction_template, add_special_tokens=False)
        self.response_token_ids = self.tokenizer.encode(response_template, add_special_tokens=False)
        
        logging.info(f"Collator: instruction_token_ids: {self.instruction_token_ids}")
        logging.info(f"Collator: response_token_ids: {self.response_token_ids}")


    def __call__(self, examples: List[dict]) -> dict:
        texts = [e["text"] for e in examples]
        cot_types = [e.get("cot_type", None) for e in examples]

        batch = self.tokenizer(
            texts, padding=True, truncation=True, max_length=self.max_seq_length, return_tensors="pt"
        )
        labels = batch["input_ids"].clone()

        for i in range(len(examples)):
            input_ids_list = batch["input_ids"][i].tolist()
            response_token_ids_idxs = []
            human_token_ids_idxs = []

            for human_idx in np.where(labels[i] == self.instruction_token_ids[0])[0]:
                if (
                    self.instruction_token_ids
                    == input_ids_list[human_idx : human_idx + len(self.instruction_token_ids)]
                ):
                    human_token_ids_idxs.append(human_idx)

            for assistant_idx in np.where(labels[i] == self.response_token_ids[0])[0]:
                if (
                    self.response_token_ids
                    == input_ids_list[assistant_idx : assistant_idx + len(self.response_token_ids)]
                ):
                    response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
            if len(human_token_ids_idxs) == 0:
                logging.warning(f"Could not find instruction template in sample: {self.tokenizer.decode(input_ids_list)}")
                labels[i, :] = self.label_pad_token_id
                continue 

            if len(response_token_ids_idxs) == 0:
                logging.warning(f"Could not find response template in sample: {self.tokenizer.decode(input_ids_list)}")
                labels[i, :] = self.label_pad_token_id
                continue

            if (
                len(human_token_ids_idxs) > 0
                and len(response_token_ids_idxs) > 0
                and human_token_ids_idxs[0] > response_token_ids_idxs[0]
            ):
                human_token_ids_idxs = [0] + human_token_ids_idxs
            for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
                if idx == 0:
                    labels[i, :end] = self.label_pad_token_id
                else:
                    labels[i, start:end] = self.label_pad_token_id

            if len(response_token_ids_idxs) < len(human_token_ids_idxs):
                labels[i, human_token_ids_idxs[-1] :] = self.label_pad_token_id

        labels[labels == self.tokenizer.pad_token_id] = self.label_pad_token_id
        batch["labels"] = labels
        batch["raw_text"] = texts

        batch["cot_type"] = cot_types
        return batch

def zero_out_heads(model: PreTrainedModel, heads_to_zero: Iterable[Tuple[int, int]]) -> None:
    """Hard masking function: Zero out weights of specified attention heads and verify."""
    heads_list = list(heads_to_zero)
    if not heads_list:
        logging.warning("heads_to_zero list is empty, no weights will be zeroed out.")
        return

    config = model.config
    num_heads = config.num_attention_heads
    head_dim = config.hidden_size // num_heads
    
    logging.info(f"Zeroing out weights for {len(heads_list)} attention heads...")
    for layer_idx, head_idx in tqdm(heads_list, desc="Zeroing heads"):
        if head_idx >= num_heads:
            raise IndexError(f"Head index {head_idx} in layer {layer_idx} exceeds total heads ({num_heads})")
        
        layer = model.model.layers[layer_idx]
        attn = layer.self_attn
        
        row_slice = slice(head_idx * head_dim, (head_idx + 1) * head_dim)
        
        for proj in (attn.q_proj, attn.k_proj, attn.v_proj):
            with torch.no_grad():
                proj.weight[row_slice, :].zero_()
                if proj.bias is not None:
                    proj.bias[row_slice].zero_()
                    
    logging.info("Zeroing operation completed.")

    logging.info("Verifying zeroed weights...")
    verified_zeros = 0
    for layer_idx, head_idx in tqdm(heads_list, desc="Verifying"):
        layer = model.model.layers[layer_idx]
        attn = layer.self_attn
        row_slice = slice(head_idx * head_dim, (head_idx + 1) * head_dim)
        
        q_proj_weight_is_zero = (attn.q_proj.weight[row_slice, :].abs().sum().item() == 0)
        q_proj_bias_is_zero = True
        if attn.q_proj.bias is not None:
            q_proj_bias_is_zero = (attn.q_proj.bias[row_slice].abs().sum().item() == 0)

        if q_proj_weight_is_zero and q_proj_bias_is_zero:
            verified_zeros += 1
    
    if verified_zeros == len(heads_list):
        logging.info(f"Verification successful: {verified_zeros}/{len(heads_list)} heads confirmed zeroed.")
    else:
        logging.error(f"Verification failed: Only {verified_zeros}/{len(heads_list)} heads were successfully zeroed. Check model structure and code.")


class TokenLossCalculator:
    def __init__(self, model_name: str, **kwargs):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        
        logging.info(f"Loading model from {model_name}...")
        config = AutoConfig.from_pretrained(model_name)
        self.layer_num, self.head_num = config.num_hidden_layers, config.num_attention_heads
        logging.info(f"Model Structure: {self.layer_num} layers, {self.head_num} heads/layer")
        
        logging.info("Loading model using AutoModelForCausalLM with Flash Attention 2...")
        model_load_args = {"torch_dtype": "auto", "device_map": 'auto', "attn_implementation": "flash_attention_2"}
        self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_load_args, trust_remote_code=True)
        self.model.eval()
        
        self.label_pad_token_id = -100

    def _load_head_scores(self) -> List[List[int]]:
        """Load pre-computed attention head importance scores."""
        model_name_for_score = self.model_name.split('/')[-1]
        if model_name_for_score == 'Mistral-7B-Instruct-v0.2':
            model_name_for_score = "Mistral-7B-v0.2-hf"
        
        try:
            score_path = f"{model_name_for_score}.json"
            logging.info(f"Loading head scores from {score_path}...")
            with open(score_path, "r") as file:
                head_scores = json.loads(file.readline())
            sorted_heads = sorted(head_scores.items(), key=lambda x: np.mean(x[1]), reverse=True)
            logging.info(f"Successfully loaded and sorted scores for {len(sorted_heads)} heads.")
            return [[int(ll) for ll in h[0].split("-")] for h in sorted_heads]
        except FileNotFoundError:
            logging.warning(f"Note: Score file {score_path} not found.")
            logging.warning("If using positive mask_percentage (top-k mask), no heads will be masked.")
            logging.warning("If using negative mask_percentage (random mask), ignore this warning.")
            return []

    def _construct_random_heads(self, n: int, block_list: List[List[int]]) -> List[Tuple[int, int]]:
        """Construct a list of randomly selected attention heads."""
        all_possible_heads = [(l, h) for l in range(self.layer_num) for h in range(self.head_num)]
        known_heads = {tuple(h) for h in block_list}
        pool = [h for h in all_possible_heads if tuple(h) not in known_heads]
        if len(pool) < n:
            logging.warning(f"Available random heads ({len(pool)}) less than requested ({n}). Using all available.")
            n = len(pool)
        return random.sample(pool, n)

    def compute_per_token_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """Calculate loss per token."""
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = CrossEntropyLoss(reduction='none')
        # [Modification] Fixed spelling error in view(-(-1), ...) if any
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss.view(shift_labels.shape)
        padded_loss = F.pad(loss, (1, 0), "constant", 0)
        return padded_loss
        
    def compute_entropy_from_logits(self, logits: torch.Tensor) -> torch.Tensor:
        """Calculate entropy."""
        shifted_logits = logits[..., :-1, :].contiguous()
        entropy = entropy_from_logits(shifted_logits)
        padded_token_entropy = F.pad(entropy, (1, 0), "constant", 0)
        return padded_token_entropy


    def run(self, args: argparse.Namespace):
        """Main process for calculation and saving."""
        start_time_run = time.time()
        
        if "Llama" in self.model_name:
            instruction_template = "<|start_header_id|>user<|end_header_id|>"
            response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"
            self.tokenizer.pad_token = "<|reserved_special_token_5|>"
            logging.info("Detected Llama model, using 'user' template and '<|reserved_special_token_5|>' pad token.")
        elif "Qwen" in self.model_name:
            logging.info("Detected Qwen model, using '<|im_start|>user' template and '<|fim_pad|>' pad token.")
            instruction_template = "<|im_start|>user"
            response_template = "<|im_start|>assistant\n" 
            self.tokenizer.pad_token = "<|fim_pad|>"
            logging.info(f"Response template overwritten to match new assembly logic: '{response_template}'")
        else:
            logging.info("Detected Qwen model (default), using '<|im_start|>user' template and '<|fim_pad|>' pad token.")
            instruction_template = "<|im_start|>user"
            response_template = "<|im_start|>assistant\n" 
            self.tokenizer.pad_token = "<|fim_pad|>"
            logging.info(f"Response template overwritten to match new assembly logic: '{response_template}'")

        collator = CustomDataCollatorForCompletionOnlyLM(
            self.tokenizer, 
            instruction_template, 
            response_template, 
            args.max_length
        )

        if args.apply_mask:
            block_list = self._load_head_scores()
            heads_to_zero_out = []
            total_heads = self.layer_num * self.head_num
            
            if args.mask_percentage > 0:
                percentage = args.mask_percentage / 100.0
                num_heads_to_mask = int(total_heads * percentage)
                logging.info(f"Masking top {num_heads_to_mask} ({args.mask_percentage}%) heads based on scores.")
                heads_to_zero_out = block_list[:num_heads_to_mask]
            elif args.mask_percentage < 0:
                random_percentage = abs(args.mask_percentage) / 100.0
                num_heads_to_mask = int(total_heads * random_percentage)
                logging.info(f"Randomly masking {num_heads_to_mask} ({abs(args.mask_percentage)}%) heads.")
                heads_to_zero_out = self._construct_random_heads(num_heads_to_mask, block_list)
            
            logging.info(f"Finalizing zero-out for {len(heads_to_zero_out)} heads.")
            
            if heads_to_zero_out:
                zero_out_heads(self.model, heads_to_zero_out)
            else:
                logging.warning("No heads found to mask. Calculating vanilla loss. Check score file or mask percentage.")
        
        dataset = load_dataset(args.data_dir)['train']

        if args.test_run:
            logging.info("--- Starting Test Mode ---")
            dataset = dataset.select(range(min(60, len(dataset))))
        logging.info("Detected new data structure, assembling 'text' field...")
        
        original_cols = set(dataset.column_names)
        new_cols = {"text", "cot_type"}
        cols_to_remove = list(original_cols - new_cols) 
        
        dataset = dataset.map(
            _assemble_qwen_text,  
            num_proc=1, 
            remove_columns=cols_to_remove
        )
        logging.info("'text' field assembly complete.")
        logging.info(f"Processed dataset fields: {dataset.column_names}")

        dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collator, num_workers=4, shuffle=False)
        
        all_output = []
        logging.info("Starting Token Loss and Entropy calculation...")
        with tqdm(total=len(dataloader), desc="Calculating Loss & Entropy") as pbar:
            for batch in dataloader:
                input_ids = batch["input_ids"].to(self.model.device)
                labels = batch["labels"].to(self.model.device)
                
                with torch.no_grad():
                    outputs = self.model(input_ids=input_ids)
                    logits = outputs.logits
                    
                    token_loss = self.compute_per_token_loss(logits, labels)
                    token_entropy = self.compute_entropy_from_logits(logits)

                    unmasked_tokens_mask = (labels != self.label_pad_token_id).float()
                    sum_loss_per_item = (token_loss * unmasked_tokens_mask).sum(dim=-1)
                    num_unmasked_tokens_per_item = unmasked_tokens_mask.sum(dim=-1)
                    num_unmasked_tokens_per_item = num_unmasked_tokens_per_item.clamp(min=1e-9)
                    item_loss = sum_loss_per_item / num_unmasked_tokens_per_item

                
                raw_texts = batch["raw_text"]
                cot_types = batch["cot_type"]  
                token_loss_list = token_loss.cpu().tolist()
                item_loss_list = item_loss.cpu().tolist() 
                token_entropy_list = token_entropy.cpu().tolist()
                
                for i in range(len(raw_texts)):
                    all_output.append({
                        "text": raw_texts[i],
                        "token_loss": token_loss_list[i],
                        "item_loss": item_loss_list[i],      #
                        "cot_type": cot_types[i],           
                        "token_entropy": token_entropy_list[i]
                    })
                pbar.update(1)
        
        base_name, ext = os.path.splitext(args.output_file)
        if args.apply_mask:
            mask_type = "top" if args.mask_percentage > 0 else "random"
            output_filename = f"{base_name}_masked_{mask_type}{abs(args.mask_percentage)}pct{ext}"
        else:
            output_filename = f"{base_name}_vanilla{ext}"

        output_dir = os.path.dirname(output_filename)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            
        with open(output_filename, "w", encoding="utf-8") as f:
            for entry in all_output:
                f.write(json.dumps(entry, ensure_ascii=False) + "\n")
        
        logging.info(f"Processing complete. Saved to {output_filename}")

        # [Modification] Stop timer and report
        end_time_run = time.time()
        total_time_run = end_time_run - start_time_run
        total_items = len(all_output)
        
        if total_items > 0:
            avg_time_per_item = total_time_run / total_items
            logging.info(f"--- Performance Statistics (Processing Only) ---")
            logging.info(f"Total items processed: {total_items}")
            logging.info(f"Total processing time: {total_time_run:.2f} seconds.")
            logging.info(f"Average time per item: {avg_time_per_item:.4f} seconds.")
        else:
            logging.warning("No items processed.")

if __name__ == "__main__":
    start_time_global = time.time()

    parser = argparse.ArgumentParser(description="Calculate per-token loss and entropy for a model (optionally with masking).")
    parser.add_argument('--model_name', type=str, required=True, help='Hugging Face model path')
    parser.add_argument('--data_dir', type=str, required=True, help='Input data file path or directory')
    parser.add_argument('--output_file', type=str, required=True, help='Base path for output jsonl file')
    parser.add_argument('--apply_mask', action='store_true', help='Flag to apply attention head masking.')
    parser.add_argument('--mask_percentage', type=float, default=5.0, 
                        help='Percentage of heads to mask. Positive: top-k%; Negative: random k%. Only valid if --apply_mask is set.')
    parser.add_argument('--max_length', type=int, default=32768, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('--test_run', action='store_true', help='Run a quick test using only the first 60 samples.')
    
    args = parser.parse_args()
    
    calculator = TokenLossCalculator(model_name=args.model_name)
    calculator.run(args)
    
    end_time_global = time.time()
    logging.info(f"--- Total script runtime (including model loading) ---")
    logging.info(f"Total execution time: {end_time_global - start_time_global:.2f} seconds.")