import os
import json
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import Optional

import torch
import tyro
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
)

# Helper functions from your training script (with slight modifications for inference context)
# (generate, first_true_indices, truncate_response)

def first_true_indices(bools, dtype=torch.long):
    """
    Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving
    the position of the first True in each "row".

    Returns the length of the rows (bools.size(-1)) if no element is True in a given row.
    """
    row_len = bools.size(-1)
    # Ensure arange is on the same device as bools
    device = bools.device
    zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=device)
    return torch.min(zero_or_index, dim=-1).values


def truncate_response(truncate_config_ns, tokenizer, responses):
    """
    truncate_config_ns: A SimpleNamespace or similar object with a `truncate_token_id` attribute.
    """
    # Ensure responses is a tensor
    if not isinstance(responses, torch.Tensor):
        responses = torch.tensor(responses, device='cpu') # Or appropriate device

    # Ensure truncate_token_id is on the same device as responses
    truncate_token_id = torch.tensor(truncate_config_ns.truncate_token_id, device=responses.device)

    trunc_idxs = first_true_indices(responses == truncate_token_id).unsqueeze(-1)
    new_size = [1] * (len(responses.size()) - 1) + [responses.shape[1]]
    
    # Ensure idxs is on the same device as responses
    idxs = torch.arange(responses.shape[1], device=responses.device).view(*new_size)
    
    # Mask out tokens beyond the truncation index
    # Create a mask where True means 'keep' and False means 'replace with pad_token_id'
    # Tokens are kept if their index is less than or equal to the truncation index
    keep_mask = idxs <= trunc_idxs
    
    # For tokens that should be replaced, their value becomes pad_token_id
    # For tokens that should be kept, their value remains responses
    postprocessed_responses = torch.where(keep_mask, responses, torch.tensor(tokenizer.pad_token_id, device=responses.device))
    
    # An important edge case: if truncate_token_id is NOT found, trunc_idxs will be responses.shape[1].
    # This means all original tokens are kept, which is correct.
    # If the first token is truncate_token_id, trunc_idxs will be 0, so only the first token (EOS) is kept,
    # and the rest become PAD. This is also sensible.

    return postprocessed_responses


@torch.no_grad()
def generate_sequences(
    model: PreTrainedModel,
    tokenized_queries: torch.Tensor,
    tokenizer: AutoTokenizer,
    generation_config: GenerationConfig,
) -> torch.Tensor:
    """
    Generates sequences from tokenized queries.
    Returns only the generated part (not including the input query).
    """
    attention_mask = tokenized_queries != tokenizer.pad_token_id
    
    # The model.generate method handles padding internally if pad_token_id is set
    # in generation_config or model.config, and an attention_mask is provided.
    # It's crucial that tokenizer.padding_side == "left" for batched generation.
    
    # output.sequences contains the full sequence (query + response)
    # For decoder-only models, when input_ids are provided, `generate` prepends them to the output.
    full_sequences = model.generate(
        input_ids=tokenized_queries,
        attention_mask=attention_mask,
        generation_config=generation_config,
    )
    
    # Extract only the generated part
    # query_length might vary if inputs are padded differently, but for fixed-length tokenized_queries from dataset, it's fine.
    # If queries were tokenized on-the-fly with left padding, this would be more robust.
    # Assuming tokenized_queries are already padded to the same length (e.g. by DataLoader)
    query_length = tokenized_queries.shape[1]
    generated_part = full_sequences[:, query_length:]
    
    return generated_part


@dataclass
class InferenceArgs:
    model_path: str = "models/ppo_model"
    """Path to the saved PPO model directory (containing policy model and tokenizer)."""
    query_dataset_name: str = "vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144"
    """Name of the dataset on Hugging Face Hub."""
    dataset_split: str = "test"
    """Dataset split to use for inference (e.g., 'test', 'validation')."""
    output_file: str = "generated_summaries.json"
    """File path to save the generated summaries."""
    
    # Generation parameters (mirroring training defaults where appropriate)
    response_length: int = 53
    """The desired max length of the response (max_new_tokens)."""
    temperature: float = 0.01 # Use low temp for more deterministic output, as in validation
    """Sampling temperature for generation."""
    top_k: float = 0.0
    """Top-k sampling."""
    top_p: float = 1.0
    """Top-p (nucleus) sampling."""
    do_sample: bool = True # Consistent with validation_generation_config
    """Whether to use sampling; must be True for temperature to have an effect."""
    truncate_token: str = "eos"
    """Token to truncate responses at ('eos' or a specific token string)."""
    
    # Batching and device
    batch_size: int = 8
    """Batch size for inference."""
    cuda: bool = True
    """Whether to use CUDA if available."""
    seed: int = 42
    """Random seed for reproducibility of sampling (if do_sample=True)."""


def main():
    args = tyro.cli(InferenceArgs)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 1. Load Tokenizer and Model
    # When loading for generation, especially batched, left padding is preferred.
    # However, your training saved the tokenizer as it was (likely right-padded).
    # We'll load it as saved, but be mindful if issues arise with batched generation.
    # The `generate_sequences` function's attention_mask should handle it.
    tokenizer = AutoTokenizer.from_pretrained(args.model_path) #, padding_side="left") # Potentially override padding_side
    
    # Ensure pad token is set, consistent with training
    if tokenizer.pad_token is None:
        print("Tokenizer does not have a pad token. Adding '[PAD]' as pad token.")
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        # If you added new tokens, the model's embedding matrix might need resizing
        # if it wasn't already done during training and saved.
        # model.resize_token_embeddings(len(tokenizer)) # Usually done before saving

    model = AutoModelForCausalLM.from_pretrained(args.model_path).to(device)
    model.eval()
    
    # For `truncate_response` function
    truncate_config_ns = SimpleNamespace()
    if args.truncate_token == "eos":
        if tokenizer.eos_token_id is None:
            raise ValueError("truncate_token is 'eos' but tokenizer has no eos_token_id.")
        truncate_config_ns.truncate_token_id = tokenizer.eos_token_id
    else:
        truncate_config_ns.truncate_token_id = tokenizer.convert_tokens_to_ids(args.truncate_token)
        if truncate_config_ns.truncate_token_id == tokenizer.unk_token_id:
            print(f"Warning: truncate_token '{args.truncate_token}' resolved to UNK token.")
    
    print(f"Truncate token ID: {truncate_config_ns.truncate_token_id}")


    # 2. Load Dataset
    # The dataset 'vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144'
    # contains 'query_token' (tokenized input) and 'post_text' (original document).
    dataset = load_dataset(args.query_dataset_name, split=args.dataset_split)
    
    # We need to get 'query_token' as tensors and 'post_text' as strings.
    # DataLoader will help with batching. We need a custom collate_fn to handle padding of 'query_token'.
    def collate_fn(batch):
        # batch is a list of dicts, e.g., [{'query_token': [1,2,3], 'post_text': "doc1"}, ...]
        query_tokens_list = [torch.tensor(item['query_token']) for item in batch]
        post_texts = [item['post'] for item in batch]
        
        # Pad query_tokens to the max length in this batch
        # `pad_sequence` expects a list of Tensors.
        padded_query_tokens = torch.nn.utils.rnn.pad_sequence(
            query_tokens_list, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        return {"query_token": padded_query_tokens, "post_text": post_texts}

    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

    # 3. Setup GenerationConfig
    generation_config = GenerationConfig(
        max_new_tokens=args.response_length,
        min_new_tokens=args.response_length, # Ensures fixed length before potential truncation by EOS
        temperature=(args.temperature + 1e-7), # Add epsilon like in training
        top_k=args.top_k,
        top_p=args.top_p,
        do_sample=args.do_sample,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id # Model might stop early if it generates EOS
    )
    print("Generation Config:")
    print(generation_config)

    # 4. Generation Loop
    results = []
    print(f"Starting generation for {len(dataset)} samples...")
    for batch_data in tqdm(dataloader, desc="Generating Summaries"):
        queries_tensor = batch_data["query_token"].to(device)
        original_documents = batch_data["post_text"]

        # Generate response (only the new tokens)
        generated_responses_ids = generate_sequences(
            model,
            queries_tensor,
            tokenizer,
            generation_config
        ) # Shape: (batch_size, response_length)

        # Truncate responses at the specified token (e.g., EOS)
        # Ensure generated_responses_ids is on the correct device for truncate_response
        postprocessed_responses_ids = truncate_response(
            truncate_config_ns, 
            tokenizer, 
            generated_responses_ids.to(device) # Ensure it's on device if not already
        )

        # Decode summaries
        decoded_summaries = tokenizer.batch_decode(postprocessed_responses_ids, skip_special_tokens=True)

        for doc_text, summary_text in zip(original_documents, decoded_summaries):
            results.append({
                "document": doc_text,
                "summary": summary_text.strip()
            })

    # 5. Save Results
    output_dir = os.path.dirname(args.output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    with open(args.output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"Saved {len(results)} generated summaries to {args.output_file}")
    if results:
        print("\nSample output:")
        print(json.dumps(results[0], indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()