import argparse
import numpy as np
from itertools import islice
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import requests
from scipy.special import logsumexp 

# Assuming 'utils' contains the necessary helper functions
from utils import *
def message2api_topk_entropy(url, data, user_indices, llm_indices, tokenizer, top_k=30):
    """
    Extract token information from API response and calculate Top-K truncated entropy
    using logsumexp for numerical stability.
    
    Args:
        url: API endpoint URL.
        data: Request payload.
        user_indices: List of indices for user tokens [(start, end), ...].
        llm_indices: List of indices for llm tokens [(start, end), ...].
        tokenizer: Tokenizer object.
        top_k: Number of top tokens used for entropy calculation.
    
    Returns:
        dict: Contains entropy, token indices, and logprobs for both user and llm parts.
    """
    raw_response = requests.post(url, json=data)
    response = raw_response.json()
    
    # Check if the response contains the required metadata
    if 'meta_info' not in response:
        raise ValueError("Response does not contain 'meta_info'")
    
    meta_info = response['meta_info']
    
    # Get input_token_logprobs (actual logprob and token_id for each token)
    input_token_logprobs_all = meta_info['input_token_logprobs']
    
    # Get input_top_logprobs (top-k candidates for each token)
    input_top_logprobs_all = meta_info['input_top_logprobs']

    # Validate if the API returned enough top-k tokens
    if input_top_logprobs_all and len(input_top_logprobs_all[1]) < top_k:
        print(f"Warning: API returned top-{len(input_top_logprobs_all[0])}, but requested top-{top_k}")
        actual_top_k = len(input_top_logprobs_all[0])
    else:
        actual_top_k = top_k

    ## === Extract USER part information ===
    user_token_info = []
    user_top_info = []
    
    for start, end in user_indices:
        user_token_info.extend(input_token_logprobs_all[start:end])
        user_top_info.extend(input_top_logprobs_all[start:end])

    if user_token_info:
        user_token_logprobs = np.array([item[0] for item in user_token_info])
        user_token_index = np.array([item[1] for item in user_token_info])
        
        user_top_index = np.array([[item[1] for item in token_list[:actual_top_k]] 
                                   for token_list in user_top_info])
        
        # Extract logprobs for top-k candidates
        user_top_logprobs = np.array([[item[0] for item in token_list[:actual_top_k]] 
                                      for token_list in user_top_info])
        
        # --- Calculate Top-K Truncated Entropy ---
        log_sum = logsumexp(user_top_logprobs, axis=-1, keepdims=True)
        renorm_log_probs = user_top_logprobs - log_sum
        renorm_probs = np.exp(renorm_log_probs)
        user_entropy = -np.sum(renorm_probs * renorm_log_probs, axis=-1)
        
    else:
        user_token_logprobs = np.array([])
        user_token_index = np.array([])
        user_top_index = np.array([])
        user_top_logprobs = np.array([])
        user_entropy = np.array([])
    
    ## === Extract LLM part information ===
    llm_token_info = []
    llm_top_info = []
    
    for start, end in llm_indices:
        llm_token_info.extend(input_token_logprobs_all[start:end])
        llm_top_info.extend(input_top_logprobs_all[start:end])
    
    if llm_token_info:
        llm_token_logprobs = np.array([item[0] for item in llm_token_info])
        llm_token_index = np.array([item[1] for item in llm_token_info])
        
        llm_top_index = np.array([[item[1] for item in token_list[:actual_top_k]] 
                                  for token_list in llm_top_info])
        
        llm_top_logprobs = np.array([[item[0] for item in token_list[:actual_top_k]] 
                                     for token_list in llm_top_info])
        
        # --- Calculate Top-K Truncated Entropy ---
        log_sum = logsumexp(llm_top_logprobs, axis=-1, keepdims=True)
        renorm_log_probs = llm_top_logprobs - log_sum
        renorm_probs = np.exp(renorm_log_probs)
        llm_entropy = -np.sum(renorm_probs * renorm_log_probs, axis=-1)
        
    else:
        llm_token_logprobs = np.array([])
        llm_token_index = np.array([])
        llm_top_index = np.array([])
        llm_top_logprobs = np.array([])
        llm_entropy = np.array([])
    
    return {
        'user_entropy': user_entropy,
        'llm_entropy': llm_entropy,
        'user_token_index': user_token_index,
        'llm_token_index': llm_token_index,
        'user_token_logprobs': user_token_logprobs,
        'llm_token_logprobs': llm_token_logprobs,
        'user_top_index': user_top_index,
        'llm_top_index': llm_top_index,
        'user_top_logprobs': user_top_logprobs,
        'llm_top_logprobs': llm_top_logprobs
    }

def main(url, path_data, save_path, model_path, topk_for_entropy=100, topk_for_output=10, max_tokens=8100, min_rounds=1):
    print(f"Max tokens set to: {max_tokens}")

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token_id = tokenizer.eos_token_id

    all_results = []

    for i in range(0, 500):
        line = next(islice(read_jsonl(path_data), i, None), None)
        if line is None:
            break
            
        text = line['messages']
        fragments = extract_dynamic_fragments_with_precise_indices(text, tokenizer, max_tokens=max_tokens, min_rounds=min_rounds)
        print(f"Processing Text {i}: {len(fragments)} fragments")

        # Containers for concatenated results
        user_entropy_cat = []
        user_token_index_cat = []
        user_token_logprobs_cat = []
        user_topk_index_cat = []
        user_topk_logprobs_cat = []

        llm_entropy_cat = []
        llm_token_index_cat = []
        llm_token_logprobs_cat = []
        llm_topk_index_cat = []
        llm_topk_logprobs_cat = []

        for fragment in fragments:
            user_indices = fragment['user_token_indices']
            llm_indices = fragment['llm_token_indices']
            merged_content = fragment['merged_content']    
            
            # Prepare API payload
            data = {
                "model": "model", 
                "stream": False, 
                "text": merged_content,
                "sampling_params": {
                    "temperature": 0, 
                    "top_p": 1, 
                    "top_k": 1, 
                    "max_new_tokens": 1
                }, 
                "return_logprob": True, 
                "top_logprobs_num": topk_for_output, 
                "logprob_start_len": 0, 
                "return_text_in_logprobs": True,    
            }

            # Calculate entropy using the helper function
            results = message2api_topk_entropy(url, data, user_indices, llm_indices, tokenizer, top_k=topk_for_output)
            
            # Collect User results
            if len(results['user_entropy']) > 0:
                user_entropy_cat.append(results['user_entropy'])
                user_token_index_cat.append(results['user_token_index'])
                user_token_logprobs_cat.append(results['user_token_logprobs'])
                user_topk_index_cat.append(results['user_top_index'])
                user_topk_logprobs_cat.append(results['user_top_logprobs'])
            else:
                # print("  No user tokens in this fragment")
                pass
            
            # Collect LLM results
            if len(results['llm_entropy']) > 0:
                llm_entropy_cat.append(results['llm_entropy'])
                llm_token_index_cat.append(results['llm_token_index'])
                llm_token_logprobs_cat.append(results['llm_token_logprobs'])
                llm_topk_index_cat.append(results['llm_top_index'])
                llm_topk_logprobs_cat.append(results['llm_top_logprobs'])
            else:
                # print("  No LLM tokens in this fragment")
                pass
                    
        # Concatenate User results
        if len(user_entropy_cat) > 0:
            user_entropy_cat = np.concatenate(user_entropy_cat, axis=0)
            user_token_index_cat = np.concatenate(user_token_index_cat, axis=0)
            user_token_logprobs_cat = np.concatenate(user_token_logprobs_cat, axis=0)
            user_topk_index_cat = np.concatenate(user_topk_index_cat, axis=0)
            user_topk_logprobs_cat = np.concatenate(user_topk_logprobs_cat, axis=0)
        else:
            user_entropy_cat = np.array([])
            user_token_index_cat = np.array([])
            user_token_logprobs_cat = np.array([])
            # Preserve 2D shape for consistency
            user_topk_index_cat = np.array([]).reshape(0, topk_for_output)  
            user_topk_logprobs_cat = np.array([]).reshape(0, topk_for_output)
            print("\nNo user tokens collected")

        # Concatenate LLM results
        if len(llm_entropy_cat) > 0:
            llm_entropy_cat = np.concatenate(llm_entropy_cat, axis=0)
            llm_token_index_cat = np.concatenate(llm_token_index_cat, axis=0)
            llm_token_logprobs_cat = np.concatenate(llm_token_logprobs_cat, axis=0)
            llm_topk_index_cat = np.concatenate(llm_topk_index_cat, axis=0)
            llm_topk_logprobs_cat = np.concatenate(llm_topk_logprobs_cat, axis=0)
        else:
            llm_entropy_cat = np.array([])
            llm_token_index_cat = np.array([])
            llm_token_logprobs_cat = np.array([])
            llm_topk_index_cat = np.array([]).reshape(0, topk_for_output)
            llm_topk_logprobs_cat = np.array([]).reshape(0, topk_for_output)
            print("No LLM tokens collected")

        # Structure the results for saving
        all_results.append({
            "user": {
                "entropy": user_entropy_cat,
                "token_index": user_token_index_cat,
                "token_logprobs": user_token_logprobs_cat,
                "topk_index": user_topk_index_cat,
                "topk_logprobs": user_topk_logprobs_cat,
            },
            "llm": {
                "entropy": llm_entropy_cat,
                "token_index": llm_token_index_cat,
                "token_logprobs": llm_token_logprobs_cat,
                "topk_index": llm_topk_index_cat,
                "topk_logprobs": llm_topk_logprobs_cat,
            }
        })

    # Save all results to an .npz file
    save_dict = {}
    for i, result in enumerate(all_results):
        # Save User data
        save_dict[f"text_{i}_user_entropy"] = result['user']['entropy']
        save_dict[f"text_{i}_user_token_index"] = result['user']['token_index']
        save_dict[f"text_{i}_user_token_logprobs"] = result['user']['token_logprobs']
        save_dict[f"text_{i}_user_topk_index"] = result['user']['topk_index']
        save_dict[f"text_{i}_user_topk_logprobs"] = result['user']['topk_logprobs']
        
        # Save LLM data
        save_dict[f"text_{i}_llm_entropy"] = result['llm']['entropy']
        save_dict[f"text_{i}_llm_token_index"] = result['llm']['token_index']
        save_dict[f"text_{i}_llm_token_logprobs"] = result['llm']['token_logprobs']
        save_dict[f"text_{i}_llm_topk_index"] = result['llm']['topk_index']
        save_dict[f"text_{i}_llm_topk_logprobs"] = result['llm']['topk_logprobs']
        
    np.savez(save_path, **save_dict)
    print(f"\n{'='*50}")
    print(f"Processing complete. Saved {len(all_results)} texts to {save_path}")
    print(f"{'='*50}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate Top-K Truncated Entropy from LLM API.")
    parser.add_argument("--url", type=str, required=True, help="API endpoint URL")
    parser.add_argument("--data_path", type=str, required=True, help="Path to input data file")
    parser.add_argument("--save_path", type=str, required=True, help="Path to save the output .npz file")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model/tokenizer")
    parser.add_argument("--topk_for_entropy", type=int, default=10, help="Top-K value for entropy calculation")
    parser.add_argument("--topk_for_output", type=int, default=30, help="Number of top logprobs to request from API")
    parser.add_argument("--max_tokens", type=int, default=32700, help="Maximum token length for fragmentation")
    parser.add_argument("--min_rounds", type=int, default=1, help="Minimum conversation rounds")
    args = parser.parse_args()

    main(
        url=args.url,
        path_data=args.data_path,
        save_path=args.save_path,
        model_path=args.model_path,
        topk_for_entropy=args.topk_for_entropy,
        topk_for_output=args.topk_for_output,
        max_tokens=args.max_tokens,
        min_rounds=args.min_rounds
    )