import argparse
import random
import os
from typing import List, Tuple, Dict
from openai import OpenAI
from pathlib import Path
import json
import time
import sys
from dotenv import load_dotenv
from tqdm import tqdm
from src.openai.substitute_tokens import (
    create_modified_prompt,
    create_token_substitution,
    apply_token_mapping,
    get_unique_symbols,
    create_formal_language_prompt,
)

# Load environment variables
load_dotenv()


def load_train_data(data_dir: str, task: str) -> Tuple[List[str], List[int]]:
    """Load training data and labels for a given task."""
    base_path = Path(f"{data_dir}/{task}")

    with open(base_path / "data.train") as f:
        sequences = [line.strip() for line in f]

    with open(base_path / "labels.train") as f:
        labels = [int(line.strip()) for line in f]

    return sequences, labels


def load_validation_data(
    data_dir: str, task: str, validation_type: str
) -> Tuple[List[str], List[int]]:
    """Load validation data (long or short) for a given task."""
    base_path = Path(f"{data_dir}/{task}/")

    with open(base_path / f"data.{validation_type}") as f:
        sequences = [line.strip() for line in f]

    with open(base_path / f"labels.{validation_type}") as f:
        labels = [int(line.strip()) for line in f]

    return sequences, labels


def create_prompt(
    base_prompt: str, examples: List[Tuple[str, int]], test_sequence: str
) -> str:
    """Create a prompt using the template file."""
    examples_str = ""
    for seq, label in examples:
        examples_str += f"String: {seq}\nLabel: {label}\n\n"

    prompt = base_prompt.replace("{examples}", examples_str)
    prompt = prompt.replace("{test_sequence}", test_sequence)

    return prompt


def print_label_stats(data: List[Tuple[str, int]], name: str = ""):
    """Print statistics about label distribution in the data."""
    total = len(data)
    if total == 0:
        return

    labels = [label for label, _ in data]
    num_label_0 = sum(1 for label in labels if label == 0)
    num_label_1 = sum(1 for label in labels if label == 1)

    print(f"\n{name} statistics:")
    print(f"Total examples: {total}")
    print(f"Label 0: {num_label_0} ({num_label_0/total*100:.1f}%)")
    print(f"Label 1: {num_label_1} ({num_label_1/total*100:.1f}%)")


def create_batch_requests(
    examples: List[Tuple[str, int]],
    val_sequences: List[str],
    model: str,
    prompt_before_exemplars: str,
    prompt_after_exemplars: str,
    substitution_strategy: str,
    mapping: Dict[str, str] | None = None,
    one2many_timestep: int | None = None,
) -> List[dict]:
    """Create batch input requests for OpenAI's Batch API."""
    print(f"Creating batch requests for model: {model}")  # For debugging
    batch_requests = []

    if mapping is not None:
        modified_prompt = create_modified_prompt(
            prompt_before_exemplars,
            prompt_after_exemplars,
            examples,
            substitution_strategy,
            mapping,
            model_name=model,
            one2many_timestep=one2many_timestep,
        )
    else:
        modified_prompt = create_formal_language_prompt(
            prompt_before_exemplars, prompt_after_exemplars, examples
        )

    for idx, test_seq in enumerate(val_sequences):
        if mapping is not None:
            prompt = modified_prompt + apply_token_mapping(
                test_seq,
                substitution_strategy,
                mapping,
                model_name=model,
                one2many_timestep=one2many_timestep,
            )
        else:
            prompt = modified_prompt + test_seq
        if idx == 0:
            print(f"prompt_before_exemplars: {prompt_before_exemplars}")
            print(f"prompt_after_exemplars: {prompt_after_exemplars}")
            print(f"prompt: {prompt}")
        request = {
            "custom_id": f"request-{idx}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": model,  # Verify that the model name is set correctly here
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0,
            },
        }
        # For debugging: Check if model name is set correctly
        if request["body"]["model"] != model:
            print(
                f"Warning: Model mismatch in request {idx}. Expected {model}, got {request['body']['model']}"
            )
        batch_requests.append(request)

    return batch_requests


def run_batch_validation(
    client: OpenAI,
    examples: List[Tuple[str, int]],
    val_sequences: List[str],
    val_labels: List[int],
    model: str,
    prompt_before_exemplars: str,
    prompt_after_exemplars: str,
    substitution_strategy: str,
    task: str,
    size: str,
    mapping: Dict[str, str],
    exp_name: str = "default",
    validation_info: Dict = None,
    one2many_timestep: int | None = None,
) -> Dict:
    """Run validation using Batch API and return batch ID."""
    print(f"Creating batch validation for model: {model}")
    batch_requests = create_batch_requests(
        examples,
        val_sequences,
        model,
        prompt_before_exemplars,
        prompt_after_exemplars,
        substitution_strategy,
        mapping,
        one2many_timestep,
    )

    # Check model name consistency
    for req in batch_requests:
        if req["body"]["model"] != model:
            print(
                f"Warning: Model mismatch in batch request. Expected {model}, got {req['body']['model']}"
            )

    # Save requests to jsonl file
    batch_input_file = f"batch_input_{task}_{size}_{model}.jsonl"  # Include model name

    # Create directory to save batch requests
    batch_requests_dir = Path("batch_requests") / exp_name
    batch_requests_dir.mkdir(parents=True, exist_ok=True)

    # Create both temporary file and backup file
    temp_input_file = batch_input_file
    saved_input_file = (
        batch_requests_dir
        / f"{task}_{size}_{model}_{time.strftime('%Y%m%d_%H%M%S')}.jsonl"
    )

    with open(temp_input_file, "w") as f:
        for request in batch_requests:
            f.write(json.dumps(request) + "\n")

    # Backup copy
    import shutil

    shutil.copy2(temp_input_file, saved_input_file)

    try:
        # Upload file
        print("Uploading batch file...")
        file = client.files.create(file=open(temp_input_file, "rb"), purpose="batch")

        # Create batch
        print("Creating batch job...")
        batch = client.batches.create(
            input_file_id=file.id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
        )

        print(f"Created batch {batch.id}")

        # Create metadata
        metadata = {
            "batch_id": batch.id,
            "task": task,
            "size": size,
            "model": model,
            "examples": [{"sequence": seq, "label": label} for label, seq in examples],
            "validation_data": {"sequences": val_sequences, "labels": val_labels},
            "status": "pending",
            "validation_info": validation_info,
        }

        # Save metadata
        metadata_file = (
            Path("batch_jobs") / exp_name / f"{task}_{size}_{batch.id}_metadata.json"
        )
        metadata_file.parent.mkdir(exist_ok=True)
        with open(metadata_file, "w") as f:
            json.dump(metadata, f, indent=2)

        result = {
            "batch_id": batch.id,
            "metadata_file": str(metadata_file),
            "status": "submitted",
        }

        # Cleanup only temp file
        os.remove(temp_input_file)

        return result

    except Exception as e:
        print(f"Error in batch processing: {e}")
        if os.path.exists(temp_input_file):
            os.remove(temp_input_file)
        raise


def main():
    parser = argparse.ArgumentParser(
        description="Batch experiment with LLMs on formal languages"
    )
    parser.add_argument("--data_dir", required=True, help="Path to the data directory")
    parser.add_argument("--task", required=True, help="Formal language task name")
    parser.add_argument("--model", required=True, help="OpenAI model name")
    parser.add_argument(
        "--sample_sizes",
        type=str,
        default="100",
        help="Comma-separated list of sample sizes",
    )
    parser.add_argument(
        "--test_sets",
        type=str,
        default="test",
        help="Comma-separated list of test sets (val-short, val-long, test)",
    )
    parser.add_argument(
        "--prompt_before_exemplars",
        required=True,
        help="Path to prompt template file before the examples",
    )
    parser.add_argument(
        "--prompt_after_exemplars",
        required=True,
        help="Path to prompt template file after the examples",
    )
    parser.add_argument(
        "--results_dir",
        default="results",
        help="Directory to save results (default: results)",
    )
    parser.add_argument(
        "--exp_name", type=str, default="default", help="Experiment name"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    parser.add_argument(
        "--token_substitution_strategy",
        type=str,
        default="many2one",
        help="Token substitution strategy (many2one, one2one,one2many, one2many-sym, one2one-sym)",
    )
    parser.add_argument(
        "--one2many_timestep", type=int, default=5, help="One2many timestep"
    )

    args = parser.parse_args()
    try:
        with open(args.prompt_before_exemplars, "r") as f:
            prompt_before_exemplars = f.read()
        with open(args.prompt_after_exemplars, "r") as f:
            prompt_after_exemplars = f.read()
    except FileNotFoundError:
        print(
            f"Error: Prompt template file '{args.prompt_before_exemplars}' or '{args.prompt_after_exemplars}' not found"
        )
        sys.exit(1)

    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OPENAI_API_KEY not found in environment variables")
        sys.exit(1)

    client = OpenAI(api_key=api_key)
    sample_sizes = [int(size) for size in args.sample_sizes.split(",")]
    test_sets = args.test_sets.split(",")

    train_sequences, train_labels = load_train_data(args.data_dir, args.task)
    train_data = list(zip(train_labels, train_sequences))
    print_label_stats(train_data, "Full training data")

    for size in sample_sizes:
        print(f"\nTesting with sample size: {size}")

        examples = [(label, seq) for label, seq in train_data[:size]]
        print_label_stats(examples, f"Training examples (size {size})")

        # Create separate batch job for each validation set
        for test_set in test_sets:
            print(f"\nProcessing {test_set} dataset")
            val_sequences, val_labels = load_validation_data(
                args.data_dir, args.task, test_set
            )

            unique_symbols = get_unique_symbols(examples, val_sequences)
            mapping = create_token_substitution(
                original_symbols=unique_symbols,
                prompt_before_exemplars=prompt_before_exemplars,
                prompt_after_exemplars=prompt_after_exemplars,
                substitution_strategy=args.token_substitution_strategy,
                model_name=args.model,
                seed=args.seed,
            )
            # Create separate batch job for each validation set
            batch_info = run_batch_validation(
                client=client,
                examples=examples,
                val_sequences=val_sequences,
                val_labels=val_labels,
                model=args.model,
                prompt_before_exemplars=prompt_before_exemplars,
                prompt_after_exemplars=prompt_after_exemplars,
                substitution_strategy=args.token_substitution_strategy,
                task=args.task,
                size=f"{size}_{test_set}",  # Combine size and validation set uniquely
                exp_name=args.exp_name,
                mapping=mapping,
                one2many_timestep=args.one2many_timestep,
            )

            print(f"Created batch job: {batch_info['batch_id']}")
            print(f"Metadata saved to: {batch_info['metadata_file']}")


if __name__ == "__main__":
    main()
