from dotenv import load_dotenv
load_dotenv()

import os
import sys
sys.path.append('.')

import json
import argparse
import multiprocessing as mp
import numpy as np
import torch
import random
from dotenv import load_dotenv
from datasets import Dataset
from huggingface_hub import create_repo, HfFolder
from vllm import LLM
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeout
from collections import defaultdict
from functools import partial
from tqdm import tqdm

from examples.datasets import load_dataset_and_prompt_handlers
from examples.mutator.codegen import mutator_engine
from examples.solver.codegen import solver_engine
from examples.verdict.codegen import verdict_engine
from src.agent import Agent
from src.verdict import Verdict
from src.core.sample import sample_mutations, sample_solutions_data_gen
from src.core.reward import generate_solver_results, compute_verdict_scores
from src.utils.load_vllm_model import get_vllm_models
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(description="Generate and validate mutations/solutions using Mutator, Solver, and Verdict models.")

    # Mode selection
    parser.add_argument("--generation_type", type=str, default="mutations", choices=["mutations", "solutions"], help="What to generate")
    parser.add_argument("--mode", type=str, required=True, choices=['generate', 'validate'], help="Whether to generate mutations or validate existing ones")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")

    # Model arguments
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct", help="Name of the base model")
    parser.add_argument("--solver_model_name", type=str, help="Name of the solver model (defaults to model_name if not specified)")
    parser.add_argument("--mutator_peft_dir", type=str, help="Path to mutator PEFT/LoRA checkpoint directory")
    parser.add_argument("--solver_peft_dir", type=str, help="Path to solver PEFT/LoRA checkpoint directory")

    # Common arguments
    parser.add_argument("--splits_file", type=str, default=None, help="Path to the splits file.")
    parser.add_argument("--dataset", type=str, default="bigcodebench-complete", help="Name of the dataset.")
    parser.add_argument("--max_model_len", type=int, default=1024, help="Maximum number of input tokens for the model.")
    parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of new generated tokens.")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.4, help="GPU memory utilization for the model.")
    parser.add_argument("--enforce_eager", action="store_true", help="Whether to enforce eager mode.")
    parser.add_argument("--mutation_subparts_to_generate", type=str, nargs='+', default=["mutations", "solver_results"], help="Which parts to generate. Provide one or more of: mutations, solver_results, etc.")

    # Attacker (mutator) generation parameters
    parser.add_argument("--attacker_temperature", type=float, default=0.7, help="Temperature for attacker generation.")
    parser.add_argument("--attacker_top_p", type=float, default=1.0, help="Top-p for attacker generation.")
    parser.add_argument("--attacker_top_k", type=int, default=-1, help="Top-k for attacker generation.")
    parser.add_argument("--attacker_repetition_penalty", type=float, default=1.0, help="Repetition penalty for attacker generation.")

    # Solver generation parameters
    parser.add_argument("--solver_temperature", type=float, default=0.7, help="Temperature for solver generation.")
    parser.add_argument("--solver_top_p", type=float, default=1.0, help="Top-p for solver generation.")
    parser.add_argument("--solver_top_k", type=int, default=-1, help="Top-k for solver generation.")
    parser.add_argument("--solver_repetition_penalty", type=float, default=1.0, help="Repetition penalty for solver generation.")

    # Generation mode arguments
    generation_group = parser.add_argument_group('Generation options')
    generation_group.add_argument("--num_mutator_iters", type=int, default=3, help="Number of mutation samples per problem.")
    generation_group.add_argument("--num_solver_iters", type=int, default=3, help="Number of solver samples per problem.")
    generation_group.add_argument("--save_dir", type=str, default="data/mutations", help="Directory to save generated results")

    # Validation mode arguments
    validation_group = parser.add_argument_group('Validation options')
    validation_group.add_argument("--mutations_path", type=str, help="Path to saved mutations JSON file")
    validation_group.add_argument("--solver_results_path", type=str, help="Path to saved solver results JSON file")
    validation_group.add_argument("--save_path", type=str, default="data/mutator_solver_verdict", help="Path to save the validated dataset")
    validation_group.add_argument("--hf_repo_name", type=str, default="", help="Name of the Hugging Face repository to push the dataset")
    validation_group.add_argument("--push_to_hub", action="store_true", help="Whether to push the dataset to Hugging Face Hub")

    # Debug/testing arguments
    debug_group = parser.add_argument_group('Debug options')
    debug_group.add_argument("--debug", action="store_true", help="Run with a small subset of data for testing")
    debug_group.add_argument("--debug_samples", type=int, default=1, help="Number of samples to use in debug mode")

    # Workers
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to use for parallel processing")

    args = parser.parse_args()

    # Set solver_model_name to model_name if not specified
    if not args.solver_model_name:
        args.solver_model_name = args.model_name

    # Validate arguments based on mode
    if args.mode == 'validate':
        if not args.mutations_path and not args.solver_results_path:
            parser.error("--mutations_path or --solver_results_path is required for validate mode")
        if not args.solver_results_path:
            parser.error("--solver_results_path is required for validate mode")
        if args.push_to_hub and not args.hf_repo_name:
            parser.error("--hf_repo_name is required when --push_to_hub is set")

    return args


def generate_mutations(
    dataset: list, 
    mutator: Agent, 
    solver: Agent,
    num_mutator_iters: int = 10,
    num_solver_iters: int = 10,
    save_dir: str = 'data/mutations',
    mutation_subparts_to_generate: list[str] = ['mutations', 'solver_results']
):
    """Generate mutations and solver results, saving them locally."""
    os.makedirs(save_dir, exist_ok=True)
    mutations = None
    if "mutations" in mutation_subparts_to_generate:
        print(f"Generating mutations for {len(dataset)} problems")
        mutations, _, _ = sample_mutations(mutator, dataset, num_mutator_iters)
        with open(os.path.join(save_dir, 'mutations.json'), 'w') as f:
            json.dump(mutations, f)
        print(f"Saved {len(mutations)} mutations to {save_dir}")
    
    if "solver_results" in mutation_subparts_to_generate:
        if mutations is None:
            with open(os.path.join(save_dir, 'mutations.json'), 'r') as f:
                mutations = json.load(f)
        print(f"Generating solver results for {len(mutations)} mutations")
        solver_results = generate_solver_results(mutations, solver, num_solver_iters)
        with open(os.path.join(save_dir, 'solver_results.json'), 'w') as f:
            json.dump(solver_results, f)
        print(f"Saved {len(solver_results)} solver results to {save_dir}")


def validate_mutations(
    dataset: str,
    mutations_path: str,
    solver_results_path: str,
    verdict: Verdict,
    save_path: str = 'data/mutator_solver_verdict',
    hf_repo_name: str = '<>/<>',
    push_to_hub: bool = True,
    num_workers: int = 4
):
    """Validate previously generated mutations using a verdict model."""
    # Load saved mutations and solver results
    with open(mutations_path) as f:
        mutations = json.load(f)
    with open(solver_results_path) as f:
        solver_results = json.load(f)

    # Function to process a single mutation
    def process_mutation(mut):
        passed, info = verdict(problem=mut['problem'], completion=mut['mutation'])
        if "kodcode" in dataset or "bigcodebench" in dataset:
            mutation_is_valid = True
        else:
            mutation_is_valid = mut['success'] and not passed
        if mutation_is_valid:
            mut_solver_results = [res for res in solver_results if res['mutation_id'] == mut['mutation_id']]
            mut['mutation_info'] = info
            if "failed to map segment from shared object" in str(info) or "timeout" in str(info):
                print(str(info))
            return mut, mut_solver_results

    filtered_mutations = []
    filtered_solver_results = []

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_mutation, mut): mut for mut in mutations}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Validating mutations"):
            result = future.result()
            if result:
                filtered_mutations.append(result[0])
                filtered_solver_results.extend(result[1])
            del futures[future]

    print(f"Filtered {len(filtered_mutations)} valid mutations from {len(mutations)} total mutations")
    verdict_dict_list = compute_verdict_scores(filtered_solver_results, verdict)
    print(f"Computed verdict scores for {len(filtered_mutations)} mutations")
    
    # Create lookup dictionaries for scores and info
    scores_by_mutation = {}
    info_by_mutation = {}
    for verdict_dict in verdict_dict_list:
        mutation_id = verdict_dict['mutation_id']
        if mutation_id not in scores_by_mutation:
            scores_by_mutation[mutation_id] = []
        scores_by_mutation[mutation_id].append(verdict_dict['score'])
        if mutation_id not in info_by_mutation:
            info_by_mutation[mutation_id] = []
        info_by_mutation[mutation_id].append(verdict_dict['info'])
    
    for res in filtered_mutations:
        if res['mutation_id'] not in scores_by_mutation:
            print(f"[Warning] mutation_id {res['mutation_id']} not found in scores_by_mutation")
    
    # Calculate mutator scores based on verdict info
    mutator_scores = []
    for res in filtered_mutations:
        mutation_id = res['mutation_id']
        mutation_info = res['mutation_info']
        if "kodcode" in dataset:
            # For kodcode, a valid mutation must have an AssertionError in tracebacks
            has_assertion_error = any('AssertionError' in traceback for traceback in mutation_info.get('tracebacks', []))
            if has_assertion_error:
                # Valid mutation - score based on solver performance
                mutator_scores.append(1.0 - np.mean(scores_by_mutation[mutation_id]))
            else:
                # Invalid mutation - either compile error or wrong type of error
                mutator_scores.append(0.0)
        elif "bigcodebench" in dataset:
            # For bigcodebench, a valid mutation must not pass all tests
            if 'ALL' not in mutation_info:
                # Valid mutation - score based on solver performance
                mutator_scores.append(1.0 - np.mean(scores_by_mutation[mutation_id]))
            else:
                mutator_scores.append(0.0)
        else:
            # For other datasets, a valid mutation is one that fails tests
            if not res['success']:
                # Invalid mutation - failed to generate
                mutator_scores.append(0.0)
            else:
                # Valid mutation - score based on solver performance
                mutator_scores.append(1.0 - np.mean(scores_by_mutation[mutation_id]))
    
    dataset_dict = {
        'mutation_id': [res['mutation_id'] for res in filtered_mutations],
        'task_id': [res['problem']['task_id'] for res in filtered_mutations],
        'mutator_prompt': [res['mutator_prompt'] for res in filtered_mutations],
        'solver_prompt': [next(res['solver_prompt'] for res in filtered_solver_results if res['mutation_id'] == mut['mutation_id']) for mut in filtered_mutations],
        'response': [f"```python\n{res['mutation']}\n```" for res in filtered_mutations],
        'mutation_explanation': [res['mutation_explanation'] for res in filtered_mutations],
        "mutation_info": [json.dumps(res['mutation_info']) for res in filtered_mutations],
        'mutator_score': mutator_scores,
        'solution_scores': [json.dumps(scores_by_mutation[res['mutation_id']]) for res in filtered_mutations],
        'solutions': [json.dumps([f"```python\n{res['solution']}\n```" for res in filtered_solver_results if res['mutation_id'] == mut['mutation_id']]) for mut in filtered_mutations],
        'solutions_explanation': [json.dumps([res['solution_explanation'] for res in filtered_solver_results if res['mutation_id'] == mut['mutation_id']]) for mut in filtered_mutations],
        'solutions_info': [json.dumps(info_by_mutation[res['mutation_id']]) for res in filtered_mutations],
    }

    dataset = Dataset.from_dict(dataset_dict)
    if push_to_hub:
        token = os.environ.get('HF_TOKEN') or HfFolder.get_token()

        create_repo(hf_repo_name, repo_type='dataset', exist_ok=True, private=False, token=token)
        dataset.push_to_hub(hf_repo_name, private=False,token=token)
        print(f"Successfully pushed {len(filtered_mutations)} mutations to {hf_repo_name}")


    return dataset


def generate_solutions(
    dataset: list, 
    solver: Agent, 
    num_solver_iters: int = 10,
):
    """Generate solutions, saving them locally."""
    solutions = sample_solutions_data_gen(solver, dataset, num_solver_iters)
    return solutions


# def validate_solutions(
#     solver_results_path: str,
#     verdict: Verdict,
#     save_path: str = 'data/solver_verdict',
#     hf_repo_name: str = '<>/<>',
#     push_to_hub: bool = True,
#     num_workers: int = 4
# ):
#     """Validate previously generated solutions using a verdict model."""
#     # Load solver results
#     with open(solver_results_path) as f:
#         solutions = json.load(f)

#     # Function to process a single solution
#     def process_solution(solution):
#         passed, info = verdict(problem=solution['problem'], completion=solution['solution'])
#         if passed is not None:  # Include all valid solutions
#             sol_dict = {
#                 'solution_id': solution['solution_id'],
#                 'problem_id': solution['problem_id'],
#                 'solver_prompt': solution['prompt'],
#                 'solution': solution['solution'],
#                 'solution_info': solution['solution_info'],
#                 'score': float(passed),
#                 'verdict_info': info,
#                 'problem': solution['problem']
#             }
#             # Check if reward = num_passed_unittests / num_unittests is in info
#             if "reward" in info.keys():
#                 sol_dict['score'] = info['reward']
#             return sol_dict
#         return None

#     filtered_solutions = []
    
#     # Process solutions with multiple workers
#     with ThreadPoolExecutor(max_workers=num_workers) as executor:
#         futures = {executor.submit(process_solution, sol): sol for sol in solutions}
#         for future in tqdm(as_completed(futures), total=len(futures), desc="Validating solutions"):
#             result = future.result()
#             if result:
#                 filtered_solutions.append(result)
#                 if "failed to map segment from shared object" in str(result['verdict_info']) or "timeout" in str(result['verdict_info']):
#                     print(f"Warning for solution {result['solution_id']}: {result['verdict_info']}")

#     print(f"Processed {len(filtered_solutions)} solutions from {len(solutions)} total solutions")

#     dataset_dict = {
#         'mutation_id': [None for _ in filtered_solutions],
#         'task_id': [sol['problem']['task_id'] for sol in filtered_solutions],
#         'mutator_prompt': [None for _ in filtered_solutions],
#         'solver_prompt': [sol['solver_prompt'] for sol in filtered_solutions],
#         'response': [None for _ in filtered_solutions],
#         "mutation_info": [None for _ in filtered_solutions],
#         'mutator_score': [None for _ in filtered_solutions],
#         'solution_scores': [json.dumps([sol['score']]) for sol in filtered_solutions],
#         'solutions': [json.dumps([f"```python\n{sol['solution']}\n```"]) for sol in filtered_solutions],
#         'solutions_info': [json.dumps(sol['verdict_info']) for sol in filtered_solutions]
#     }

#     dataset = Dataset.from_dict(dataset_dict)
#     print(f"Saving dataset to {save_path}")
#     dataset.save_to_disk(save_path)
#     if push_to_hub:
#         dataset.push_to_hub(hf_repo_name)
#     print(f"Pushed {len(filtered_solutions)} solutions to {hf_repo_name}")


import os
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from datasets import Dataset

# Helper: convert JSON array to JSONL for streaming
def ensure_jsonlines(path: str) -> str:
    jl = path.replace('.json', '.jl')
    if path.endswith('.json') and not os.path.exists(jl):
        print(f"Converting {path} → {jl}")
        with open(path, 'r') as fin, open(jl, 'w') as fout:
            data = json.load(fin)
            for rec in data:
                fout.write(json.dumps(rec) + '\n')
    return jl if os.path.exists(jl) else path

# Stream JSONL lines as objects
def stream_jsonlines(path: str):
    with open(path, 'r') as f:
        for line in f:
            yield json.loads(line)

# Original single-solution processor
def process_solution(solution, verdict):
    passed, info = verdict(problem=solution['problem'], completion=solution['solution'])
    if passed is not None:
        sol_dict = {
            'solution_id': solution['solution_id'],
            'problem_id': solution['problem_id'],
            'solver_prompt': solution['prompt'],
            'solution': solution['solution'],
            'solution_info': solution['solution_info'],
            'score': float(passed),
            'verdict_info': info,
            'problem': solution['problem']
        }
        if 'reward' in info:
            sol_dict['score'] = info['reward']
        return sol_dict
    return None

# Batch helper
def _process_solution_batch(batch, verdict, out_list, num_workers):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_solution, sol, verdict): sol for sol in batch}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing batch", leave=False):
            res = future.result()
            if res:
                out_list.append(res)
                if 'failed to map segment' in str(res['verdict_info']) or 'timeout' in str(res['verdict_info']):
                    print(f"Warning for solution {res['solution_id']}: {res['verdict_info']}")
            del futures[future]

# Refactored validate_solutions
def validate_solutions(
    solver_results_path: str,
    verdict: Verdict,
    save_path: str = 'data/solver_verdict',
    hf_repo_name: str = '<>/<>',
    push_to_hub: bool = True,
    num_workers: int = 4
):
    """Validate previously generated solutions using a verdict model (streaming & batching)."""
    # Convert to JSONL if needed
    results_path = ensure_jsonlines(solver_results_path)

    # Count total solutions first
    total_solutions = sum(1 for _ in stream_jsonlines(results_path))
    print(f"Found {total_solutions} solutions to validate")

    # Stream & batch
    filtered_solutions = []
    batch = []
    batch_size = 500
    
    # Create progress bar for overall validation
    with tqdm(total=total_solutions, desc="Validating solutions") as pbar:
        for sol in stream_jsonlines(results_path):
            batch.append(sol)
            if len(batch) >= batch_size:
                _process_solution_batch(batch, verdict, filtered_solutions, num_workers)
                pbar.update(len(batch))
                batch.clear()
        if batch:
            _process_solution_batch(batch, verdict, filtered_solutions, num_workers)
            pbar.update(len(batch))

    print(f"Processed {len(filtered_solutions)} solutions (streamed)")

    # Build dataset dict as before
    dataset_dict = {
        'mutation_id': [None] * len(filtered_solutions),
        'task_id': [sol['problem']['task_id'] for sol in filtered_solutions],
        'mutator_prompt': [None] * len(filtered_solutions),
        'solver_prompt': [sol['solver_prompt'] for sol in filtered_solutions],
        'response': [None] * len(filtered_solutions),
        'mutation_info': [None] * len(filtered_solutions),
        'mutator_score': [None] * len(filtered_solutions),
        'solution_scores': [json.dumps([sol['score']]) for sol in filtered_solutions],
        'solutions': [json.dumps([f"```python\n{sol['solution']}\n```"]) for sol in filtered_solutions],
        'solutions_info': [json.dumps(sol['verdict_info']) for sol in filtered_solutions]
    }

    dataset = Dataset.from_dict(dataset_dict)
    print(f"Saving dataset to {save_path}")
    dataset.save_to_disk(save_path)
    if push_to_hub:
        dataset.push_to_hub(hf_repo_name)
    print(f"Pushed {len(filtered_solutions)} solutions to {hf_repo_name}")


def main():
    args = parse_args()

    # Set random seeds for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    # Load dataset
    print("Loading dataset and prompt handlers...")
    dataset_engine, mutator_prompt_handler, solver_prompt_handler, codegen_prompt_handler, incorrect_codegen_prompt_handler = load_dataset_and_prompt_handlers(args.dataset)
    return_problem = args.generation_type == "solutions"
    trainset, testset = dataset_engine(splits_file=args.splits_file, return_problem=return_problem)
    if args.debug:
        trainset = trainset[:args.debug_samples]
        print(f"Debug mode: using {args.debug_samples} samples")
    print(f"Trainset size: {len(trainset)}")
    print(f"Testset size: {len(testset)}")

    if args.mode == 'generate':
        print("Starting to load models...")
        mutator_vllm_model, solver_vllm_model = get_vllm_models(
            mutator_model_name=args.model_name,
            solver_model_name=args.solver_model_name,
            max_model_len=args.max_model_len,
            enforce_eager=args.enforce_eager,
            num_gpus=torch.cuda.device_count(),
            gpu_memory_utilization=args.gpu_memory_utilization,
        )
        print("VLLM models loaded.")
        if args.generation_type == "mutations":
            # Initialize mutator
            mutator = mutator_engine(
                model_name=args.model_name, 
                vllm_model=mutator_vllm_model,
                max_tokens=args.max_new_tokens,
                temperature=args.attacker_temperature,
                top_p=args.attacker_top_p,
                top_k=args.attacker_top_k,
                repetition_penalty=args.attacker_repetition_penalty,
                n=args.num_mutator_iters,
                peft_dir=args.mutator_peft_dir,
                prompt_handler=mutator_prompt_handler
            )
            # Initialize solver
            solver = solver_engine(
                model_name=args.solver_model_name,  # Use solver_model_name
                vllm_model=solver_vllm_model,  # Use appropriate model
                max_tokens=args.max_new_tokens,
                temperature=args.solver_temperature,
                top_p=args.solver_top_p,
                top_k=args.solver_top_k,
                repetition_penalty=args.solver_repetition_penalty,
                n=args.num_solver_iters,
                peft_dir=args.solver_peft_dir,
                prompt_handler=solver_prompt_handler
            )
            generate_mutations(
                trainset, mutator, solver,
                num_mutator_iters=args.num_mutator_iters,
                num_solver_iters=args.num_solver_iters,
                save_dir=args.save_dir,
                mutation_subparts_to_generate=args.mutation_subparts_to_generate
            )
        else:  # solutions
            solver = solver_engine(
                model_name=args.solver_model_name,  # Use solver_model_name
                vllm_model=solver_vllm_model,  # Use appropriate model
                max_tokens=args.max_new_tokens,
                temperature=args.solver_temperature,
                top_p=args.solver_top_p,
                top_k=args.solver_top_k,
                repetition_penalty=args.solver_repetition_penalty,
                n=args.num_solver_iters,
                peft_dir=args.solver_peft_dir,
                prompt_handler=codegen_prompt_handler
            )
            solutions = generate_solutions(
                trainset, solver,
                num_solver_iters=args.num_solver_iters,
            )
            
            buggy_solver = solver_engine(
                model_name=args.solver_model_name,  
                vllm_model=solver_vllm_model,
                max_tokens=args.max_new_tokens,
                temperature=args.solver_temperature,
                top_p=args.solver_top_p,
                top_k=args.solver_top_k,
                repetition_penalty=args.solver_repetition_penalty,
                n=args.num_solver_iters,
                peft_dir=args.solver_peft_dir,
                prompt_handler=incorrect_codegen_prompt_handler
            )
            buggy_solutions = generate_solutions(
                trainset, buggy_solver,
                num_solver_iters=args.num_solver_iters,
            )
            
            # Combine solutions and validate
            all_solutions = solutions + buggy_solutions
            if not os.path.exists(args.save_dir):
                os.makedirs(args.save_dir, exist_ok=True)
            with open(os.path.join(args.save_dir, 'solutions.json'), 'w') as f:
                json.dump(all_solutions, f)
    
    else:  # validate mode
        verdict = verdict_engine(args.dataset)
        if args.generation_type == "mutations":
            validate_mutations(
                dataset=args.dataset,
                mutations_path=args.mutations_path,
                solver_results_path=args.solver_results_path,
                verdict=verdict,
                save_path=args.save_path,
                hf_repo_name=args.hf_repo_name,
                push_to_hub=args.push_to_hub,
                num_workers=args.num_workers
            )
        else:  # solutions
            validate_solutions(
                solver_results_path=args.solver_results_path,
                verdict=verdict,
                save_path=args.save_path,
                hf_repo_name=args.hf_repo_name,
                push_to_hub=args.push_to_hub,
                num_workers=args.num_workers
            )


if __name__ == "__main__":
    load_dotenv(os.path.expanduser('.env'))
    main()
