import argparse
import json
from pathlib import Path
import time
import os
from openai import OpenAI
from dotenv import load_dotenv
from tqdm import tqdm


def process_batch_file(client: OpenAI, batch_file: Path, output_dir: Path):
    """Process a single batch file and save results."""
    # Create output directory if it doesn't exist
    output_dir.mkdir(parents=True, exist_ok=True)

    # Read batch requests
    requests = []
    with open(batch_file) as f:
        for line in f:
            requests.append(json.loads(line))

    # Process each request
    results = []
    for request in tqdm(requests, desc=f"Processing {batch_file.name}"):
        try:
            response = client.chat.completions.create(**request)
            prediction = int(response.choices[0].message.content.strip())
            results.append(
                {"success": True, "prediction": prediction, "request": request}
            )
        except Exception as e:
            print(f"Error processing request: {e}")
            results.append({"success": False, "error": str(e), "request": request})

        # Sleep to respect rate limits
        time.sleep(0.1)  # Adjust as needed based on your rate limits

    # Save results
    output_file = output_dir / f"{batch_file.stem}_results.json"
    with open(output_file, "w") as f:
        json.dump(results, f, indent=2)

    return output_file


def combine_results(
    task: str,
    sample_size: int,
    val_set: str,
    batch_results_dir: Path,
    labels_file: Path,
    output_dir: Path,
):
    """Combine batch results with labels and save final results."""
    # Load labels and metadata
    with open(labels_file) as f:
        label_data = json.load(f)

    # Find and load all batch results for this configuration
    pattern = f"{task}_n{sample_size}_{val_set}_batch*_results.json"
    batch_files = sorted(batch_results_dir.glob(pattern))

    all_predictions = []
    for batch_file in batch_files:
        with open(batch_file) as f:
            results = json.load(f)
            all_predictions.extend(
                [r.get("prediction", -1) if r["success"] else -1 for r in results]
            )

    # Calculate accuracy
    valid_predictions = [p for p in all_predictions if p != -1]
    accuracy = (
        sum(1 for p, l in zip(valid_predictions, label_data["labels"]) if p == l)
        / len(valid_predictions)
        if valid_predictions
        else 0
    )

    # Create final results
    final_results = {
        "task": task,
        "sample_size": sample_size,
        "validation_set": val_set,
        "accuracy": accuracy,
        "total_examples": len(label_data["labels"]),
        "valid_predictions": len(valid_predictions),
        "examples_used": label_data["examples_used"],
        "raw_results": [
            {
                "sequence": seq,
                "true_label": label,
                "predicted_label": pred,
                "correct": pred == label if pred != -1 else None,
            }
            for seq, label, pred in zip(
                label_data["sequences"], label_data["labels"], all_predictions
            )
        ],
    }

    # Save final results
    output_file = output_dir / f"results_{task}_{val_set}_n{sample_size}.json"
    with open(output_file, "w") as f:
        json.dump(final_results, f, indent=2)

    return output_file


def main():
    parser = argparse.ArgumentParser(description="Submit batch files to OpenAI API")
    parser.add_argument(
        "--batch_dir", required=True, help="Directory containing batch files"
    )
    parser.add_argument(
        "--output_dir", default="batch_results", help="Directory to save results"
    )
    parser.add_argument(
        "--pattern", default="*.jsonl", help="Pattern to match batch files"
    )

    args = parser.parse_args()

    # Load environment variables
    load_dotenv()
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("OPENAI_API_KEY not found in environment variables")

    # Initialize OpenAI client
    client = OpenAI(api_key=api_key)

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Process all batch files
    batch_dir = Path(args.batch_dir)
    batch_files = sorted(batch_dir.glob(args.pattern))

    if not batch_files:
        print(f"No batch files found matching pattern '{args.pattern}' in {batch_dir}")
        return

    print(f"Found {len(batch_files)} batch files")

    # Process each batch file
    for batch_file in batch_files:
        try:
            output_file = process_batch_file(client, batch_file, output_dir)
            print(f"Processed {batch_file} -> {output_file}")
        except Exception as e:
            print(f"Error processing {batch_file}: {e}")

    # Combine results for each configuration
    print("\nCombining results...")
    labels_dir = Path(args.batch_dir)
    for labels_file in labels_dir.glob("*_labels.json"):
        try:
            with open(labels_file) as f:
                metadata = json.load(f)

            output_file = combine_results(
                metadata["task"],
                metadata["sample_size"],
                metadata["validation_set"],
                output_dir,
                labels_file,
                output_dir.parent,
            )
            print(f"Combined results saved to {output_file}")
        except Exception as e:
            print(f"Error combining results for {labels_file}: {e}")


if __name__ == "__main__":
    main()
