import re
import torch
import argparse
import os
import json
from tqdm import tqdm
from transformers import (
    LlamaTokenizer,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from datasets import load_dataset
from peft import PeftConfig
from abcrl.datasets import build_anthropic_dataset, collator
from trl import PPOConfig, AutoModelForCausalLMWithValueHead

BASE_MODEL_NAME = "VMware/open-llama-7b-open-instruct"


def build_anthropic_dataset(
    config: PPOConfig,
    max_length: int = 256,
    split: str = "test"
):
    """
    Build dataset for training.

    Args:
        config: The configuration object for a PPOTrainer.
        max_length (int, optional): The maximum length of the input sequences. Defaults to 256.

    Returns:
        torch.utils.data.Dataset: The dataset for training.

    Raises:
        Exception: If the tokenizer fails to load.

    Example:
        >>> config = config = PPOConfig(model_name="VMware/open-llama-7b-open-instruct")
        >>> dataset = build_anthropic_dataset(config, max_length=512)
    """
    ds = load_dataset("Anthropic/hh-rlhf", download_mode="force_redownload", split=split)
    ds = ds.filter(lambda x: x["chosen"].count("Human:") == 1, batched=False)
    try:
        tokenizer = LlamaTokenizer.from_pretrained(config.model_name, use_fast=False)
    except Exception as e:
        tokenizer = AutoTokenizer.from_pretrained(config.model_name, use_fast=False)

    def tokenize(sample):
        pattern = r"\s*Human:\s*(.*?)\s*Assistant:\s*"
        match = re.search(pattern, sample["chosen"], re.DOTALL)
        # prompt = f"Below is an instruction from a Human that describes a task. Write a response as the Assistant that appropriately completes the request. ###Human: {match.group(1).strip()} ###Assistant: "
        prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{match.group(1).strip()}\n\n### Response:"
        sample["input_ids"] = tokenizer.encode(prompt)
        sample["query"] = tokenizer.decode(sample["input_ids"])
        sample["rm_input_ids"] = tokenizer.encode(
            f"###Human: {match.group(1).strip()} ###Assistant: "
        )
        sample["rm_query"] = tokenizer.decode(sample["rm_input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds = ds.filter(lambda x: len(x["input_ids"]) < max_length, batched=False)
    ds.set_format(type="torch")
    return ds


def main_inference(
    adapter_checkpoint_path: str,
    base_model_name: str = BASE_MODEL_NAME,
    max_instruction_length: int = 256,
    min_generation: int = 8,
    max_generation: int = 128,
    repetition_penalty: float = 1.0,
    num_test_samples: int = 5,
    batch_size: int = 1,
    output_file: str = "inference_outputs.json",
    use_greedy_decoding: bool = False,
):
    print(f"Loading adapter from: {adapter_checkpoint_path}")
    print(f"Using base model: {base_model_name}")
    decoding_strategy = "greedy" if use_greedy_decoding else "sampling"
    print(f"Using decoding strategy: {decoding_strategy}")

    # Construct the full output filename
    print(f"Output will be saved to: {output_file}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 1. Load Tokenizer
    tokenizer = LlamaTokenizer.from_pretrained(adapter_checkpoint_path, use_fast=False)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    print("Tokenizer loaded.")

    # 2. Configure Quantization
    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    print("BitsAndBytesConfig configured.")

    # --- CORRECTED MODEL LOADING ---
    # 3. Load PeftConfig from the adapter path
    try:
        peft_config = PeftConfig.from_pretrained(adapter_checkpoint_path)
        print("PeftConfig loaded.")
    except Exception as e:
        print(f"Error loading PeftConfig from {adapter_checkpoint_path}: {e}")
        print("Ensure that 'adapter_config.json' exists in this directory.")
        return


    # 4. Load the base model with quantization
    # Note: TRL's AutoModelForCausalLMWithValueHead automatically handles
    # merging PEFT adapters if `peft_config` is passed correctly.
    # However, the error suggests it expects a PeftConfig object directly.

    # The AutoModelForCausalLMWithValueHead class from TRL should ideally handle this
    # by first loading the base model and then applying the PEFT model.
    # Let's try providing the PeftConfig object directly.

    # model = AutoModelForCausalLMWithValueHead.from_pretrained(
    #     base_model_name, # Load the base model first
    #     quantization_config=nf4_config,
    #     torch_dtype=torch.bfloat16,
    #     # device_map="auto" # Consider adding this for multi-GPU or large models
    # )
    # print(f"Base model '{base_model_name}' loaded.")

    # Now, load the LoRA weights onto the base model.
    # TRL's AutoModelForCausalLMWithValueHead is a wrapper.
    # We need to ensure the PEFT adapter is loaded correctly *onto* the underlying base model.
    # The `from_pretrained` of `AutoModelForCausalLMWithValueHead` should be able to take `peft_config`
    # as a PeftConfig object.
    #
    # If the direct PeftConfig object in `from_pretrained` of AutoModelForCausalLMWithValueHead
    # doesn't work as expected (it should according to the error message context),
    # an alternative is to load the base model first, then apply PEFT, then wrap.
    # But let's stick to what the error message implies for TRL's wrapper.

    # The error `ValueError: The peft_config argument should be an instance of peft.PeftConfig class.`
    # comes from TRL's PreTrainedModelWrapper. So, `peft_config=peft_config` (the object) should be the fix.

    # Let's re-instantiate the TRL wrapper, this time passing the loaded peft_config object
    # This seems redundant if the initial from_pretrained for the wrapper could take it,
    # but let's trace the logic. The wrapper needs to know about PEFT from the start.

    # The most straightforward way with TRL's wrapper, based on the error,
    # is to pass the PeftConfig object during its instantiation.
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        pretrained_model_name_or_path=base_model_name,
        quantization_config=nf4_config,
        torch_dtype=torch.bfloat16,
        peft_config=peft_config, # Pass the PeftConfig object here
        # device_map="auto"
    )
    print("Model with PEFT adapter loaded using AutoModelForCausalLMWithValueHead.")

    # If not using device_map="auto", explicitly move to device
    if not hasattr(model, 'hf_device_map') or not model.hf_device_map:
         model.to(device)
    # --- END CORRECTED MODEL LOADING ---

    model.eval()
    print("Model set to evaluation mode.")

    # 4. Prepare Test Data
    print("Loading test prompts...")
    dummy_config = PPOConfig(model_name=base_model_name, batch_size=1, ppo_epochs=1) # dummy values
    
    # Assuming build_anthropic_dataset can produce a 'test' split or you use 'train' for demo
    anthropic_test_dataset = build_anthropic_dataset(dummy_config, max_instruction_length, split="test")
    
    test_prompts_all = [item['query'] for item in anthropic_test_dataset]
    if not test_prompts_all:
        print("Warning: No prompts loaded from dataset. Using placeholder prompts.")
        test_prompts_all = [
            "Human: Explain the concept of black holes to a 5-year-old.\n\nAssistant:",
            "Human: Write a haiku about a rainy day.\n\nAssistant:",
        ]

    test_prompts = test_prompts_all[:num_test_samples]
    print(f"Loaded {len(test_prompts)} test prompts.")

    # 5. Define Generation Parameters
    generation_kwargs = {
        "min_length": -1,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "repetition_penalty": repetition_penalty,
        "max_new_tokens": max_generation,
        "min_new_tokens": min_generation,
    }

    if use_greedy_decoding:
        generation_kwargs["do_sample"] = False
    else: # Sampling
        generation_kwargs["do_sample"] = True
        generation_kwargs["top_k"] = 0.0
        generation_kwargs["top_p"] = 1.0

    print(f"Generation kwargs: {generation_kwargs}")

    # 6. Perform Inference
    results = []
    print(f"\nStarting inference on {len(test_prompts)} prompts (batch size {batch_size})...")

    for i in tqdm(range(0, len(test_prompts), batch_size)):
        batch_prompts = test_prompts[i:i+batch_size]
        
        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_instruction_length
        ).to(device)

        with torch.no_grad():
            outputs = model.generate(**inputs, **generation_kwargs)

        decoded_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        for idx, (prompt, full_response) in enumerate(zip(batch_prompts, decoded_responses)):
            if prompt.startswith("<s>"):
                prompt = prompt[len("<s>"):]
            assert full_response.startswith(prompt)
            generated_text = full_response[len(prompt):]

            print("-" * 50)
            print(f"Prompt:\n{prompt}")
            print(f"Generated Response:\n{generated_text}")

            generator_name = os.path.basename(adapter_checkpoint_path)
            results.append({
                "instruction": prompt,
                "output": generated_text,
                "generator": f"{generator_name}_{decoding_strategy}",
            })
    
    # 8. Save results to a JSON file
    try:
        with open(output_file, "w", encoding="utf-8") as f_out:
            json.dump(results, f_out, indent=4, ensure_ascii=False)
        print(f"\nInference complete. {len(results)} outputs saved to {output_file}")
    except Exception as e:
        print(f"Error saving results to JSON: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Load a QLoRA fine-tuned model and perform inference.")
    parser.add_argument(
        "--adapter_checkpoint_path",
        type=str,
        required=True,
        help="Path to the saved adapter checkpoint directory (e.g., 'experiments/saved_models/your_run_name')."
    )
    parser.add_argument(
        "--base_model_name",
        type=str,
        default=BASE_MODEL_NAME,
        help="Name of the base model (e.g., 'VMware/open-llama-7b-open-instruct')."
    )
    parser.add_argument(
        "--max_instruction_length",
        type=int,
        default=256,
        help="Maximum length of the input instruction/prompt."
    )
    parser.add_argument(
        "--min_generation",
        type=int,
        default=8,
        help="Minimum number of new tokens to generate."
    )
    parser.add_argument(
        "--max_generation",
        type=int,
        default=256,
        help="Maximum number of new tokens to generate."
    )
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        default=1.0,
        help="Repetition penalty for generation."
    )
    parser.add_argument(
        "--num_test_samples",
        type=int,
        default=10,
        help="Number of test samples to run inference on."
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=2,
        help="Batch size for inference."
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default="inference_outputs.json",
        help="Base name for the output JSON file. '_greedy' or '_sampling' will be appended."
    )
    parser.add_argument(
        "--use_greedy_decoding",
        action="store_true", # Creates a boolean flag, True if present, False otherwise
        help="Use greedy decoding. If not set, sampling will be used."
    )

    args = parser.parse_args()

    main_inference(
        adapter_checkpoint_path=args.adapter_checkpoint_path,
        base_model_name=args.base_model_name,
        max_instruction_length=args.max_instruction_length,
        min_generation=args.min_generation,
        max_generation=args.max_generation,
        repetition_penalty=args.repetition_penalty,
        num_test_samples=args.num_test_samples,
        batch_size=args.batch_size,
        output_file=args.output_file,
        use_greedy_decoding=args.use_greedy_decoding,
    )