"""
Step0 generates LLM responses for the given queries. It serves as the reference reasoning path.

Inputs:
- Query huggingface dataset, with columns: id and question.
    - This dataset can be generated by init_dataset_conversion.py.

Outputs:
- A csv file with the model responses.
- A huggingface dataset with the model responses.
    - The dataset and csv share the same content in different formats. It contains columns: "id", "input_text", "model_reasoning", "model_response", and "is_finished". Each row corresponds to a query.
"""

from datetime import datetime
import pandas as pd
import requests
from datasets import load_dataset, Dataset, load_from_disk
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
import os
import argparse
import json
import time
from dataclasses import dataclass
from typing import Optional, Dict, List, Tuple
import concurrent.futures
import threading
from time import sleep

# os.environ["HF_DATASETS_OFFLINE"] = "1"

@dataclass
class InputItem:
    """Data class representing a general input item for model processing."""

    id: str
    input_text: str
    model_reasoning: Optional[str] = None
    model_response: Optional[str] = None
    is_finished: Optional[bool] = None

    def __str__(self) -> str:
        return f"Item {self.id}:\n{self.input_text}\n\nReasoning:\n{self.model_reasoning}\n\nResponse:\n{self.model_response}"


def parse_args():
    parser = argparse.ArgumentParser(
        description="Process inputs with models using API requests"
    )

    # Model configuration
    parser.add_argument(
        "--model_path",
        type=str,
        default="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
        help="Model name to use for API requests and to load tokenizers.",
    )

    # API configuration
    parser.add_argument(
        "--api_url",
        type=str,
        required=True,
        help="Base URL for the API endpoint (e.g., http://localhost:30000/v1)",
    )
    parser.add_argument(
        "--api_key",
        type=str,
        default="",
        help="API key for authentication (if required)",
    )
    parser.add_argument(
        "--max_concurrent_requests",
        type=int,
        default=16,
        help="Maximum number of concurrent API requests",
    )
    parser.add_argument(
        "--request_timeout",
        type=int,
        default=6000,
        help="Request timeout in seconds",
    )
    parser.add_argument(
        "--max_retries",
        type=int,
        default=3,
        help="Maximum number of retries for failed requests",
    )
    parser.add_argument(
        "--retry_delay",
        type=float,
        default=1.0,
        help="Delay between retries in seconds",
    )

    # Dataset configuration
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
        help="Path to dataset. Supports: HF repo name, HF dataset saved to disk, or local file (.json/.jsonl/.csv)",
    )

    parser.add_argument(
        "--use_hf_dataset",
        type=bool,
        default=True,
        help="Use HuggingFace dataset as input",
    )
    
    # Generation configuration
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=32768,
        help="Maximum number of new tokens to generate",
    )
    parser.add_argument(
        "--temperature", type=float, default=0.0, help="Temperature for generation"
    )
    parser.add_argument(
        "--top_p", type=float, default=1.0, help="Top-p sampling parameter for generation"
    )
    parser.add_argument(
        "--top_k", type=int, default=-1, help="Top-k sampling parameter for generation"
    )

    # Output configuration
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Base directory to save results",
    )

    parser.add_argument(
        "--is_print",
        action="store_true",
        default=False,
        help="Print all model responses to standard output",
    )

    # Debug configuration
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Run in debug mode (only process first item)",
    )
    parser.add_argument(
        "--num_items",
        type=int,
        default=None,
        help="Number of items to process (for testing)",
    )

    # Recovery configuration
    parser.add_argument(
        "--item_ids",
        type=str,
        default=None,
        help="Comma-separated list of specific item IDs to process",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume from last checkpoint, processing only failed or missing items",
    )

    args = parser.parse_args()

    # Convert item IDs string to list if provided
    if args.item_ids:
        args.item_ids = [id.strip() for id in args.item_ids.split(",")]

    return args


def save_results(problems, output_dir):
    """Save results to CSV and convert to HuggingFace dataset."""
    if not problems:
        print("No problems were successfully processed")
        return

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Save to CSV
    df = pd.DataFrame([problem.__dict__ for problem in problems])
    df["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_path = os.path.join(output_dir, "LLM_response_results.csv")
    df.to_csv(csv_path, index=False)

    print(f"Results saved to: {csv_path}")

    # Convert to HuggingFace dataset
    dataset_dict = {
        "id": df["id"].tolist(),
        "input_text": df["input_text"].tolist(),
        "model_reasoning": df["model_reasoning"].tolist(),
        "model_response": df["model_response"].tolist(),
        "is_finished": df["is_finished"].tolist(),
    }

    # Create HuggingFace dataset
    hf_dataset = Dataset.from_dict(dataset_dict)

    # Save the dataset
    dataset_path = os.path.join(output_dir, "dataset")
    hf_dataset.save_to_disk(dataset_path)
    print(
        f"Saved HuggingFace dataset with {len(hf_dataset)} problems to {dataset_path}"
    )

    # filter the dataset by is_finished
    hf_dataset = hf_dataset.filter(lambda x: x["is_finished"] == True)
    print(f"Filtered dataset with {len(hf_dataset)} problems")

    # Save the filtered dataset
    dataset_path = os.path.join(output_dir, "dataset_finished")
    hf_dataset.save_to_disk(dataset_path)
    print(
        f"Saved HuggingFace dataset with {len(hf_dataset)} problems to {dataset_path}"
    )


def get_completed_items(output_dir: str) -> set:
    """Get set of item IDs that have been successfully processed."""
    completed = set()
    try:
        # Look for all CSV files in subdirectories
        for root, _, files in os.walk(output_dir):
            for file in files:
                if file.endswith(".csv") and "processed_items" in file:
                    try:
                        df = pd.read_csv(os.path.join(root, file))
                        if "id" in df.columns:
                            completed.update(df["id"].unique())
                    except Exception as e:
                        print(f"Error reading {file}: {e}")
    except Exception as e:
        print(f"Error getting completed items: {e}")
    return completed


def load_items_from_local_file(file_path: str) -> List[Dict]:
    """Load items from a local .json/.jsonl/.csv file into a list of dicts.

    The function is intentionally minimal and only handles common cases.
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")

    lower_path = file_path.lower()

    if lower_path.endswith(".jsonl"):
        items: List[Dict] = []
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                obj = json.loads(line)
                if isinstance(obj, dict):
                    items.append(obj)
                else:
                    # Wrap non-dict entries minimally
                    items.append({"problem": str(obj)})
        return items

    if lower_path.endswith(".json"):
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        if isinstance(data, list):
            return data
        if isinstance(data, dict):
            # Common patterns: {"data": [...]}
            if "data" in data and isinstance(data["data"], list):
                return data["data"]
            # Fallback: convert dict mapping to list
            out: List[Dict] = []
            for k, v in data.items():
                if isinstance(v, dict):
                    item = {"id": k}
                    item.update(v)
                    out.append(item)
                else:
                    out.append({"id": k, "problem": str(v)})
            return out
        raise ValueError("Unsupported JSON structure. Expected list or dict.")

    if lower_path.endswith(".csv"):
        df = pd.read_csv(file_path)
        return df.to_dict(orient="records")

    raise ValueError("Unsupported file extension. Use .json, .jsonl, or .csv")


def initialize_tokenizer(model_path):
    """Initialize tokenizer for prompt formatting."""
    print(f"Initializing tokenizer from {model_path}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        return tokenizer
    except Exception as e:
        print(f"Warning: Could not load tokenizer from {model_path}: {e}")
        print("Using default tokenizer, prompt formatting may be affected")
        return None


def prepare_prompts(items, tokenizer=None, enable_thinking=True):
    """Prepare prompts for API requests using the tokenizer's chat template if available.
    
    The input items are expected to be in the unified format created by step_-1_dataset_conversion.py.
    """
    prompts = []
    item_ids = []

    for index, item in enumerate(items):
        try:
            # Get the item ID (fallback to index if missing)
            item_id = str(item.get("id", item.get("_id", item.get("uid", index))))
            
            # Get the pre-formatted input text (support both keys)
            input_text = item.get("problem")
            if not input_text:
                input_text = item.get("question", "")
            if not input_text:
                # Skip if no usable text
                print(f"Skipping item {item_id}: no 'problem' or 'question' found")
                continue
            
            # Format as a chat message
            if tokenizer is not None:
                messages = [{"role": "user", "content": input_text}]
                
                # Apply the tokenizer's chat template
                formatted_prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=enable_thinking,
                )
            else:
                # Fallback to simple formatting if tokenizer is not available
                formatted_prompt = input_text
            
            prompts.append(formatted_prompt)
            item_ids.append(item_id)
            
        except Exception as e:
            print(f"Error processing item {item.get('id', 'unknown')}: {e}")
            continue

    return prompts, item_ids


def make_api_request(api_url, api_key, model_name, prompt, max_new_tokens=8192, temperature=0.0, top_p=1.0, top_k=-1, timeout=300):
    """Make a single API request to the model."""
    headers = {
        "Content-Type": "application/json",
    }
    
    # Add authorization header if API key is provided
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"
    
    # Prepare request payload (OpenAI-compatible format)
    payload = {
        "model": model_name,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "stream": False,
    }
    
    # Add top_k if specified
    if top_k != -1:
        payload["top_k"] = top_k
    
    # Make the API request
    try:
        response = requests.post(
            f"{api_url}/chat/completions",
            headers=headers,
            json=payload,
            timeout=timeout
        )
        response.raise_for_status()
        
        result = response.json()
        
        # Extract the response content
        if "choices" in result and len(result["choices"]) > 0:
            return result["choices"][0]["message"]["content"]
        else:
            raise Exception("No response content in API result")
            
    except requests.exceptions.Timeout:
        raise Exception("Request timeout")
    except requests.exceptions.RequestException as e:
        raise Exception(f"Request failed: {str(e)}")
    except Exception as e:
        raise Exception(f"Error processing API response: {str(e)}")


def api_request_with_retry(api_url, api_key, model_name, prompt, max_retries=3, retry_delay=1.0, **kwargs):
    """Make an API request with retry logic."""
    last_exception = None
    
    for attempt in range(max_retries):
        try:
            return make_api_request(api_url, api_key, model_name, prompt, **kwargs)
        except Exception as e:
            last_exception = e
            if attempt < max_retries - 1:
                print(f"Request failed (attempt {attempt + 1}/{max_retries}): {str(e)}")
                sleep(retry_delay * (2 ** attempt))  # Exponential backoff
            else:
                print(f"Request failed after {max_retries} attempts: {str(e)}")
    
    raise last_exception


def process_single_item(args, item_data, prompt, item_id):
    """Process a single item with API request."""
    try:
        # Make API request with retry
        response_content = api_request_with_retry(
            api_url=args.api_url,
            api_key=args.api_key,
            model_name=args.model_path,
            prompt=prompt,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
            top_k=args.top_k,
            timeout=args.request_timeout,
            max_retries=args.max_retries,
            retry_delay=args.retry_delay,
        )
        
        return process_single_response(response_content, item_id, item_data, args.is_print)
        
    except Exception as e:
        print(f"Error processing item {item_id}: {str(e)}")
        return None


def process_single_response(response_content, item_id, item_data, is_print=False):
    """Process a single model response, extracting reasoning if available."""
    # Be tolerant to different input schemas
    input_text = (
        item_data.get("question")
        or item_data.get("problem")
        or item_data.get("input_text", "")
    )
    is_finished = False

    # Extract reasoning from <think> tags if present
    reasoning_content = ""
    if "</think>" in response_content:
        end_idx = response_content.find("</think>")
        reasoning_content = response_content[:end_idx].strip()
        # Remove the thinking part from the main response
        response_content = response_content[end_idx + len("</think>") :].strip()
        is_finished = True

    # Print full responses if requested
    if is_print:
        print(f"\n===== INPUT =====\n{input_text}\n")
        if reasoning_content:
            print(f"===== REASONING =====\n{reasoning_content}\n")
        print(f"===== RESPONSE =====\n{response_content}\n")
        print(f"{'='*50}\n")

    if response_content or reasoning_content:
        return InputItem(
            id=item_id,
            input_text=input_text,
            model_reasoning=reasoning_content,
            model_response=response_content,
            is_finished=is_finished,
        )
    
    return None


def concurrent_api_requests(args, items, prompts, item_ids, max_concurrent=10):
    """Process multiple items concurrently using API requests."""
    processed_items = []
    
    # Create a thread pool executor
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent) as executor:
        # Submit all tasks
        future_to_item = {}
        for i, (item_data, prompt, item_id) in enumerate(zip(items, prompts, item_ids)):
            future = executor.submit(process_single_item, args, item_data, prompt, item_id)
            future_to_item[future] = (item_id, i)
        
        # Process completed requests
        for future in tqdm(concurrent.futures.as_completed(future_to_item), total=len(future_to_item), desc="Processing API requests"):
            item_id, index = future_to_item[future]
            try:
                result = future.result()
                if result is not None:
                    processed_items.append(result)
            except Exception as e:
                print(f"Error processing item {item_id}: {str(e)}")
    
    return processed_items


def save_args_to_json(args, output_dir):
    """Save command line arguments to a JSON file."""
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Convert args to dictionary
    args_dict = vars(args)

    # Handle non-serializable objects
    if "year_range" in args_dict:
        args_dict["year_range"] = list(args_dict["year_range"])

    # Add timestamp
    args_dict["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Save to JSON
    json_path = os.path.join(output_dir, "run_args.json")
    with open(json_path, "w") as f:
        json.dump(args_dict, f, indent=2)

    print(f"Arguments saved to: {json_path}")


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load the unified dataset
    try:
        # Auto-detect local file formats; otherwise fall back to HF loading logic
        is_local_file = os.path.isfile(args.dataset_path) and args.dataset_path.lower().endswith((".json", ".jsonl", ".csv"))
        if is_local_file:
            print(f"Loading local dataset file from {args.dataset_path}")
            dataset = load_items_from_local_file(args.dataset_path)
        else:
            if args.use_hf_dataset:
                print(f"Loading HuggingFace dataset from {args.dataset_path}")
                dataset = load_dataset(args.dataset_path, split="train")
            else:
                dataset = load_from_disk(args.dataset_path)
        print(f"Loaded dataset from {args.dataset_path}")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    all_items = list(dataset)
    print(f"Dataset contains {len(all_items)} items")

    # Get completed items if resuming
    completed_items = get_completed_items(args.output_dir) if args.resume else set()

    # Filter items based on arguments
    if args.item_ids:
        # Only process specified items
        items = [p for p in all_items if p["id"] in args.item_ids]
        print(f"Processing {len(items)} specified items")
    elif args.resume:
        # Only process items that haven't been completed
        items = [p for p in all_items if p["id"] not in completed_items]
        print(f"Resuming with {len(items)} remaining items")
    else:
        items = all_items

    # Handle debug mode and num_items
    if args.debug:
        items = items[:1]
        print("Debug mode: processing only first item")
    elif args.num_items:
        items = items[: args.num_items]
        print(f"Processing first {args.num_items} items")

    if not items:
        print("No items to process!")
        return

    print(f"Processing {len(items)} items from dataset")
    print(f"Output directory: {args.output_dir}")
    print(f"API URL: {args.api_url}")
    print(f"Max concurrent requests: {args.max_concurrent_requests}")

    # Initialize tokenizer (optional, for prompt formatting)
    tokenizer = initialize_tokenizer(args.model_path)

    # Prepare prompts
    prompts, item_ids = prepare_prompts(items=items, tokenizer=tokenizer)
    print(f"Prepared {len(prompts)} prompts")

    # Process items using concurrent API requests
    processed_items = concurrent_api_requests(
        args=args,
        items=items,
        prompts=prompts,
        item_ids=item_ids,
        max_concurrent=args.max_concurrent_requests,
    )

    # Save results
    save_results(processed_items, args.output_dir)

    # Save arguments to JSON
    save_args_to_json(args, args.output_dir)

    print("All processing complete!")


if __name__ == "__main__":
    main()
