import os
import json
import warnings
import logging
import argparse
import random
from tqdm import tqdm
import numpy as np
from typing import Iterable, Tuple, List, Dict, Any

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, AutoModelForCausalLM
import trl 

warnings.filterwarnings("ignore", category=FutureWarning)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    """Compute entropy from logits."""
    pd = F.softmax(logits, dim=-1)
    return torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)

class CustomDataCollatorForCompletionOnlyLM:
    def __init__(self, tokenizer, instruction_template, response_template, max_seq_length, **kwargs):
        self.tokenizer = tokenizer
        self.instruction_template = instruction_template
        self.response_template = response_template
        self.max_seq_length = max_seq_length
        self.label_pad_token_id = -100

        self.instruction_token_ids = self.tokenizer.encode(instruction_template, add_special_tokens=False)
        self.response_token_ids = self.tokenizer.encode(response_template, add_special_tokens=False)
        
        logging.info(f"Collator: instruction_token_ids: {self.instruction_token_ids}")
        logging.info(f"Collator: response_token_ids: {self.response_token_ids}")

    def __call__(self, examples: List[dict]) -> dict:
        texts = [e["text"] for e in examples]
        batch = self.tokenizer(
            texts, padding=True, truncation=True, max_length=self.max_seq_length, return_tensors="pt"
        )
        labels = batch["input_ids"].clone()

        for i in range(len(examples)):
            input_ids_list = batch["input_ids"][i].tolist()
            
            response_token_ids_idxs = []
            human_token_ids_idxs = []

            for human_idx in np.where(labels[i] == self.instruction_token_ids[0])[0]:
                if (
                    self.instruction_token_ids
                    == input_ids_list[human_idx : human_idx + len(self.instruction_token_ids)]
                ):
                    human_token_ids_idxs.append(human_idx)

            for assistant_idx in np.where(labels[i] == self.response_token_ids[0])[0]:
                if (
                    self.response_token_ids
                    == input_ids_list[assistant_idx : assistant_idx + len(self.response_token_ids)]
                ):
                    response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))

            if len(human_token_ids_idxs) == 0:
                logging.warning(f"Could not find instruction template in sample.")
                labels[i, :] = self.label_pad_token_id
                continue

            if len(response_token_ids_idxs) == 0:
                logging.warning(f"Could not find response template in sample.")
                labels[i, :] = self.label_pad_token_id
                continue

            if (
                len(human_token_ids_idxs) > 0
                and len(response_token_ids_idxs) > 0
                and human_token_ids_idxs[0] > response_token_ids_idxs[0]
            ):
                human_token_ids_idxs = [0] + human_token_ids_idxs

            for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
                if idx == 0:
                    labels[i, :end] = self.label_pad_token_id
                else:
                    labels[i, start:end] = self.label_pad_token_id

            if len(response_token_ids_idxs) < len(human_token_ids_idxs):
                labels[i, human_token_ids_idxs[-1] :] = self.label_pad_token_id
            
        labels[labels == self.tokenizer.pad_token_id] = self.label_pad_token_id
        batch["labels"] = labels
        batch["raw_text"] = texts
        return batch

def zero_out_heads(model: PreTrainedModel, heads_to_zero: Iterable[Tuple[int, int]]) -> None:
    heads_list = list(heads_to_zero)
    if not heads_list:
        logging.warning("heads_to_zero list is empty, no weights will be zeroed out.")
        return

    config = model.config
    num_heads = config.num_attention_heads
    head_dim = config.hidden_size // num_heads
    
    logging.info(f"Zeroing out weights for {len(heads_list)} attention heads...")
    for layer_idx, head_idx in tqdm(heads_list, desc="Zeroing heads"):
        if head_idx >= num_heads:
            raise IndexError(f"Head index {head_idx} in layer {layer_idx} exceeds total heads ({num_heads})")
        
        layer = model.model.layers[layer_idx]
        attn = layer.self_attn
        
        row_slice = slice(head_idx * head_dim, (head_idx + 1) * head_dim)
        
        for proj in (attn.q_proj, attn.k_proj, attn.v_proj):
            with torch.no_grad():
                proj.weight[row_slice, :].zero_()
                if proj.bias is not None:
                    proj.bias[row_slice].zero_()
                    
    logging.info("Zeroing operation completed.")

    logging.info("Verifying zeroed weights...")
    verified_zeros = 0
    for layer_idx, head_idx in tqdm(heads_list, desc="Verifying"):
        layer = model.model.layers[layer_idx]
        attn = layer.self_attn
        row_slice = slice(head_idx * head_dim, (head_idx + 1) * head_dim)
        
        q_proj_weight_is_zero = (attn.q_proj.weight[row_slice, :].abs().sum().item() == 0)
        q_proj_bias_is_zero = True
        if attn.q_proj.bias is not None:
            q_proj_bias_is_zero = (attn.q_proj.bias[row_slice].abs().sum().item() == 0)

        if q_proj_weight_is_zero and q_proj_bias_is_zero:
            verified_zeros += 1
    
    if verified_zeros == len(heads_list):
        logging.info(f"Verification successful: {verified_zeros}/{len(heads_list)} heads confirmed zeroed.")
    else:
        logging.error(f"Verification failed: Only {verified_zeros}/{len(heads_list)} heads were successfully zeroed.")

class TokenLossCalculator:
    def __init__(self, model_name: str, **kwargs):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        
        logging.info(f"Loading model from {model_name}...")
        config = AutoConfig.from_pretrained(model_name)
        self.layer_num, self.head_num = config.num_hidden_layers, config.num_attention_heads
        logging.info(f"Model Structure: {self.layer_num} layers, {self.head_num} heads/layer")

        logging.info("Loading model with AutoModelForCausalLM and Flash Attention 2...")
        model_load_args = {"torch_dtype": "auto", "device_map": 'auto', "attn_implementation": "flash_attention_2"}
        self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_load_args, trust_remote_code=True)
        self.model.eval()

    def _load_head_scores(self) -> List[List[int]]:
        model_name_for_score = self.model_name.split('/')[-1]
        if model_name_for_score == 'Mistral-7B-Instruct-v0.2':
            model_name_for_score = "Mistral-7B-v0.2-hf"
        
        try:
            score_path = f"{model_name_for_score}.json"
            logging.info(f"Loading head scores from {score_path}...")
            with open(score_path, "r") as file:
                head_scores = json.loads(file.readline())
            sorted_heads = sorted(head_scores.items(), key=lambda x: np.mean(x[1]), reverse=True)
            logging.info(f"Successfully loaded and sorted scores for {len(sorted_heads)} heads.")
            return [[int(ll) for ll in h[0].split("-")] for h in sorted_heads]
        except FileNotFoundError:
            logging.warning(f"Note: Score file {score_path} not found.")
            logging.warning("If using positive mask_percentage (top-k mask), no heads will be masked.")
            logging.warning("If using negative mask_percentage (random mask), ignore this warning.")
            return []

    def _construct_random_heads(self, n: int, block_list: List[List[int]]) -> List[Tuple[int, int]]:
        all_possible_heads = [(l, h) for l in range(self.layer_num) for h in range(self.head_num)]
        known_heads = {tuple(h) for h in block_list}
        pool = [h for h in all_possible_heads if tuple(h) not in known_heads]
        if len(pool) < n:
            logging.warning(f"Available random heads ({len(pool)}) less than requested ({n}). Using all available.")
            n = len(pool)
        return random.sample(pool, n)

    def compute_per_token_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = CrossEntropyLoss(reduction='none')
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss.view(shift_labels.shape)
        padded_loss = F.pad(loss, (1, 0), "constant", 0)
        return padded_loss
        
    def compute_entropy_from_logits(self, logits: torch.Tensor) -> torch.Tensor:
        shifted_logits = logits[..., :-1, :].contiguous()
        entropy = entropy_from_logits(shifted_logits)
        padded_token_entropy = F.pad(entropy, (1, 0), "constant", 0)
        return padded_token_entropy

    def run(self, args: argparse.Namespace):
        instruction_template = "<|im_start|>user"
        response_template = "<|im_start|>assistant\n"
        self.tokenizer.pad_token = "<|fim_pad|>"
        logging.info("Qwen model detected, using '<|im_start|>user' template and '<|fim_pad|>' pad token.")
        
        collator = CustomDataCollatorForCompletionOnlyLM(
            self.tokenizer, 
            instruction_template, 
            response_template, 
            args.max_length
        )

        if args.apply_mask:
            block_list = self._load_head_scores()
            heads_to_zero_out = []
            total_heads = self.layer_num * self.head_num
            
            if args.mask_percentage > 0:
                percentage = args.mask_percentage / 100.0
                num_heads_to_mask = int(total_heads * percentage)
                logging.info(f"Masking top {num_heads_to_mask} ({args.mask_percentage}%) heads based on scores.")
                heads_to_zero_out = block_list[:num_heads_to_mask]
            elif args.mask_percentage < 0:
                random_percentage = abs(args.mask_percentage) / 100.0
                num_heads_to_mask = int(total_heads * random_percentage)
                logging.info(f"Randomly masking {num_heads_to_mask} ({abs(args.mask_percentage)}%) heads.")
                heads_to_zero_out = self._construct_random_heads(num_heads_to_mask, block_list)
            
            logging.info(f"Finalizing zero-out for {len(heads_to_zero_out)} heads.")
            
            if heads_to_zero_out:
                zero_out_heads(self.model, heads_to_zero_out)
            else:
                logging.warning("No heads found to mask. Calculating vanilla loss. Check score file or mask percentage.")
        
        logging.info(f"Preparing to load data: {args.data_dir}")
        dataset = None

        if os.path.isfile(args.data_dir):
            file_ext = args.data_dir.split(".")[-1].lower()
            if file_ext in ["json", "jsonl"]:
                logging.info("Detected single JSON/JSONL file, loading with json loader...")
                dataset = load_dataset("json", data_files=args.data_dir, split="train")
            elif file_ext == "parquet":
                logging.info("Detected Parquet file, loading with parquet loader...")
                dataset = load_dataset("parquet", data_files=args.data_dir, split="train")
            else:
                raise ValueError(f"Unsupported file format: {file_ext}. Please provide .json, .jsonl, .parquet, or a dataset directory.")
        
        else:
            logging.info("Input treated as HuggingFace Repo ID or dataset directory...")
            raw_data = load_dataset(args.data_dir)
            
            if isinstance(raw_data, dict) or hasattr(raw_data, "keys"):
                if "train" in raw_data:
                    dataset = raw_data["train"]
                else:
                    first_split = list(raw_data.keys())[0]
                    logging.warning(f"'train' split not found, defaulting to '{first_split}'.")
                    dataset = raw_data[first_split]
            else:
                dataset = raw_data
        
        if args.test_run:
            logging.info("--- Starting Test Mode ---")
            dataset = dataset.select(range(min(60, len(dataset))))

        dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collator, num_workers=4, shuffle=False)
        
        all_output = []
        logging.info("Starting Token Loss and Entropy calculation...")
        with tqdm(total=len(dataloader), desc="Calculating Loss & Entropy") as pbar:
            for batch in dataloader:
                input_ids = batch["input_ids"].to(self.model.device)
                labels = batch["labels"].to(self.model.device)
                
                with torch.no_grad():
                    outputs = self.model(input_ids=input_ids)
                    logits = outputs.logits
                    
                    token_loss = self.compute_per_token_loss(logits, labels)
                    token_entropy = self.compute_entropy_from_logits(logits)
                
                raw_texts = batch["raw_text"]
                token_loss_list = token_loss.cpu().tolist()
                token_entropy_list = token_entropy.cpu().tolist()
                
                for i in range(len(raw_texts)):
                    all_output.append({
                        "text": raw_texts[i],
                        "token_loss": token_loss_list[i],
                        "token_entropy": token_entropy_list[i]
                    })
                pbar.update(1)
        
        base_name, ext = os.path.splitext(args.output_file)
        if args.apply_mask:
            mask_type = "top" if args.mask_percentage > 0 else "random"
            output_filename = f"{base_name}_masked_{mask_type}{abs(args.mask_percentage)}pct{ext}"
        else:
            output_filename = f"{base_name}_vanilla{ext}"

        output_dir = os.path.dirname(output_filename)
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            
        with open(output_filename, "w", encoding="utf-8") as f:
            for entry in all_output:
                f.write(json.dumps(entry, ensure_ascii=False) + "\n")
        
        logging.info(f"Processing complete. Saved to {output_filename}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate per-token loss and entropy for a model (optionally with masking).")
    parser.add_argument('--model_name', type=str, required=True, help='Hugging Face model path')
    parser.add_argument('--data_dir', type=str, required=True, help='Input data file path or directory')
    parser.add_argument('--output_file', type=str, required=True, help='Base path for output jsonl file')
    parser.add_argument('--apply_mask', action='store_true', help='Flag to apply attention head masking.')
    parser.add_argument('--mask_percentage', type=float, default=5.0, 
                        help='Percentage of heads to mask. Positive: top-k%; Negative: random k%. Only valid if --apply_mask is set.')
    parser.add_argument('--max_length', type=int, default=32768, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('--test_run', action='store_true', help='Run a quick test using only the first 60 samples.')
    
    args = parser.parse_args()
    
    calculator = TokenLossCalculator(model_name=args.model_name)
    calculator.run(args)