import argparse
import json
from pathlib import Path
from openai import OpenAI
import os
from typing import Dict, List, Tuple
import time
from dotenv import load_dotenv

load_dotenv()


def process_batch_results(results_file: str, prompt_type: str) -> List[int]:
    """Process batch results file and extract predictions."""
    raw_responses = []
    predictions = []
    with open(results_file, "r") as f:
        for line in f:
            result = json.loads(line)
            if result.get("error"):
                print(f"Error in request {result['custom_id']}: {result['error']}")
                predictions.append(-1)
                raw_responses.append(result["error"])
                continue

            try:
                response = result["response"]["body"]
                raw_response = response["choices"][0]["message"]["content"].strip()
                raw_responses.append(raw_response)
                if prompt_type == "io_prompt":
                    # remove punctuation
                    response = raw_response.replace(".", "")
                    if response.isdigit():
                        prediction = int(response)
                    else:
                        prediction = -1
                elif prompt_type == "zsr_prompt":
                    # take last digit {0,1}
                    for char in reversed(raw_response):
                        if char in ["0", "1"]:
                            prediction = int(char)
                            break
                        else:
                            prediction = -1

                predictions.append(prediction)

            except (KeyError, ValueError) as e:
                print(f"Error parsing result for {result['custom_id']}: {e}")
                predictions.append(-1)
                raw_responses.append(response)

    return raw_responses, predictions


def evaluate_model(predictions: List[int], labels: List[int]) -> float:
    """Calculate accuracy of predictions."""
    valid_predictions = [(p, l) for p, l in zip(predictions, labels) if p != -1]
    if not valid_predictions:
        return 0.0

    predictions_filtered, labels_filtered = zip(*valid_predictions)
    correct = sum(1 for p, l in zip(predictions_filtered, labels_filtered) if p == l)
    return round(correct / len(valid_predictions), 2)


def process_completed_batch(
    client: OpenAI, metadata_file: Path, exp_name: str, prompt_type: str
) -> Dict:
    """Process a completed batch and return results."""
    with open(metadata_file) as f:
        metadata = json.load(f)

    batch_id = metadata["batch_id"]
    batch_status = client.batches.retrieve(batch_id)

    # Generate job filename (used for both success and failure cases)
    size_info = metadata["size"].split("_")
    size = size_info[0]
    val_set = size_info[1] if len(size_info) > 1 else "short"
    job_file = f"batch_jobs/{exp_name}/{metadata['task']}_{metadata['model'].replace(' ', '_')}_{size}_{val_set}.job"

    if batch_status.status == "completed":
        # Extract size and validation set from metadata
        size_info = metadata["size"].split("_")
        size = size_info[0]
        val_set = size_info[1] if len(size_info) > 1 else "short"

        # Create output directories
        task = metadata["task"]
        task_dir = Path("results") / exp_name / "tasks" / task
        output_dir = task_dir / "output"
        output_dir.mkdir(parents=True, exist_ok=True)

        # Download and save batch results
        output_file = output_dir / f"batch_output_{batch_id}.jsonl"
        with open(output_file, "w") as f:
            f.write(client.files.content(batch_status.output_file_id).text)

        # Process predictions
        raw_responses, predictions = process_batch_results(
            str(output_file), prompt_type
        )
        val_sequences = metadata["validation_data"]["sequences"]
        val_labels = metadata["validation_data"]["labels"]

        # Calculate accuracy
        accuracy = evaluate_model(predictions, val_labels)

        # Create results
        results = {
            "accuracy": accuracy,
            "raw_results": [
                {
                    "sequence": seq,
                    "true_label": label,
                    "raw_response": raw_response,
                    "model_output": pred,
                    "correct": pred == label,
                }
                for seq, label, pred, raw_response in zip(
                    val_sequences, val_labels, predictions, raw_responses
                )
            ],
            "examples_used": metadata["examples"],
            "batch_id": batch_id,
            "model": metadata["model"],
            "validation_size": len(val_sequences),
            "validation_balanced": metadata.get("validation_balanced", False),
            "validation_set": val_set,
            "label_distribution": {
                "training": {
                    "label_0": sum(
                        1 for ex in metadata["examples"] if ex["label"] == 0
                    ),
                    "label_1": sum(
                        1 for ex in metadata["examples"] if ex["label"] == 1
                    ),
                },
                "validation": {
                    "label_0": sum(1 for label in val_labels if label == 0),
                    "label_1": sum(1 for label in val_labels if label == 1),
                },
            },
        }

        # Save results to combined file
        output_file = task_dir / f"results_{task}_combined.json"
        output_file.parent.mkdir(parents=True, exist_ok=True)

        if output_file.exists():
            with open(output_file) as f:
                combined_results = json.load(f)
        else:
            combined_results = {}

        # Initialize nested dictionaries if needed
        model = metadata["model"]
        if model not in combined_results:
            combined_results[model] = {}
        if size not in combined_results[model]:
            combined_results[model][size] = {}

        # Store results
        combined_results[model][size][val_set] = results

        # Save updated results
        with open(output_file, "w") as f:
            json.dump(combined_results, f, indent=2)

        print(f"Results saved to {output_file}")
        print(f"Batch output saved to {output_file}")

        # Successfully completed, so remove from failed_jobs.txt (if it exists)
        failed_jobs_file = Path(f"batch_jobs/{exp_name}/failed_jobs.txt")
        if failed_jobs_file.exists():
            with open(failed_jobs_file) as f:
                jobs = set(f.read().splitlines())
            if job_file in jobs:
                jobs.remove(job_file)
                with open(failed_jobs_file, "w") as f:
                    for j in jobs:
                        f.write(f"{j}\n")
                if not jobs:  # Delete if all jobs are successful
                    failed_jobs_file.unlink()

        # Move metadata file to completed directory
        completed_dir = (
            metadata_file.parent.parent / "completed"
        )  # Always directly under parent
        completed_dir.mkdir(exist_ok=True)
        metadata_file.rename(completed_dir / metadata_file.name)

        return size, val_set, results

    elif batch_status.status == "failed":
        print(f"Batch {batch_id} failed: {batch_status.errors}")
        # Move metadata file to failed directory
        failed_dir = (
            metadata_file.parent.parent / "failed"
        )  # Always directly under parent
        failed_dir.mkdir(exist_ok=True)
        metadata_file.rename(failed_dir / metadata_file.name)

        # Add to failed_jobs.txt if not already there
        failed_jobs_file = Path(f"batch_jobs/{exp_name}/failed_jobs.txt")
        if failed_jobs_file.exists():
            with open(failed_jobs_file) as f:
                existing_jobs = set(f.read().splitlines())
        else:
            existing_jobs = set()

        if job_file not in existing_jobs:
            with open(failed_jobs_file, "a") as f:
                f.write(f"{job_file}\n")

        return None

    else:
        # Still processing or pending
        print(f"Batch {batch_id} is still {batch_status.status}")
        return None


def generate_table(results: Dict) -> tuple[str, str]:
    """Generate markdown and latex tables summarizing the results."""
    # Markdown table
    md_table = "| Model | Sample Size | Validation Set | Accuracy |\n"
    md_table += "|-------|-------------|----------------|----------|\n"

    # Latex table
    latex_table = "\\begin{table}[h]\n\\centering\n"
    latex_table += "\\begin{tabular}{|c|c|c|c|}\n\\hline\n"
    latex_table += "Model & Sample Size & Validation Set & Accuracy \\\\ \\hline\n"

    for model in sorted(results.keys()):
        for size in sorted(results[model].keys(), key=int):
            for val_set in sorted(results[model][size].keys()):
                result = results[model][size][val_set]

                # Format accuracy as percentage
                accuracy = f"{result['accuracy'] * 100:.1f}%"

                # Add row to markdown table
                md_table += f"| {model} | {size} | {val_set} | {accuracy} |\n"

                # Add row to latex table
                latex_table += (
                    f"{model} & {size} & {val_set} & {accuracy} \\\\ \\hline\n"
                )

    # Complete latex table
    latex_table += "\\end{tabular}\n"
    latex_table += "\\caption{Model Performance Comparison}\n"
    latex_table += "\\label{tab:model-performance}\n"
    latex_table += "\\end{table}"

    return md_table, latex_table

def parse_exp_name(input_dir_name: str) -> Tuple[str, str, str, str]:
    """
    Parse the input directory name to extract the model name, prompt name, substitution strategy, and seed.
    """
    model_name = input_dir_name.split("_")[0]
    prompt_name = "_".join(input_dir_name.split("_")[1:3])
    substitution_strategy = input_dir_name.split("_")[3]
    seed = input_dir_name.split("_")[4].replace("seed", "")
    sample_size = input_dir_name.split("_")[5].replace("sample", "")
    test_set = input_dir_name.split("_")[6]
    if "k=" in input_dir_name:
        k = input_dir_name.split("_")[7].replace("k=", "")
    else:
        k = None
    return (
        model_name,
        prompt_name,
        substitution_strategy,
        seed,
        sample_size,
        test_set,
        k,
    )

def map_substitution_strategy(substitution_strategy: str) -> str:
    """
    Map the substitution strategy to the new format.
    """
    if substitution_strategy == "many2one":
        return "many_to_one"
    elif substitution_strategy == "one2one":
        return "one_to_one"
    elif substitution_strategy == "one2many":
        return "one_to_many"
    elif substitution_strategy == "one2many-sym":
        return "one_to_many_sym"
    elif substitution_strategy == "one2one-sym":
        return "one_to_one_sym"
    else:
        raise ValueError(f"Invalid substitution strategy: {substitution_strategy}")


def main():
    parser = argparse.ArgumentParser(description="Process completed batch experiments")
    parser.add_argument(
        "--results_dir",
        default="results",
        help="Directory to save results (default: results)",
    )
    parser.add_argument("--exp_name", required=True, help="Experiment name")
    parser.add_argument("--prompt_type", required=True, help="Prompt type")
    args = parser.parse_args()

    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

    # Find all pending and failed metadata files (exclude only completed)
    batch_jobs_dir = Path(f"batch_jobs/{args.exp_name}")
    metadata_files = [
        f
        for f in batch_jobs_dir.glob("**/*_metadata.json")
        if f.parent.name != "completed"
    ]  # Include failed directory

    if not metadata_files:
        print("No pending or failed batch jobs found")
        return

    print(f"Found {len(metadata_files)} batch jobs to process")


    for metadata_file in metadata_files:
        status = "failed" if metadata_file.parent.name == "failed" else "pending"
        with open(metadata_file) as f:
            metadata = json.load(f)

        task = metadata["task"]
        model = metadata["model"]

        print(f"\nProcessing {status} batch {metadata['batch_id']}...")
        result = process_completed_batch(
            client, metadata_file, args.exp_name, args.prompt_type
        )

        if result:
            _, _, batch_results = result
            # Load existing results if any
            model_name, prompt_name, substitution_strategy, seed, sample_size, test_set, k = (
                parse_exp_name(args.exp_name)
            )

            if k is not None:
                output_base_dir = (
                    Path("results")
                    / model_name
                    / prompt_name
                    / map_substitution_strategy(substitution_strategy)
                    / f"k={k}"
                )
            else:
                output_base_dir = (
                    Path("results")
                    / model_name
                    / prompt_name
                    / map_substitution_strategy(substitution_strategy)
                )

            output_file = (
                output_base_dir / task / f"{sample_size}_{test_set}_{seed}.json"
            )
            output_file.parent.mkdir(parents=True, exist_ok=True)
            # Save updated results
            with open(output_file, "w") as f:
                json.dump(batch_results, f, indent=2)
            print(f"Results saved to {output_file}")



    print("\nProcessing completed!")


if __name__ == "__main__":
    main()
