import json
import numpy as np
import os
import argparse
import html
from tqdm import tqdm
from transformers import AutoTokenizer
import warnings
from multiprocessing import Pool, cpu_count
from functools import partial

warnings.filterwarnings('ignore', category=RuntimeWarning)

def get_tokens_and_ids(text, tokenizer, max_length):
    eos_token = tokenizer.eos_token
    if eos_token is None:
        if "qwen" in tokenizer.name_or_path.lower():
            eos_token = "<|im_end|>"
        elif "llama" in tokenizer.name_or_path.lower():
            eos_token = "</s>"
        else:
            eos_token = ""
    
    text_with_eos = text + eos_token

    tokenized_output = tokenizer(
        text_with_eos,
        truncation=True,
        max_length=max_length,
        padding=False,
        add_special_tokens=True
    )
    
    input_ids = tokenized_output['input_ids']
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    return tokens, input_ids


def load_and_pair_data(vanilla_path, masked_path):
    print(f"Loading L_base data from '{vanilla_path}'...")
    vanilla_data = {json.loads(line)['text']: json.loads(line) for line in open(vanilla_path, 'r', encoding='utf-8')}
    print(f"Loading L_ref data from '{masked_path}'...")
    masked_data = {json.loads(line)['text']: json.loads(line) for line in open(masked_path, 'r', encoding='utf-8')}
    
    paired_data = []
    print("Pairing data...")
    for text, v_item in tqdm(vanilla_data.items(), desc="Pairing"):
        if text in masked_data:
            m_item = masked_data[text]
            
            L_base = np.array(v_item['token_loss'])
            L_ref = np.array(m_item['token_loss'])
            
            min_len = min(len(L_base), len(L_ref))
            L_base = L_base[:min_len]
            L_ref = L_ref[:min_len]
            
            paired_data.append({
                "text": text,
                "L_base": L_base,
                "L_ref": L_ref,
            })
    
    if not paired_data:
        raise ValueError("Unable to pair any data. Please check if input files match.")
    
    return paired_data


def _normalize_weights(raw_weights, N):
    raw_sum = np.sum(raw_weights)
    if raw_sum == 0:
        return np.ones(N)
    return raw_weights * (N / raw_sum)


def compute_step_level_scores(item, tokenizer, max_length):
    text = item['text']
    L_base = item['L_base']
    L_ref = item['L_ref']
    
    tokens, input_ids = get_tokens_and_ids(text, tokenizer, max_length)
    N = len(tokens)
    
    if len(L_base) < N:
        padded_L_base = np.zeros(N)
        padded_L_base[:len(L_base)] = L_base
        L_base = padded_L_base
        
        padded_L_ref = np.zeros(N)
        padded_L_ref[:len(L_ref)] = L_ref
        L_ref = padded_L_ref
    
    step_level_token_scores = np.zeros(N)
    step_boundaries_list = [] 
    
    eos_token = tokenizer.eos_token
    if eos_token is None:
        if "qwen" in tokenizer.name_or_path.lower():
            eos_token = "<|im_end|>"
        elif "llama" in tokenizer.name_or_path.lower():
            eos_token = "</s>"
        else:
            eos_token = ""
    
    text_with_eos = text + eos_token
    
    tokenized_output = tokenizer(
        text_with_eos,
        truncation=True,
        max_length=max_length,
        padding=False,
        add_special_tokens=True,
        return_offsets_mapping=True
    )
    
    offset_mapping = tokenized_output['offset_mapping']
    if len(offset_mapping) != N:
        offset_mapping = offset_mapping[:N]

    newline_positions = [i for i, char in enumerate(text) if char == '\n']
    
    newline_token_indices = []
    for nl_pos in newline_positions:
        for token_idx, (start, end) in enumerate(offset_mapping):
            if start <= nl_pos < end:
                newline_token_indices.append(token_idx)
                step_level_token_scores[token_idx] = 0 
                break
    
    step_boundaries_idx = [0] + newline_token_indices + [N]
    step_boundaries_idx = sorted(set(step_boundaries_idx))
    
    for step_idx in range(len(step_boundaries_idx) - 1):
        start_token_idx = step_boundaries_idx[step_idx]
        end_token_idx = step_boundaries_idx[step_idx + 1]
        
        if end_token_idx <= start_token_idx:
            continue
        
        actual_start = start_token_idx
        if start_token_idx in newline_token_indices:
            actual_start = start_token_idx + 1
        
        actual_end = end_token_idx
        
        if actual_end <= actual_start:
            continue
        
        step_L_base = L_base[actual_start:actual_end]
        step_L_ref = L_ref[actual_start:actual_end]
        
        if len(step_L_base) > 0:
            avg_L_base = np.mean(step_L_base)
            avg_L_ref = np.mean(step_L_ref)
            score_sub = avg_L_ref - avg_L_base
        else:
            score_sub = 0
        
        step_level_token_scores[actual_start:actual_end] = score_sub
        step_boundaries_list.append((actual_start, actual_end, score_sub))
    
    return tokens, step_level_token_scores, step_boundaries_list


def process_item_for_weights(item, tokenizer, max_length, percentile, low_weight, high_weight, do_normalize):
    try:
        tokens, _, step_boundaries = compute_step_level_scores(item, tokenizer, max_length)
        
        N = len(tokens)
        if N == 0:
            return (item['text'], [])

        unique_step_scores = [score for _, _, score in step_boundaries]
        
        if not unique_step_scores:
            raw_weights = np.full(N, low_weight)
            if do_normalize:
                return (item['text'], np.ones(N).tolist())
            else:
                return (item['text'], raw_weights.tolist())

        threshold = np.percentile(unique_step_scores, percentile)

        raw_weights = np.full(N, low_weight)
        
        for start, end, score in step_boundaries:
            if score >= threshold:
                raw_weights[start:end] = high_weight
        
        if do_normalize:
            final_weights = _normalize_weights(raw_weights, N)
        else:
            final_weights = raw_weights
        
        return (item['text'], final_weights.tolist())
    
    except Exception as e:
        print(f"Failed to process item (text: {item['text'][:50]}...): {e}")
        return (item['text'], [])


def main():
    parser = argparse.ArgumentParser(description="[Step-Level] Flexible Weight Calculation Script")
    
    parser.add_argument('--vanilla_file', type=str, required=True, help='Vanilla loss file path (L_base)')
    parser.add_argument('--masked_file', type=str, required=True, help='Masked loss file path (L_ref)')
    parser.add_argument('--model_path', type=str, required=True, help='Model path (for tokenizer)')
    parser.add_argument('--max_length', type=int, default=32768, help='Maximum sequence length')
    parser.add_argument('--save_weights_dir', type=str, default="./step_weights_output", help='Output root directory')

    parser.add_argument('--percentile', type=float, default=80.0, help='Top N% Threshold (e.g., 80.0 for Top 20%)')
    
    parser.add_argument('--low_weight', type=float, default=1.0, help='Weight for normal steps (default: 1.0)')
    
    parser.add_argument('--high_weight', type=float, default=2.0, help='Weight for hard steps (default: 2.0)')
    
    parser.add_argument('--normalize', action='store_true', help='[Flag] Normalize weights (mean=1). If not set, keeps raw weights.')
    
    parser.add_argument('--score_key', type=str, default="step_sub", help='Sub-directory name (score source)')
    parser.add_argument('--weight_key', type=str, default="threshold", help='Sub-directory name (weight method)')
    
    args = parser.parse_args()

    print(f"Loading Tokenizer from {args.model_path}...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    
    if tokenizer.pad_token is None:
        if "Llama" in args.model_path: tokenizer.pad_token = "<|reserved_special_token_5|>"
        elif "Qwen" in args.model_path: tokenizer.pad_token = "<|fim_pad|>"
        else: tokenizer.pad_token = "<|fim_pad|>"

    if tokenizer.eos_token is None:
        if "Llama" in args.model_path: tokenizer.eos_token = "</s>"
        elif "Qwen" in args.model_path: tokenizer.eos_token = "<|im_end|>"
        else: tokenizer.eos_token = "</s>"

    print("\n--- Loading and Pairing Data ---")
    paired_data = load_and_pair_data(args.vanilla_file, args.masked_file)
    print(f"Successfully paired {len(paired_data)} samples")

    print("\n--- Calculating Weights in Parallel ---")
    print(f"Configuration:")
    print(f"  - Top {100.0 - args.percentile:.0f}% step weight: {args.high_weight}")
    print(f"  - Other step weight: {args.low_weight}")
    print(f"  - Normalization (Mean=1): {'Enabled' if args.normalize else 'Disabled'}")

    process_func = partial(process_item_for_weights,
                           tokenizer=tokenizer,
                           max_length=args.max_length,
                           percentile=args.percentile,
                           low_weight=args.low_weight,
                           high_weight=args.high_weight,
                           do_normalize=args.normalize)
                           
    num_workers = max(1, cpu_count() - 2)
    
    all_results = []
    with Pool(num_workers) as pool:
        with tqdm(total=len(paired_data), desc="Calculating Weights") as pbar:
            for result in pool.imap(process_func, paired_data):
                if result[1]:
                    all_results.append(result)
                pbar.update(1)

    norm_str = "norm" if args.normalize else "raw"
    weight_info = f"lo{args.low_weight}_hi{args.high_weight}_{norm_str}"
    
    output_folder = os.path.join(args.save_weights_dir, args.score_key, args.weight_key, weight_info)
    os.makedirs(output_folder, exist_ok=True)
    output_path = os.path.join(output_folder, "train.jsonl")

    print(f"\nSaving to: {output_path}")
    with open(output_path, 'w', encoding='utf-8') as f:
        for text, weights_list in tqdm(all_results, desc="Saving JSONL"):
            jsonl_entry = {
                "text": text,
                "token_weight": weights_list
            }
            f.write(json.dumps(jsonl_entry, ensure_ascii=False) + "\n")
            
    print(f"\n--- Completed! ---")

if __name__ == "__main__":
    main()