import argparse
import os
import json
from tqdm import tqdm
import copy
from concurrent.futures import ThreadPoolExecutor
import concurrent.futures
import time
from datasets import load_dataset
from openai import OpenAI


def preprocess_data(test_case_string):
    if f"```python" in test_case_string:
        test_case_string = test_case_string[test_case_string.find(f"```python")+len(f"```python"):]
        test_case_string = test_case_string[:test_case_string.find("```")]

    return test_case_string

def parse_livecodebench_response(response_text):
    """Parse and validate LiveCodeBench test input response"""
    try:
        # Try to extract JSON from response
        if "```json" in response_text:
            response_text = response_text[response_text.find("```json")+7:]
            response_text = response_text[:response_text.find("```")]
        elif "```" in response_text:
            response_text = response_text[response_text.find("```")+3:]
            response_text = response_text[:response_text.find("```")]
        
        # Parse JSON
        test_inputs = json.loads(response_text.strip())
        
        # Validate format
        if not isinstance(test_inputs, list):
            print(f"Warning: Response is not a list, got {type(test_inputs)}")
            return ""
        
        # Check if each item has the required "input" field
        valid_inputs = []
        for item in test_inputs:
            if isinstance(item, dict) and "input" in item:
                valid_inputs.append(item)
            else:
                print(f"Warning: Invalid test case format: {item}")
        
        if not valid_inputs:
            print("Warning: No valid test inputs found in response")
            return ""
        
        # Return formatted JSON string
        return json.dumps(valid_inputs, indent=2)
        
    except json.JSONDecodeError as e:
        print(f"Error parsing JSON response: {e}")
        return ""
    except Exception as e:
        print(f"Unexpected error parsing response: {e}")
        return ""

# Function to fetch completion
def fetch_completion(data_entry, model, dataset_name):
    client = OpenAI(
            api_key ="API-Key",
        ) 
    if dataset_name == "livecodebench":
        task_description = f"### Question:\n{data_entry['question_content']}\n\n"
        if data_entry.get("starter_code"):
            task_description += f"### Starter Code:\n```python\n{data_entry['starter_code']}\n```\n\n"
        # Extract public test cases as examples
        public_test_cases = data_entry.get('public_test_cases', '[]')
        examples_text = ""
        try:
            test_cases_list = json.loads(public_test_cases)
            test_input_list = []
            for test_case in test_cases_list:
                test_input_list.append({"input": test_case["input"]})
            examples_text = "\n\nSample test inputs from the problem (in JSON format):\n```json\n"
            examples_text += json.dumps(test_input_list[:3], indent=2)
            examples_text += "\n```\n"
        except:
            pass
        text = f"""You are an expert test case generator for competitive programming problems.

Problem Description:
{task_description}{examples_text}

This is a standard input/output problem where the solution reads from stdin and writes to stdout.

Generate 20 diverse test inputs in JSON format exactly like the examples shown.
The format should be a JSON array where each test case is an object with:
- "input": the stdin input string (include \\n for newlines)

Return ONLY a valid JSON array. For example:
[
  {{"input": "5\\nTTAAT\\n"}},
  {{"input": "1\\nA\\n"}}
]

IMPORTANT: 
- Use the exact input format shown in the problem examples
- Include \\n for newlines in input strings
- Ensure the JSON is valid and properly formatted
- Return ONLY the JSON array, no explanations or markdown code blocks
"""
    else:
        if dataset_name == "humaneval":
            prompt = data_entry["prompt"] + "\n" + data_entry["canonical_solution"]
            entry_point = data_entry["entry_point"]
        elif dataset_name == "mbpp":
            prompt = f'"""\n{data_entry["prompt"]}\n"""{data_entry["code"]}'
            entry_point = data_entry["test_list"][0].replace("assert ","").split("(")[0]
        elif dataset_name == "ult":
            prompt = f'"""\n{data_entry["prompt"]}\n"""{data_entry["code"]}'
            entry_point = data_entry["test_list"][0].replace("assert ","").split("(")[0]
        elif dataset_name == "humaneval-pro":
            prompt = data_entry["raw_problem"] + "\n" + data_entry["raw_solution"] + "\n" + data_entry["new_problem"] + "\n" + data_entry["new_solution"]
            test_example = [test for test in data_entry["test_code"].split("\n") if "assert " in test][0]
            entry_point = test_example.replace("assert ","").split("(")[0]
            data_entry["entry_point"] = entry_point
        elif dataset_name == "mbpp-pro":
            prompt = data_entry["raw_problem"] + "\n" + data_entry["raw_solution"] + "\n" + data_entry["new_problem"] + "\n" + data_entry["new_solution"]
            test_example = [test for test in data_entry["test_code"].split("\n") if "assert " in test][0]
            entry_point = test_example.replace("assert ","").split("(")[0]
            data_entry["entry_point"] = entry_point
        elif dataset_name == "testeval":
            prompt = f'"""\n{data_entry["description"]}\n"""{data_entry["python_solution"]}\n\nsolution = Solution()'
            entry_point = f'solution.{data_entry["func_name"]}'
        text = f"""
Please generate 20 diverse test inputs for the following function.
The test input should be the format of:
```python
{entry_point}(input_parameters)
```
Each line should be a separate test input.
```python
{prompt}
```
"""

    try:
        completions = client.chat.completions.create(
            model=model,
            stream=False,
            messages=[
        {"role": "user", "content":text},
            ],
        )
        test_case = completions.choices[0].message.content
        
        # Handle LiveCodeBench differently
        if dataset_name == "livecodebench":
            parsed_inputs = parse_livecodebench_response(test_case)
            data_entry["test_inputs"] = parsed_inputs if parsed_inputs else ""
        else:
            data_entry["test_inputs"] = preprocess_data(test_case)
            
    except Exception as e:
        print(f"Error fetching completion: {e}")
        data_entry["test_inputs"] = ""
    
    return data_entry



if __name__ == "__main__":

    # dataset = load_dataset("openai_humaneval", split="test")
    dataset_name = "ult"
    if dataset_name == "mbpp":
        dataset = load_dataset("google-research-datasets/mbpp", "sanitized", split="test")
    elif dataset_name == "humaneval":
        dataset = load_dataset("openai_humaneval", split="test")
    elif dataset_name == "humaneval-pro":
        dataset = load_dataset("CodeEval-Pro/humaneval-pro", split="train")
    elif dataset_name == "mbpp-pro":
        dataset = load_dataset("CodeEval-Pro/mbpp-pro", split="train")
    elif dataset_name == "livecodebench":
        with open("./datasets/livecodebench_test_inputs.json", "r") as f:
            dataset = json.load(f)
    elif dataset_name == "ult":
        with open("./datasets/ult_lite.json", "r") as f:
            dataset = json.load(f)
    elif dataset_name == "testeval":
        with open("./datasets/leetcode-py.jsonl", "r") as f:
            dataset = [json.loads(line) for line in f.readlines()]
    dataset = [entry for entry in dataset]
    model = "gpt-4.1-nano"
    with ThreadPoolExecutor() as executor:
        future_to_entry = {executor.submit(fetch_completion, copy.deepcopy(entry), model, dataset_name): entry for entry in tqdm(dataset)}
        for future in tqdm(concurrent.futures.as_completed(future_to_entry)):
            entry = future_to_entry[future]
            try:
                updated_entry = future.result()
                idx = dataset.index(entry)
                dataset[idx] = updated_entry
            except Exception as e:
                print(f"Error processing entry: {repr(e)}")
                # Ensure failed entries have empty test_inputs
                entry["test_inputs"] = ""
                idx = dataset.index(entry)
                dataset[idx] = entry
    
    with open(f"./datasets/{dataset_name}_test_inputs.json", "w") as f:
        json.dump(dataset, f, indent=4)