import json
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from transformers import AutoTokenizer
import argparse
import logging
from typing import List, Dict, Any
import re


# JSON Schema for structured reasoning output
REASONING_SCHEMA = {
    "type": "object",
    "properties": {
        "reasoning_steps": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "step_id": {"type": "string", "pattern": "^int_\\d+$"},
                    "rule_facts": {"type": "string", "pattern": "^(rule_\\d+|fact_\\d+|int_\\d+)(\\s*&\\s*(rule_\\d+|fact_\\d+|int_\\d+))*$"},
                    "conclusion": {
                        "oneOf": [
                            {
                                "type": "object",
                                "properties": {
                                    "type": {"type": "string", "enum": ["AttributeFact"]},
                                    "entity": {"type": "string", "minLength": 1},
                                    "attribute": {"type": "string", "minLength": 1},
                                    "value": {"type": "integer"}
                                },
                                "required": ["type", "entity", "attribute", "value"],
                                "additionalProperties": False
                            },
                            {
                                "type": "object",
                                "properties": {
                                    "type": {"type": "string", "enum": ["RelationFact"]},
                                    "relation": {"type": "string", "minLength": 1},
                                    "entity1": {"type": "string", "minLength": 1},
                                    "entity2": {"type": "string", "minLength": 1}
                                },
                                "required": ["type", "relation", "entity1", "entity2"],
                                "additionalProperties": False
                            }
                        ]
                    }
                },
                "required": ["step_id", "rule_facts", "conclusion"],
                "additionalProperties": False
            }
        },
        "final_answer": {"type": "integer"}
    },
    "required": ["reasoning_steps", "final_answer"],
    "additionalProperties": False
}


def create_structured_chat_messages(original_answer: str, problem: str = None, attributes: list = None, relations: list = None) -> list:
    """Create chat messages for structured output mode."""

    system_message = """You are a logical reasoning assistant. Your task is to analyze the given reasoning process and extract it into a structured JSON format.

The system will automatically ensure valid JSON format, so focus on the content accuracy.

Required JSON Structure:
{
  "reasoning_steps": [
    {
      "step_id": "int_1",
      "rule_facts": "combination of rules/facts/intermediates",
      "conclusion": {
        "type": "AttributeFact" | "RelationFact",
        ...specific fields based on type
      }
    }
  ],
  "final_answer": integer_value
}

Field Requirements:
1. "step_id": Sequential format "int_1", "int_2", "int_3", etc.

2. "rule_facts": Reference input elements using:
   - "fact_N" for given facts (e.g., fact_1, fact_13)
   - "rule_N" for reasoning rules (e.g., rule_5, rule_15)
   - "int_N" for intermediate conclusions from previous steps
   - Combine with " & " (e.g., "rule_15 & fact_13 & fact_4")

3. "conclusion": Two types:
   
   AttributeFact (for entity attributes):
   {
     "type": "AttributeFact",
     "entity": "exact_entity_name",
     "attribute": "exact_attribute_name",
     "value": integer_value
   }
   
   RelationFact (for relationships):
   {
     "type": "RelationFact", 
     "relation": "exact_relation_name",
     "entity1": "first_entity_name",
     "entity2": "second_entity_name"
   }

4. "final_answer": The integer answer to the query

Guidelines:
- MERGE multi-step calculations for the same attribute into one step with final value
- Use exact entity/attribute/relation names from the available lists
- Each step must produce a complete, concrete conclusion
- Sequential step_id numbering

Example:
{
  "reasoning_steps": [
    {
      "step_id": "int_1",
      "rule_facts": "rule_3 & fact_1 & fact_5",
      "conclusion": {
        "type": "RelationFact",
        "relation": "taller",
        "entity1": "Alice",
        "entity2": "Bob"
      }
    },
    {
      "step_id": "int_2", 
      "rule_facts": "rule_1 & fact_2 & int_1",
      "conclusion": {
        "type": "AttributeFact",
        "entity": "Alice",
        "attribute": "height",
        "value": 175
      }
    }
  ],
  "final_answer": 175
}"""

    user_message = f"""Analyze the following reasoning and convert to the required JSON format.

Original Reasoning:
{original_answer}"""

    if problem:
        user_message = f"""Problem Context:
{problem}

{user_message}"""

    # Add constraints more clearly
    if attributes:
        attributes_list = ', '.join(attributes)
        user_message += f"\n\nValid Attributes: {attributes_list}"

    if relations:
        relations_list = ', '.join(relations)
        user_message += f"\nValid Relations: {relations_list}"

    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ]


def create_chat_messages(original_answer: str, problem: str = None, attributes: list = None, relations: list = None) -> list:
    """Create chat messages for the LLM to format the answer into structured format."""

    system_message = """You are a logical reasoning assistant. Your task is to analyze the given reasoning process and reformat it into a specific structured format.

Requirements:
1. After completing your analysis, summarize the key reasoning steps in the specified structured format
2. The structured format is only required at the end as a summary - your main explanation can be in natural language
3. For the final answer, always use: "Answer: \\boxed{[value]}"

Structured Summary Requirements:
"Reasoning:
rule_15 & fact_13 & fact_4 => int_1: relation_name exists between first_entity and second_entity.  
rule_5 & fact_1 & fact_10 & fact_11 & fact_5 => int_2: entity_name's attribute_name is attribute_value.  
...  
Answer: \\boxed{answer_value}"

Format Guidelines:
- Each reasoning step should be expressed as: [rule/fact combinations] => int_[n]: [intermediate conclusion]
- Express relationships as "[relation] exists between [X] and [Y]"
- Express attributes as "[X]'s [attribute] is [value]"
- Use logical operators: & (and)
- Number intermediate conclusions sequentially (int_1, int_2, etc.)"""

    user_message = f"""Please analyze the following reasoning process and reformat it into the structured format specified above.

Original Answer:
{original_answer}"""

    if problem:
        user_message = f"""Problem:
{problem}

{user_message}"""

    # Add attributes and relations information separately
    attributes_list = ', '.join(attributes)
    user_message += f"\n\nAvailable Attributes: {attributes_list}"

    relations_list = ', '.join(relations)
    user_message += f"\n\nAvailable Relations: {relations_list}"

    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ]


def prepare_input_data(data, include_reasoning=False):
    """Prepare input data by constructing problem from facts, rules, and query."""
    facts = "Facts:\n"
    for i, fact in enumerate(data['facts-tuned-nl']):
        facts += f"{i + 1}. {fact}\n"

    rules = "Rules:\n"
    for i, rule in enumerate(data['rules-tuned-nl']):
        rules += f"{i + 1}. {rule}\n"

    query_entity, query_attribute = data['query']
    query = f"Query:\nWhat is the value of {query_entity}'s {query_attribute}?\n"

    if not include_reasoning:
        return facts + rules + query

    intermediate_results = "After a detailed explanation, you would conclude as follows.\n"
    reasoning_process = "Reasoning:\n"
    reasoning_process += data["reasoning_process_nl"] + "\n"
    answer = data["values"][query_entity][query_attribute]
    answer = f"Answer: \\boxed{{{answer}}}\n"
    return facts + rules + query + intermediate_results + reasoning_process + answer


def load_input_data(input_file: str) -> Dict[str, Any]:
    """Load input data from json file with dataset structure."""
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data


def validate_and_clean_item(item: Dict[str, Any], dataset_name: str, item_idx: int) -> tuple[bool, Dict[str, Any]]:
    """Validate and clean a data item before processing."""
    # Check for required fields
    required_fields = ['llm_output', 'attributes', 'relations']
    for field in required_fields:
        if field not in item:
            logging.error(
                f"Missing '{field}' field in dataset {dataset_name}, item {item_idx}")
            return False, {}

    # Validate llm_output is not empty
    if not item.get("llm_output", "").strip():
        logging.warning(
            f"Empty llm_output found in dataset {dataset_name}, item {item_idx}")
        return False, {}

    # Ensure attributes and relations are lists
    attributes = item.get('attributes', [])
    relations = item.get('relations', [])

    if not isinstance(attributes, list):
        logging.warning(
            f"Attributes not a list in dataset {dataset_name}, item {item_idx}, converting...")
        attributes = list(attributes) if hasattr(
            attributes, '__iter__') else []

    if not isinstance(relations, list):
        logging.warning(
            f"Relations not a list in dataset {dataset_name}, item {item_idx}, converting...")
        relations = list(relations) if hasattr(relations, '__iter__') else []

    # Create cleaned item
    cleaned_item = item.copy()
    cleaned_item['attributes'] = attributes
    cleaned_item['relations'] = relations

    return True, cleaned_item


def save_output_data(output_file: str, data: Dict[str, Any]) -> None:
    """Save output data to json file maintaining dataset structure."""
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def extract_structured_reasoning(response: str) -> Dict[str, str]:
    """Extract the structured reasoning and answer from the LLM response."""

    # Extract the reasoning section
    reasoning_pattern = r'Reasoning:\s*\n(.*?)(?=Answer:|$)'
    reasoning_match = re.search(
        reasoning_pattern, response, re.DOTALL | re.IGNORECASE)

    # Extract the final answer
    answer_pattern = r'Answer:\s*\\boxed\{([^}]+)\}'
    answer_match = re.search(answer_pattern, response, re.IGNORECASE)

    result = {
        'full_response': response,
        'structured_reasoning': '',
        'final_answer': ''
    }

    if reasoning_match:
        result['structured_reasoning'] = reasoning_match.group(1).strip()

    if answer_match:
        result['final_answer'] = answer_match.group(1).strip()

    return result


def extract_structured_json(response: str) -> Dict[str, Any]:
    """Extract structured JSON from the response."""
    try:
        # With guided decoding, the response should already be valid JSON
        json_data = json.loads(response.strip())

        # Format reasoning steps into text format for backward compatibility
        reasoning_text = ""
        if "reasoning_steps" in json_data and isinstance(json_data["reasoning_steps"], list):
            for step in json_data["reasoning_steps"]:
                step_id = step.get("step_id", "")
                rule_facts = step.get("rule_facts", "")
                conclusion = step.get("conclusion", {})

                # Format conclusion based on type
                if isinstance(conclusion, dict):
                    conclusion_type = conclusion.get("type", "")
                    if conclusion_type == "AttributeFact":
                        entity = conclusion.get("entity", "")
                        attribute = conclusion.get("attribute", "")
                        value = conclusion.get("value", "")
                        conclusion_text = f"{entity}'s {attribute} is {value}"
                    elif conclusion_type == "RelationFact":
                        relation = conclusion.get("relation", "")
                        entity1 = conclusion.get("entity1", "")
                        entity2 = conclusion.get("entity2", "")
                        conclusion_text = f"{relation} exists between {entity1} and {entity2}"
                    else:
                        conclusion_text = str(conclusion)
                else:
                    conclusion_text = str(conclusion)

                reasoning_text += f"{rule_facts} => {step_id}: {conclusion_text}.\n"

        # Get final answer
        final_answer = str(json_data.get("final_answer", ""))

        return {
            'full_response': response,
            'structured_reasoning': reasoning_text.strip(),
            'final_answer': final_answer,
            'json_data': json_data
        }
    except json.JSONDecodeError as e:
        # This should rarely happen with guided decoding
        print(f"Unexpected JSON parsing failed: {e}")
        print(f"Response content: {response[:200]}...")

        return {
            'full_response': response,
            'structured_reasoning': '',
            'final_answer': '',
            'json_data': None,
            'error': str(e)
        }


def process_responses(llm: LLM, prompts: List[str], sampling_params: SamplingParams) -> List[str]:
    """Process prompts with the LLM and return responses."""
    outputs = llm.generate(prompts, sampling_params)
    return [output.outputs[0].text for output in outputs]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate a conclusion for the given answer.")
    parser.add_argument("--model-path", type=str,
                        required=True, help="Path to the model.")
    parser.add_argument("--tensor-parallel-size", type=int,
                        default=1, help="Tensor parallel size.")
    parser.add_argument("--input-file", type=str, required=True,
                        help="Path to the input file containing the answer.")
    parser.add_argument("--output-file", type=str, required=True,
                        help="Path to the output file to save the conclusion.")
    parser.add_argument("--use-problem", action="store_true",
                        help="Whether to use the problem in the prompt.")
    parser.add_argument("--max-tokens", type=int, default=2048,
                        help="Maximum number of tokens to generate.")
    parser.add_argument("--temperature", type=float,
                        default=0.0, help="Temperature for sampling.")
    parser.add_argument("--top-p", type=float, default=1.0,
                        help="Top-p for sampling.")
    parser.add_argument("--structured-output", action="store_true",
                        help="Use structured JSON output with guided decoding.")
    args = parser.parse_args()

    # Setup logging
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)

    # Initialize tokenizer and LLM
    logger.info(f"Loading model from {args.model_path}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    llm = LLM(model=args.model_path, tensor_parallel_size=args.tensor_parallel_size,
              trust_remote_code=True, gpu_memory_utilization=0.85)

    # Setup sampling parameters
    if args.structured_output:
        logger.info("Using structured output mode with JSON schema")
        guided_decoding = GuidedDecodingParams(json=REASONING_SCHEMA)
        sampling_params = SamplingParams(
            temperature=args.temperature,
            top_p=args.top_p,
            max_tokens=args.max_tokens,
            guided_decoding=guided_decoding
        )
    else:
        logger.info("Using natural language output mode")
        sampling_params = SamplingParams(
            temperature=args.temperature,
            top_p=args.top_p,
            max_tokens=args.max_tokens
        )

    # Load and process data
    logger.info(f"Loading input data from {args.input_file}")
    input_data = load_input_data(args.input_file)

    # Count total items across all datasets
    total_items = sum(len(dataset_items)
                      for dataset_items in input_data.values())
    logger.info(
        f"Loaded {len(input_data)} datasets with {total_items} total items")

    # Prepare prompts and track dataset mapping
    prompts = []
    item_mapping = []  # Track which dataset and item index each prompt corresponds to

    for dataset_name, dataset_items in input_data.items():
        logger.info(
            f"Processing dataset: {dataset_name} with {len(dataset_items)} items")

        for item_idx, item in enumerate(dataset_items):
            # Validate and clean item
            is_valid, cleaned_item = validate_and_clean_item(
                item, dataset_name, item_idx)
            if not is_valid:
                continue

            # Get original answer from llm_output field
            original_answer = cleaned_item["llm_output"]

            # Prepare problem using the prepare_input_data function
            problem = None
            if args.use_problem:
                try:
                    problem = prepare_input_data(
                        cleaned_item, include_reasoning=False)
                except (KeyError, TypeError) as e:
                    logger.warning(
                        f"Could not prepare problem for dataset {dataset_name}, item {item_idx}: {e}")
                    # Fallback to original problem field if exists
                    problem = cleaned_item.get("problem", "")

            # Get validated attributes and relations
            attributes = cleaned_item['attributes']
            relations = cleaned_item['relations']

            if args.structured_output:
                # Use structured prompt for JSON output
                messages = create_structured_chat_messages(
                    original_answer, problem, attributes, relations)
                if "qwen3" in args.model_path.lower():
                    prompt = tokenizer.apply_chat_template(
                        messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
                else:
                    prompt = tokenizer.apply_chat_template(
                        messages, tokenize=False, add_generation_prompt=True)
            else:
                # Use natural language prompt
                messages = create_chat_messages(
                    original_answer, problem, attributes, relations)
                prompt = tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True)

            prompts.append(prompt)
            item_mapping.append((dataset_name, item_idx))

    # Generate responses
    logger.info(f"Processing {len(prompts)} prompts")
    responses = process_responses(llm, prompts, sampling_params)

    # Create output data structure maintaining original format
    output_data = {}

    # Initialize output datasets
    for dataset_name in input_data.keys():
        output_data[dataset_name] = []

    # Process responses and merge with original data
    for i, (response, (dataset_name, item_idx)) in enumerate(zip(responses, item_mapping)):
        original_item = input_data[dataset_name][item_idx]

        if args.structured_output:
            structured_info = extract_structured_json(response)
        else:
            structured_info = extract_structured_reasoning(response)

        # Create new item with original data plus formatted output
        new_item = original_item.copy()

        # Add formatted output fields
        new_item["formatted_response"] = structured_info['full_response']
        new_item["structured_reasoning"] = structured_info['structured_reasoning']
        new_item["final_answer"] = structured_info['final_answer']

        # Add JSON data if available
        if 'json_data' in structured_info and structured_info['json_data']:
            new_item["json_data"] = structured_info['json_data']

        output_data[dataset_name].append(new_item)

    # Save results
    logger.info(f"Saving results to {args.output_file}")
    save_output_data(args.output_file, output_data)

    # Statistics
    total_items = sum(len(dataset_items)
                      for dataset_items in output_data.values())
    valid_structured = 0
    valid_answers = 0
    valid_json = 0

    for dataset_items in output_data.values():
        for item in dataset_items:
            if item.get('structured_reasoning'):
                valid_structured += 1
            if item.get('final_answer'):
                valid_answers += 1
            if 'json_data' in item and item['json_data']:
                valid_json += 1

    logger.info(f"Processing completed for {len(output_data)} datasets:")
    for dataset_name, dataset_items in output_data.items():
        logger.info(f"  - {dataset_name}: {len(dataset_items)} items")

    if args.structured_output:
        logger.info(
            f"Total: {total_items} items, {valid_structured} structured, {valid_answers} with answers, {valid_json} valid JSON")
    else:
        logger.info(
            f"Total: {total_items} items, {valid_structured} structured, {valid_answers} with answers")
