# Standard library imports
import json
import os
import re
import time
import random
import traceback
import concurrent.futures
from typing import *

# Third-party imports
import numpy as np
import yaml
from tqdm import tqdm
from loguru import logger
from openai import OpenAI
import json_repair

# Local imports
from thinktime.sft.utils.rewrite_tool_prompt import (
    TOOL_DEFINITION,
    create_multi_tool_thinking_prompt
)

# CONFIG
MODEL_PATH = yaml.safe_load(open("config/datagen_config.yaml"))["local_llm_path"]
num_gpus = yaml.safe_load(open("config/datagen_config.yaml"))["num_gpus"]
gpu_per_model = yaml.safe_load(open("config/datagen_config.yaml"))["gpu_per_model"]
SEQ_LEN = yaml.safe_load(open("config/datagen_config.yaml"))["seq_len"]
# Load configuration
config = yaml.safe_load(open("config/datagen_config.yaml"))
DATA_OUTPUT_DIR = config["data_output_dir"]
ENCODING_METHOD = config['encoding_method']
TOTAL_CNT = 10000
INPUT_FILES = [
    (f'{DATA_OUTPUT_DIR}/ift_rlvr_None_{ENCODING_METHOD}.jsonl', f'{DATA_OUTPUT_DIR}/evol_labels/ift_rlvr_None_{ENCODING_METHOD}.json'),
]
OUTPUT_FILE = f'{DATA_OUTPUT_DIR}/warmup_sft_{TOTAL_CNT}_{ENCODING_METHOD}.jsonl'

# Remote API Configuration
USE_REMOTE_API = True

def setup_remote_api():
    """Setup remote API client if enabled."""
    if not USE_REMOTE_API:
        return None
        
    global API_BASE_URL, API_KEY, API_MODEL, API_MAX_WORKERS, API_TIMEOUT, client
    
    # API Configuration - update these values as needed
    API_BASE_URL = "[OPENAI_BASE_URL]"
    API_KEY = "[OPENAI_API_KEY]"
    API_MODEL = "[OPENAI_API_MODEL]"
    API_MAX_WORKERS = 64  # Number of concurrent API calls
    API_TIMEOUT = 120     # Timeout in seconds
    
    # Initialize OpenAI client
    client = OpenAI(
        api_key=API_KEY,
        base_url=API_BASE_URL
    )
    
    logger.info(f"Remote API mode enabled")
    logger.info(f"API Base URL: {API_BASE_URL}")
    logger.info(f"API Model: {API_MODEL}")
    logger.info(f"Max Workers: {API_MAX_WORKERS}")
    logger.info(f"Timeout: {API_TIMEOUT}s")
    
    return client

# Initialize API client
client = setup_remote_api()


def call_remote_api(prompt: str, model: str = None, max_tokens: int = 4000, temperature: float = 1.0) -> str:
    """Call remote API using Chat Completions (do not apply manual chat template upstream)."""
    if not model:
        model = API_MODEL

    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt},
            ],
            max_tokens=max_tokens,
            temperature=temperature,
            timeout=API_TIMEOUT,
            extra_body={
                "thinking": {
                    "type": "disabled"
                }
            }
        )
        result = response.choices[0].message.content
        return result
    except Exception as e:
        logger.error(f"API call failed: {e}")

        if 'limit' in str(e).lower():
            time.sleep(random.randint(20, 60))

        return f"Error: API call failed - {str(e)}"


# ============================================================================
# Response Parsing and Validation Functions  
# ============================================================================

def parse_llm_response(response_text: str) -> Dict[str, str]:
    """
    Parse LLM response and extract THINKING and ANSWER sections only.
    """
    try:
        sections = {}
        
        # Extract thinking and answer sections using regex
        thinking_match = re.search(r'### THINKING ###\s*(.*?)(?=### ANSWER ###|$)', response_text, re.DOTALL)
        if thinking_match:
            sections['thinking'] = thinking_match.group(1).strip()
        
        answer_match = re.search(r'### ANSWER ###\s*(.*?)(?=###|$)', response_text, re.DOTALL)
        if answer_match:
            sections['answer'] = answer_match.group(1).strip()
        
        # Validate required sections
        required_sections = ['thinking', 'answer']
        missing_sections = [section for section in required_sections if section not in sections or not sections[section]]
        
        if missing_sections:
            # Fallback: try to parse as JSON if structured format failed
            try:
                parsed_json = json.loads(response_text)
                return {
                    'thinking': parsed_json.get('thinking', '').strip(),
                    'answer': parsed_json.get('answer', '').strip(),
                }
            except:
                raise ValueError(f"Missing sections: {missing_sections}. Response: {response_text[:200]}...")
        
        return sections
        
    except Exception as e:
        raise ValueError(f"Failed to process LLM response: {str(e)}")


def llm_batch_generate_remote_api(seed_data: List[Dict]) -> None:
    """Generate dataset using remote API without DFS expansion."""
    os.makedirs('result', exist_ok=True)

    from queue import Queue
    work_queue: Queue[Tuple[str, Dict]] = Queue()

    # Seed initial queue - no DFS expansion needed
    for i, item in enumerate(seed_data):
        cur_prompt = create_multi_tool_thinking_prompt(item)
        work_queue.put((cur_prompt, item))

    parse_failed = 0
    validation_failed = 0
    success_cnt = 0

    # Manage inflight tasks to mimic worker pool
    with concurrent.futures.ThreadPoolExecutor(max_workers=API_MAX_WORKERS) as executor, \
         tqdm(total=TOTAL_CNT) as pbar, \
         open(OUTPUT_FILE, 'wt') as fo:

        inflight: Dict[concurrent.futures.Future, Dict] = {}

        def submit_from_queue():
            nonlocal inflight
            while len(inflight) < API_MAX_WORKERS and not work_queue.empty():
                prompt, seed = work_queue.get()
                future = executor.submit(call_remote_api, prompt)
                inflight[future] = seed

        # Prime the executor
        submit_from_queue()

        while success_cnt < TOTAL_CNT and (inflight or not work_queue.empty()):
            # Wait for at least one future to complete
            if not inflight:
                submit_from_queue()
                if not inflight:
                    break

            done, _ = concurrent.futures.wait(
                list(inflight.keys()), timeout=0.1, return_when=concurrent.futures.FIRST_COMPLETED
            )

            for fut in done:
                seed_item = inflight.pop(fut)
                try:
                    response = fut.result()
                    available_metrics = seed_item['metrics']
                    max_ts_length = max([len(ts) for ts in seed_item['timeseries']]) if len(seed_item['timeseries']) > 0 else 1000

                    # Parse and validate - only extract thinking and answer
                    sections = parse_llm_response(response)
                    thinking = sections.get('thinking', '')
                    answer = sections.get('answer', '')

                    # Validate content
                    if not thinking or not answer:
                        raise ValueError("Thinking or answer section is empty")
                    
                    # Validate thinking length
                    if len(thinking) < 300:
                        raise ValueError(f"Thinking process too short: {len(thinking)} characters (expected at least 300)")

                    # Validate tool usage
                    is_valid, errors, tool_call_count = validate_thinking_with_tools(
                        thinking, available_metrics, max_ts_length
                    )

                    if not is_valid:
                        raise ValueError(f"Tool call validation failed: {'; '.join(errors)}")

                    # Use original question and extract final_answer from original answer
                    original_question = seed_item['question']
                    original_answer = seed_item['answer']
                    original_instruction = seed_item['instruction']
                    
                    # Try to extract final_answer using \\answer{} pattern
                    final_answer_match = re.search(r'\\answer\{([^}]+)\}', original_answer)
                    if final_answer_match:
                        original_final_answer = final_answer_match.group(1)
                    else:
                        # Fallback: use the original answer directly if no pattern found
                        original_final_answer = original_answer
                    
                    # Create output entry with original question and final_answer
                    output_entry = {
                        'input': original_instruction + original_question,
                        'output': f"<think>\n{thinking}\n</think>\n<answer>{answer}\n\n\\answer{{{original_final_answer}}}</answer>",
                        'timeseries': seed_item['timeseries'].tolist(),
                        'tools': json.dumps(TOOL_DEFINITION, ensure_ascii=False)
                    }

                    fo.write(json.dumps(output_entry, ensure_ascii=False) + '\n')
                    fo.flush()

                    success_cnt += 1
                    pbar.update(1)

                except Exception as err:
                    if "Tool call validation failed" in str(err):
                        validation_failed += 1
                        traceback.print_exc()
                        logger.warning(f"Validation failed: {err}")
                    else:
                        parse_failed += 1
                        logger.warning(f"Parse failed: {err}")

            # Top-up submissions
            submit_from_queue()

        logger.info(
            f"Remote API processing completed. Success: {success_cnt}, Parse failed: {parse_failed}, Validation failed: {validation_failed}"
        )


def validate_thinking_with_tools(thinking_text: str, available_metrics: List[str], max_timeseries_length: int) -> Tuple[bool, List[str], int]:
    """
    Validate that the thinking process contains properly formatted tool calls.
    
    Returns:
        - is_valid: Whether the thinking contains valid tool calls
        - error_messages: List of validation errors
        - tool_call_count: Number of valid tool calls found
    """
    errors = []
    tool_call_count = 0
    
    # Find all tool call blocks
    tool_call_pattern = r'<tool_start>\s*(.*?)\s*<tool_end>'
    tool_calls = re.findall(tool_call_pattern, thinking_text, re.DOTALL)
    
    # Validate each tool call
    for i, tool_call_content in enumerate(tool_calls):
        try:
            # Try to parse as JSON
            tool_call_json = json.loads(tool_call_content.strip())
            
            # Check required fields
            if "name" not in tool_call_json:
                errors.append(f"Tool call {i+1}: Missing 'name' field")
                continue
            
            if "arguments" not in tool_call_json:
                errors.append(f"Tool call {i+1}: Missing 'arguments' field")
                continue
            
            # Check tool name
            tool_name = tool_call_json["name"]
            if tool_name not in ["get_timeseries_slice", "compare_timeseries_slice"]:
                errors.append(f"Tool call {i+1}: Invalid tool name '{tool_name}'")
                continue
            
            args = tool_call_json["arguments"]
            
            # Check required arguments based on tool type
            if tool_name == "get_timeseries_slice":
                required_args = ["metric_name", "start", "end"]
            elif tool_name == "compare_timeseries_slice":
                required_args = ["metric_name_1", "start_1", "end_1", "metric_name_2", "start_2", "end_2"]
            
            for arg in required_args:
                if arg not in args:
                    errors.append(f"Tool call {i+1}: Missing argument '{arg}'")
                    continue
            
            # Validate metric names and indices based on tool type
            if tool_name == "get_timeseries_slice":
                # Validate metric name
                metric_name = args.get("metric_name", "")
                if not metric_name:
                    errors.append(f"Tool call {i+1}: Empty metric_name")
                elif available_metrics and metric_name not in available_metrics:
                    # Check if it's a partial match (case insensitive)
                    if not any(metric_name.lower() in m.lower() or m.lower() in metric_name.lower() for m in available_metrics):
                        errors.append(f"Tool call {i+1}: Metric '{metric_name}' not in available metrics {available_metrics}")
                
                # Validate start and end indices
                try:
                    start = int(args["start"])
                    end = int(args["end"])
                    
                    if start < 0:
                        errors.append(f"Tool call {i+1}: start index {start} cannot be negative")
                    if end <= start:
                        errors.append(f"Tool call {i+1}: end index {end} must be greater than start index {start}")
                    if end > max_timeseries_length:
                        errors.append(f"Tool call {i+1}: end index {end} exceeds timeseries length {max_timeseries_length}")
                except ValueError:
                    errors.append(f"Tool call {i+1}: start and end must be integers")
                    
            elif tool_name == "compare_timeseries_slice":
                # Validate both metric names
                for suffix in ["_1", "_2"]:
                    metric_name = args.get(f"metric_name{suffix}", "")
                    if not metric_name:
                        errors.append(f"Tool call {i+1}: Empty metric_name{suffix}")
                    elif available_metrics and metric_name not in available_metrics:
                        if not any(metric_name.lower() in m.lower() or m.lower() in metric_name.lower() for m in available_metrics):
                            errors.append(f"Tool call {i+1}: Metric '{metric_name}' not in available metrics {available_metrics}")
                
                # Validate indices for both slices
                try:
                    for suffix in ["_1", "_2"]:
                        start = int(args[f"start{suffix}"])
                        end = int(args[f"end{suffix}"])
                        
                        if start < 0:
                            errors.append(f"Tool call {i+1}: start{suffix} index {start} cannot be negative")
                        if end <= start:
                            errors.append(f"Tool call {i+1}: end{suffix} index {end} must be greater than start{suffix} index {start}")
                        if end > max_timeseries_length:
                            errors.append(f"Tool call {i+1}: end{suffix} index {end} exceeds timeseries length {max_timeseries_length}")
                except ValueError:
                    errors.append(f"Tool call {i+1}: all indices must be integers")
            
            # Valid tool call
            if not errors or all(f"Tool call {i+1}:" not in error for error in errors):
                tool_call_count += 1
                    
        except json.JSONDecodeError as e:
            errors.append(f"Tool call {i+1}: Invalid JSON format in [{tool_call_content.strip()}] - {str(e)}")
        except Exception as e:
            errors.append(f"Tool call {i+1}: Unexpected error - {str(e)}")

    is_valid = tool_call_count <= 5 and len([e for e in errors if "Tool call" in e]) == 0

    return is_valid, errors, tool_call_count


def generate_rewrite_tool_thinking_dataset():
    # Load files
    input_list = []
    all_file_data = []

    print("Loading files...")
    for input_file, label_file in INPUT_FILES:
        print(f"Loading {input_file} and {label_file}...")
        qa_dataset = [json.loads(line.rstrip()) for line in tqdm(open(input_file))]
        labels = json.load(open(label_file))

        file_data = []
        for data, label in zip(qa_dataset, labels):
            ability_type = None

            if 'rlvr' in input_file:
                ability_type = json.loads(data['output'])['ability_type']
            elif 'mts_local' in input_file:
                ability_type = 'mts_local'
            elif 'mts_shape' in input_file:
                ability_type = 'mts_shape'

            file_data.append({
                'question': label['question'],
                'answer': json.loads(data['output'])['answer'] if 'rlvr' in input_file else data['output'],
                'timeseries': np.array(data['timeseries']),
                'description': label.get('attribute_pool', []),
                'metrics': label['metrics'],
                'fields': label.get('fields', {}),
                'corr_pool': label.get('corr_pool', []),
                'instruction': label.get('instruction', ''),
                'ability_type': ability_type
            })
        
        all_file_data.append((input_file, file_data))
        print(f"Loaded {len(file_data)} samples from {input_file}")

    # Calculate sampling size for each file (no DFS expansion)
    total_files = len(all_file_data)
    samples_per_file = TOTAL_CNT * 3 // total_files
    
    print(f"\nResampling to ensure balanced dataset:")
    print(f"Total files: {total_files}")
    print(f"Samples per file: {samples_per_file}")
    
    # Resample each file to equal size
    for input_file, file_data in all_file_data:
        if len(file_data) >= samples_per_file:
            sampled_data = random.sample(file_data, samples_per_file)
        else:
            sampled_data = random.choices(file_data, k=samples_per_file)
        
        input_list.extend(sampled_data)
        print(f"Resampled {len(sampled_data)} from {input_file} (original: {len(file_data)})")

    # Randomly shuffle the final input_list
    random.shuffle(input_list)

    print(f"{len(input_list)} seed QAs loaded from file.")

    # Run llm inference
    logger.info("Using remote API for generation")
    if not API_BASE_URL or not API_KEY:
        raise ValueError("Remote API enabled but API_BASE_URL or API_KEY not configured")
    llm_batch_generate_remote_api(input_list)

    print(f"-------------------------------------------")
    print(f"Finished! File saved to {OUTPUT_FILE}.")


def main():
    """Main entry point with configuration validation."""
    generate_rewrite_tool_thinking_dataset()

if __name__ == "__main__":
    main()
