#################################################################
###         Demo for Sailence Aware Mark-Steered Prompting (MSP)         ###
#################################################################

import argparse
import json
import os
import sys
import time
import gzip
from typing import Iterable, Dict

import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

# Assume these are part of your MSP library
# You may need to adjust these import paths based on your project structure.
from msp import msp_tokenize, MSPLogitsProcessor
from attr import Analyse

def load_tasks_from_jsonl(jsonl_file):
    tasks = []
    with open(jsonl_file, 'r') as file:
        for line in file:
            tasks.append(json.loads(line))
    return tasks

def filter_code(completion: str) -> str:
    """
    Cleans the generated code completion by taking only the first code block.
    """
    completion = completion.lstrip("\n")
    return completion.split("\n\n")[0]


def fix_indents(text: str) -> str:
    """
    Replaces tabs with four spaces for consistent indentation.
    """
    return text.replace("\t", "    ")


def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
    """
    Writes an iterable of dictionaries to a JSONL file, with optional gzip compression.
    """
    mode = 'ab' if append else 'wb'
    filename = os.path.expanduser(filename)
    if filename.endswith(".gz"):
        with open(filename, mode) as fp:
            with gzip.GzipFile(fileobj=fp, mode='wb') as gzfp:
                for x in data:
                    gzfp.write((json.dumps(x) + "\n").encode('utf-8'))
    else:
        with open(filename, mode) as fp:
            for x in data:
                fp.write((json.dumps(x) + "\n").encode('utf-8'))


def tune_msp(
        model_name: str,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        device: str,
        masking_strength: float,
        benchmark_name: str,
        benchmark_path: str,
        experiment_name: str,
        k: int,
        log_dir: str,
        generate_data_dir: str,
        modulated_by_prob: bool = True,
        use_attention_mask: bool = True
):
    """
    Runs the MSP generation process for a given set of hyperparameters.
    """
    print("-" * 50)
    print(f"Model name: {model_name}")
    print(f"Device: {device}")
    print(f"Benchmark: {benchmark_name}")
    print(f"K: {k}")
    print(f"Masking Strength: {masking_strength}")
    print(f"Modulated by probability: {modulated_by_prob}")
    print(f"Use attention mask: {use_attention_mask}\n")
    print("-" * 50)

    # Create output directories if they don't exist
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(generate_data_dir, exist_ok=True)

    model_name_short = model_name.split('/')[-1]
    log_file = os.path.join(log_dir, f'{experiment_name}_{benchmark_name}_{masking_strength}w_{model_name_short}.txt')
    generate_file = os.path.join(generate_data_dir,
                                 f'{experiment_name}_{benchmark_name}_{masking_strength}w_{model_name_short}.jsonl')

    # Resume from previous run by checking existing task_ids
    existing_task_ids = set()
    if os.path.exists(generate_file):
        with open(generate_file, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    record = json.loads(line)
                    if "task_id" in record:
                        existing_task_ids.add(record["task_id"])
                except json.JSONDecodeError:
                    continue  # Ignore corrupted lines

    print(f"Found {len(existing_task_ids)} completed tasks. Skipping them.")

    tasks = load_tasks_from_jsonl(benchmark_path)

    for data in tqdm(tasks, desc=f"Processing {benchmark_name}"):
        task_id = data['task_id']
        if task_id in existing_task_ids:
            continue

        prompt = data['prompt'] if benchmark_name == "humaneval" else data['info_prompt']

        with open(log_file, 'a', encoding='utf-8') as f:
            f.write('\n\n' + '-' * 50 + task_id + '-' * 50 + '\n\n')
            f.write(prompt)

        while True:  # Retry loop for handling potential errors during generation
            try:
                if benchmark_name == "humaneval":
                    # Analyze the prompt to find key tokens (marks)
                    marks = Analyse(prompt=prompt, model=model, tokenizer=tokenizer, device=device, max_iter=10, k=k,
                                      target=None)
                    marks = [tok.lstrip(" ") for tok in marks if tok.strip() != ""]
                elif benchmark_name == "mbpp":
                    pure_prompt = data['prompt']
                    marks = Analyse(prompt=pure_prompt, model=model, tokenizer=tokenizer, device=device, max_iter=3,
                                      k=k, target=None)
                    # The "Ġ" character often represents a space in tokenizers like GPT-2's
                    marks = [tok.replace("Ġ", "").lstrip(" ") for tok in marks if tok.strip() != ""]
                else:
                    raise ValueError(f"Invalid benchmark: {benchmark_name}")

                print(f"\nIdentified marks for {task_id}: {marks}")
                with open(log_file, 'a', encoding='utf-8') as f:
                    f.write('\n\n' + '-' * 50 + 'marks' + '-' * 50 + '\n\n')
                    f.write(str(marks))

                prompt_with_marks = prompt
                global_marks = marks

                # Tokenize the input for main and auxiliary models
                main_inputs, aux_inputs, mask_token = msp_tokenize(
                    prompt_with_marks, global_marks, tokenizer, device, log_file=log_file
                )

                # Create MSP logits processor
                msp_processor = MSPLogitsProcessor(
                    aux_model=model,
                    aux_input_ids=aux_inputs,
                    strength=masking_strength,
                    modulated_by_prob=modulated_by_prob,
                    use_attention_mask=use_attention_mask,
                    mask_token=mask_token,
                    tokenizer=tokenizer,
                )

                start_time = time.time()

                # Generate the output
                output_sequences = model.generate(
                    input_ids=main_inputs,
                    logits_processor=[msp_processor],
                    # Standard Hugging Face generation parameters
                    max_new_tokens=512,
                    do_sample=True,
                    top_p=0.95,
                    temperature=0.2,
                    pad_token_id=tokenizer.eos_token_id,
                )

                end_time = time.time()

                # Decode and extract the newly generated text
                input_text = tokenizer.decode(main_inputs[0], skip_special_tokens=False)
                generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=False)
                new_tokens_text = generated_text[len(input_text):]
                completion = filter_code(fix_indents(new_tokens_text))

                # Log generation statistics
                num_new_tokens = output_sequences.shape[1] - main_inputs.shape[1]
                tokens_per_second = num_new_tokens / (end_time - start_time) if end_time > start_time else 0
                generation_info = (
                    f"\nTime taken: {end_time - start_time:.2f} seconds\n"
                    f"New tokens generated: {num_new_tokens}\n"
                    f"Generation speed: {tokens_per_second:.2f} tokens/second\n"
                )
                with open(log_file, 'a', encoding='utf-8') as f:
                    f.write('\n\n' + '-' * 50 + 'generation_info' + '-' * 50 + '\n\n')
                    f.write(generation_info)
                    f.write('\n\n' + '-' * 50 + 'completion' + '-' * 50 + '\n\n')
                    f.write(completion)

                # Save the result
                result = {"task_id": task_id, "completion": completion}
                write_jsonl(generate_file, [result], append=True)

            except Exception as e:
                print(f"An error occurred while processing {task_id}: {e}")
                print("Retrying...")
                time.sleep(5)  # Wait before retrying
                continue

            break  # Exit retry loop on success


def main():
    parser = argparse.ArgumentParser(description="Run MSP tuning experiments.")
    parser.add_argument('--model_name_or_path', type=str, required=True,
                        help="Path to the pretrained model or model identifier from huggingface.co/models")
    parser.add_argument('--benchmark_name', type=str, required=True, choices=['humaneval', 'mbpp'],
                        help="Name of the benchmark to use.")
    parser.add_argument('--benchmark_path', type=str, required=True,
                        help="Path to the benchmark dataset file (e.g., HumanEval.jsonl).")
    parser.add_argument('--output_dir', type=str, default='./output', help="Directory to save logs and generated data.")
    parser.add_argument('--experiment_name', type=str, default='test01',
                        help="A name for the experiment, used in output file names.")

    # Hyperparameter ranges
    parser.add_argument('--k_start', type=int, default=3, help="Starting value for k.")
    parser.add_argument('--k_end', type=int, default=5, help="Ending value for k (exclusive).")
    parser.add_argument('--strength_start', type=float, default=1.1, help="Starting value for masking strength.")
    parser.add_argument('--strength_end', type=float, default=2.1, help="Ending value for masking strength.")
    parser.add_argument('--strength_step', type=float, default=0.1, help="Step size for masking strength.")

    # Boolean flags
    parser.add_argument('--no_prob_modulation', action='store_true',
                        help="Disable probability modulation for masking strength.")
    parser.add_argument('--no_attention_mask', action='store_true', help="Disable use of attention mask.")

    args = parser.parse_args()

    # Determine device
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print(f"Using device: {device}")

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        device_map="auto",
        torch_dtype=torch.float16,  # Recommended for faster inference
    )

    # Main hyperparameter tuning loop
    for k in tqdm(range(args.k_start, args.k_end), desc="Tuning MSP - k"):
        model_name_str = args.model_name_or_path.split('/')[-1]

        # Define specific directories for this k value
        log_dir = os.path.join(args.output_dir, 'logs', model_name_str, f'k_{k}')
        generate_data_dir = os.path.join(args.output_dir, 'generated_data', model_name_str, f'k_{k}')

        # Create a dynamic name for this run based on k
        dynamic_experiment_name = f"{args.experiment_name}_k_{k}"

        strength_range = np.arange(args.strength_start, args.strength_end, args.strength_step)
        for strength in tqdm(strength_range, desc=f"Strength loop (k={k})", leave=False):
            tune_msp(
                model_name=args.model_name_or_path,
                model=model,
                tokenizer=tokenizer,
                device=device,
                masking_strength=round(strength, 2),  # Round to avoid float precision issues
                benchmark_name=args.benchmark_name,
                benchmark_path=args.benchmark_path,
                experiment_name=dynamic_experiment_name,
                k=k,
                log_dir=log_dir,
                generate_data_dir=generate_data_dir,
                modulated_by_prob=not args.no_prob_modulation,
                use_attention_mask=not args.no_attention_mask
            )


if __name__ == "__main__":
    main()