import argparse
import random
import os
from typing import List, Tuple, Dict
from openai import OpenAI, AsyncOpenAI
import asyncio
from pathlib import Path
import json
import time
import sys
from dotenv import load_dotenv
from tqdm import tqdm
from src.deepseek.substitute_tokens import (
    create_modified_prompt,
    create_token_substitution,
    apply_token_mapping,
    get_unique_symbols,
    create_formal_language_prompt,
    print_token_mapping,
)
from src.deepseek.utils import print_label_stats, last_binary_digit
import transformers

# Load environment variables
load_dotenv()
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
CONCURRENCY_LIMIT = 100  # Adjust this based on API performance


def load_data(language: str, train_test: str) -> Tuple[List[str], List[int]]:
    """Load training data and labels for a given task."""
    base_path = Path(f"data/flare_subsampled/{language}/")
    with open(base_path / f"data.{train_test}") as f:
        sequences = [line.strip() for line in f]
    with open(base_path / f"labels.{train_test}") as f:
        labels = [int(line.strip()) for line in f]
    return sequences, labels


def sample_balanced_data(
    sequences: List[str], labels: List[int], size: int
) -> List[Tuple[str, int]]:
    """Return the data untouched, preserving original order."""
    return list(zip(sequences, labels))


def create_prompt(
    examples: List[Tuple[str, int]],
    val_sequence: List[str],
    prompt_template: str,
    mapping: Dict[str, str] = None,
    encoding_technique: str = "none",
    k: int = 2,
) -> Dict:
    """Create the requests for the OpenAI API."""

    if encoding_technique == "many_to_one":
        modified_prompt = create_modified_prompt(
            prompt_template, examples, mapping, encoding_technique, 1
        )
        prompt = modified_prompt.replace(
            "test_string",
            apply_token_mapping(
                val_sequence, mapping, encoding_technique=encoding_technique, k=k
            ),
        )
    else:
        if encoding_technique == "one_to_one":
            k = 1
        if mapping is not None:
            modified_prompt = create_modified_prompt(
                prompt_template, examples, mapping, encoding_technique, k
            )
            prompt = modified_prompt.replace(
                "test_string",
                apply_token_mapping(
                    val_sequence, mapping, encoding_technique=encoding_technique, k=k
                ),
            )
        else:
            modified_prompt = create_formal_language_prompt(prompt_template, examples)
            prompt = modified_prompt.replace("test_string", val_sequence)
    return prompt


def compute_accuracy(raw_results) -> float:
    """Compute accuracy from the log file."""
    correct = 0
    total = 0
    for entry in raw_results:
        if entry["model_output"] == entry["true_label"]:
            correct += 1
        total += 1
    return correct / total


# write the result of a single prompt to the file
async def write_response_to_file(file_path, response):
    """Open the JSON file, append response to `raw_results`, and save."""
    async with asyncio.Lock():  # Ensures multiple writes don't clash
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                data = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError):
            print(f"Error reading file: {file_path}")
            return

        data["raw_results"].append(response)

        with open(file_path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=4)


def output_decoder(output, prompt_template):
    """Decode the model output based on the prompt template."""
    if prompt_template == "io_prompt":
        try:
            return int(output.replace(".", ""))
        except ValueError:
            print(f"Error decoding output: {output}")
            return -1
    elif prompt_template == "zsr_prompt":
        print(f"Output: {output}")
        return int(last_binary_digit(output))
    else:
        return int(last_binary_digit(output))


async def fetch(client, model, semaphore, path, data, prompt, prompt_template):
    """Send an async API request to DeepSeek-Chat model with concurrency control."""
    async with semaphore:  # Limit concurrent requests
        seq, label = data
        response = await client.chat.completions.create(
            model=model, messages=[{"role": "user", "content": prompt}], stream=False
        )

        try:
            res = {
                "sequence": seq,
                "raw_response": response.choices[0].message.content,
                "model_output": (
                    int(
                        output_decoder(
                            response.choices[0].message.content, prompt_template
                        )
                    )
                    if response.choices[0].message.content is not None
                    else -1
                ),
                "true_label": label,
            }
            # curr_time = time.time()
            await write_response_to_file(path, res)
            # print(f"Wrote response for {seq} to file in {time.time() - curr_time:.2f} seconds")
            return res
        except Exception as e:
            print(f"❌ Error processing sequence {seq}: {e}")
            res = {"sequence": seq, "model_output": -1, "true_label": label}
            await write_response_to_file(path, res)
            return res


async def batch_inference(file_path, prompts, model, prompt_template):
    """Perform batch inference using controlled concurrency."""
    start_time = time.time()
    semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT)  # Limit concurrency
    async with AsyncOpenAI(
        api_key=DEEPSEEK_API_KEY,
        base_url="https://api.deepseek.com",
    ) as client:
        tasks = [
            asyncio.create_task(
                fetch(
                    client, model, semaphore, file_path, data, prompt, prompt_template
                )
            )
            for data, prompt in prompts
        ]
        results = await asyncio.gather(*tasks)
    end_time = time.time()
    print(f"Batch inference completed in {end_time - start_time:.2f} seconds")
    return results


def run_validation(
    examples: List[Tuple[str, int]],
    val_sequences: List[str],
    val_labels: List[int],
    model: str,
    prompt_template: str,
    mapping: Dict[str, str] = None,
    file_path: str = "default",
    encoding_technique: str = "none",
    k: int = 2,
) -> Dict:
    """Run validation on the OpenAI API."""

    print(f"Creating validation for model: {model}")

    # Load the results file
    with open(file_path, "r") as file:
        results = json.load(file)

    # Extract sequences from raw_results
    raw_sequences = {entry["sequence"] for entry in results["raw_results"]}
    all_requests = []
    for i, (seq, label) in enumerate(zip(val_sequences, val_labels)):
        if seq in raw_sequences:
            print(f"Skipping sequence {i}: {seq}")
            continue
        all_requests.append(
            (
                (seq, label),
                create_prompt(
                    examples, seq, prompt_template, mapping, encoding_technique, k
                ),
            )
        )
    asyncio.run(batch_inference(file_path, all_requests, model, prompt_template))

    # Load the file after the calls
    with open(file_path, "r") as file:
        results = json.load(file)

    results["validation_size"] = len(results["raw_results"])
    results["accuracy"] = compute_accuracy(results["raw_results"])

    # Save the results to the file
    with open(file_path, "w") as file:
        json.dump(results, file, indent=4)

    return None


# python -m scripts.deepseek_experiment --task "binary-addition" --model "deepseek-chat" --sample_sizes "1" --validation_sets "long" --base_prompt "base_prompt.txt" --validation_size 2 --balance_validation --results_dir "results" --min_long_validation_length 2 --exp_name "palindrome_test" --seed 42 --substitute_tokens
def main():
    parser = argparse.ArgumentParser(
        description="Experiment with LLMs on formal languages"
    )
    parser.add_argument("--language", required=True, help="Formal language task name")
    parser.add_argument("--model", required=True, help="Model name")
    parser.add_argument(
        "--prompt_template", required=True, help="Type of prompt template"
    )
    parser.add_argument(
        "--encoding_technique", required=True, help="Encoding technique"
    )
    parser.add_argument(
        "--tokenizer_path", type=str, default="", help="Path to tokenizer"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    parser.add_argument(
        "--k", type=int, default=2, help="K for one-to-many encoding technique"
    )
    parser.add_argument(
        "--base_path", type=str, default="results", help="Base path for results"
    )
    parser.add_argument(
        "--num_test_samples", type=int, default=100, help="Number of test samples"
    )
    
    args = parser.parse_args()

    prompt_template_path = (
        "prompts/io_prompt_template.txt"
        if args.prompt_template == "io_prompt"
        else "prompts/zsr_prompt_template.txt"
    )
    try:
        with open(prompt_template_path, "r") as f:
            base_prompt = f.read()
    except FileNotFoundError:
        print(f"Error: Prompt template file '{prompt_template_path}' not found")
        sys.exit(1)

    DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")

    if not DEEPSEEK_API_KEY:
        print("Error: DEEPSEEK_API_KEY not found in environment variables")
        sys.exit(1)

    train_sequences, train_labels = load_data(args.language, "val-long")
    
    examples = list(zip(train_sequences, train_labels))
    print_label_stats(examples, "Full training data")

    results = {}

    print(
        f"Processing language {args.language} for model {args.model} with prompt template {args.prompt_template} and encoding technique {args.encoding_technique}"
    )

    file_path = (
        f"{args.base_path}/{args.model}/{args.prompt_template}/{args.encoding_technique}/k={args.k}/{args.language}/100_long_{args.seed}.json"
        if args.encoding_technique == "one_to_many"
        else f"{args.base_path}/{args.model}/{args.prompt_template}/{args.encoding_technique}/{args.language}/100_long_{args.seed}.json"
    )
    if os.path.exists(file_path):
        # Read the json file
        with open(file_path, "r") as file:
            results = json.load(file)
            if results["accuracy"] != -1:
                print(
                    f"Skipping model {args.model} and language {args.language} and prompt template {args.prompt_template} and encoding technique {args.encoding_technique}: Already processed with accuracy {results['accuracy']}"
                )
                return

    try:
        # Create separate batch job for each validation set

        val_sequences, val_labels = load_data(args.language, "test")
        val_sequences = val_sequences[: args.num_test_samples]
        val_labels = val_labels[: args.num_test_samples]
        
        unique_symbols = get_unique_symbols(examples, val_sequences)

        mapping = create_token_substitution(
            encoding_technique=args.encoding_technique,
            symbols=unique_symbols,
            prompt_text=base_prompt,
            tokenizer_path=args.tokenizer_path,
            model_name=args.model,
            seed=args.seed,
            verbose=False,
        )

        metrics_dict = {
            "accuracy": -1,
            "encoding_techinique": args.encoding_technique,
            "prompt_template": args.prompt_template,
            "raw_results": [],
            "examples_used": examples,
            "model": args.model,
            "validation_size": len(val_sequences),
            "validation_set": val_sequences,
            "label_distribution": {
                "training": {
                    "label_0": sum(1 for ex in examples if ex[1] == 0),
                    "label_1": sum(1 for ex in examples if ex[1] == 1),
                },
                "validation": {
                    "label_0": sum(1 for label in val_labels if label == 0),
                    "label_1": sum(1 for label in val_labels if label == 1),
                },
            },
        }

        # Check if the file already exists
        if not os.path.exists(file_path):
            with open(file_path, "w") as f:
                json.dump(metrics_dict, f, indent=4)
            print(f"File created: {file_path}")
        else:
            print(f"File already exists, not overwriting: {file_path}")

        # Create separate batch job for each validation set
        if args.model == "deepseek-chat":
            try:
                results = run_validation(
                    examples=examples,
                    val_sequences=val_sequences,
                    val_labels=val_labels,
                    model=args.model,
                    prompt_template=args.prompt_template,
                    mapping=mapping,
                    file_path=file_path,
                    encoding_technique=args.encoding_technique,
                    k=args.k,
                )
            except Exception as e:
                print(
                    f"Error processing language {args.language} for {args.model}: {e}"
                )
        else:
            print(f"Error: Model {args.model} not supported")
            sys.exit(1)
    except Exception as e:
        print(f"Error processing : {e}")


if __name__ == "__main__":
    main()
