#!/usr/bin/env python3

"""
This module contains functions for running inference using Ollama models.
It processes SWE-bench instances and generates patches using the specified Ollama model.
"""

import json
import logging
import re
from argparse import ArgumentParser
from datetime import datetime
from pathlib import Path
import requests
from tqdm.auto import tqdm
import tiktoken
from swebench.inference.make_datasets.utils import extract_diff
from swebench.harness.constants import RUN_EVALUATION_LOG_DIR

# 创建日志目录
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    handlers=[
        logging.FileHandler(log_dir / "run_ollama.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def count_tokens(text: str, model: str = "gpt-3.5-turbo") -> int:
    """
    使用tiktoken计算文本的token数量
    
    Args:
        text: 要计算token数的文本
        model: 模型名称，用于选择对应的编码器
    
    Returns:
        int: token数量
    """
    try:
        encoding = tiktoken.encoding_for_model(model)
        return len(encoding.encode(text))
    except KeyError:
        # 如果找不到对应模型的编码器，使用cl100k_base编码器
        encoding = tiktoken.get_encoding("cl100k_base")
        return len(encoding.encode(text))

def write_token_stats(run_id: str, model_name: str, stats: dict, instance_ids: set):
    """
    将token统计信息写入文件，如果文件已存在则累加tokens数量
    
    Args:
        run_id: 运行ID
        model_name: 模型名称
        stats: 统计信息字典
        instance_ids: 本次处理的instance_ids集合
    """
    stats_dir = RUN_EVALUATION_LOG_DIR / run_id / model_name.replace("/", "__")
    stats_dir.mkdir(parents=True, exist_ok=True)
    
    stats_file = stats_dir / "token_stats.json"
    
    # 如果文件存在，读取现有统计信息
    if stats_file.exists():
        try:
            with open(stats_file, "r") as f:
                existing_stats = json.load(f)
            
            # 获取已处理的instance_ids
            processed_ids = set(existing_stats.get("processed_instance_ids", []))
            
            # 只累加新处理的instance_ids对应的tokens
            new_ids = instance_ids - processed_ids
            if new_ids:
                # 累加tokens数量
                stats["total_input_tokens"] += existing_stats.get("total_input_tokens", 0)
                stats["total_output_tokens"] += existing_stats.get("total_output_tokens", 0)
                stats["total_tokens"] = stats["total_input_tokens"] + stats["total_output_tokens"]
                stats["total_iterations"] += existing_stats.get("total_iterations", 0)
                
                # 更新已处理的instance_ids
                stats["processed_instance_ids"] = list(processed_ids | instance_ids)
                
                # 重新计算平均值
                if stats["total_iterations"] > 0:
                    stats["average_input_tokens_per_iteration"] = stats["total_input_tokens"] / stats["total_iterations"]
                    stats["average_output_tokens_per_iteration"] = stats["total_output_tokens"] / stats["total_iterations"]
                    stats["average_iterations_per_instance"] = stats["total_iterations"] / len(stats["processed_instance_ids"])
                
                logger.info(f"Accumulated tokens for {len(new_ids)} new instances: input={stats['total_input_tokens']}, output={stats['total_output_tokens']}")
                logger.info(f"Average iterations per instance: {stats['average_iterations_per_instance']:.2f}")
            else:
                logger.info("No new instances processed, keeping existing stats")
                stats = existing_stats
        except Exception as e:
            logger.error(f"Error reading existing stats file: {e}")
            # 如果读取失败，使用新的统计信息
            logger.info("Using new stats instead")
            stats["processed_instance_ids"] = list(instance_ids)
    else:
        logger.info("No existing stats file found, creating new one")
        stats["processed_instance_ids"] = list(instance_ids)
    
    # 确保新创建的统计信息也包含average_iterations_per_instance
    if "average_iterations_per_instance" not in stats:
        stats["average_iterations_per_instance"] = stats["total_iterations"] / len(stats["processed_instance_ids"]) if stats["processed_instance_ids"] else 0
    
    # 写入更新后的统计信息
    with open(stats_file, "w") as f:
        json.dump(stats, f, indent=4)
    logger.info(f"Token statistics written to {stats_file}")

def get_output_file(
    output_dir,
    model_name,
    dataset_path,
    split,
    temperature,
    top_p,
    min_len,
    max_len,
    shard_id,
    num_shards,
    run_id,
):
    """
    Constructs the output file path based on the provided parameters.
    """
    suffix = ""
    if min_len is not None:
        suffix += f"__min-{min_len}"
    if max_len is not None:
        suffix += f"__max-{max_len}"
    if shard_id is not None and num_shards is not None:
        suffix += f"__shard-{shard_id}-{num_shards}"
    
    dset_nickname = Path(dataset_path).name + "__" + split
    model_nickname = model_name.replace("/", "__")
    
    output_file = Path(
        output_dir,
        f"{dset_nickname}__{model_nickname}__run-{run_id}__temp-{temperature}__top-p-{top_p}{suffix}.jsonl"
    )
    
    if not output_file.parent.exists():
        output_file.parent.mkdir(parents=True, exist_ok=True)
    return output_file

def call_ollama(prompt, model_name, temperature=0.0, top_p=1.0, max_tokens=200):
    """
    Call Ollama API to generate text.
    """
    url = "http://10.129.164.27:11434/api/generate"
    data = {
        "model": model_name,
        "prompt": prompt,
        "temperature": temperature,
        "top_p": top_p,
        "max_tokens": max_tokens,
        "stream": False
    }
    
    try:
        response = requests.post(url, json=data)
        response.raise_for_status()
        return response.json()["response"]
    except Exception as e:
        logger.error(f"Error calling Ollama API: {e}")
        raise

def generate(
    dataset,
    model_name,
    temperature,
    top_p,
    fileobj,
    run_id,
):
    """
    Generate patches for each instance in the dataset.
    """
    fail_count = 0
    total_input_tokens = 0
    total_output_tokens = 0
    total_iterations = 0
    processed_ids = set()
    
    for ix, instance in enumerate(tqdm(dataset, desc="Generating patches")):
        try:
            # 构建输入文本
            input_text = f"""Problem Statement:
{instance['problem_statement']}

Repository: {instance['repo']}
Base Commit: {instance['base_commit']}

Please generate a patch to fix this issue. The patch should be in git diff format.
"""
            # 计算输入tokens
            input_tokens = count_tokens(input_text)
            total_input_tokens += input_tokens
            
            logger.info(f"Processing instance {instance['instance_id']}")
            logger.info(f"Input tokens: {input_tokens}")
            
            start = datetime.now()
            output = call_ollama(
                prompt=input_text,
                model_name=model_name,
                temperature=temperature,
                top_p=top_p
            )
            
            # 计算输出tokens
            output_tokens = count_tokens(output)
            total_output_tokens += output_tokens
            total_iterations += 1
            processed_ids.add(instance["instance_id"])
            
            logger.info(f"Output tokens: {output_tokens}")
            logger.info(f"Generated response in {(datetime.now() - start).total_seconds()} seconds")
            logger.info(output[:200])
            
            diff = extract_diff(output)
            res = {
                "instance_id": instance["instance_id"],
                "full_output": output,
                "model_patch": diff,
                "model_name_or_path": model_name,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
            }
            print(json.dumps(res), file=fileobj, flush=True)
            
        except Exception as e:
            logger.exception(e)
            print(f"failed on {ix}")
            fail_count += 1
            if fail_count >= 3:
                raise ValueError("too many failures")
    
    # 写入统计信息
    stats = {
        "total_input_tokens": total_input_tokens,
        "total_output_tokens": total_output_tokens,
        "total_tokens": total_input_tokens + total_output_tokens,
        "total_iterations": total_iterations,
        "average_input_tokens_per_iteration": total_input_tokens / total_iterations if total_iterations > 0 else 0,
        "average_output_tokens_per_iteration": total_output_tokens / total_iterations if total_iterations > 0 else 0,
        "average_iterations_per_instance": total_iterations / len(processed_ids) if processed_ids else 0,
        "timestamp": datetime.now().isoformat(),
        "model_name": model_name,
        "temperature": temperature,
        "top_p": top_p
    }
    write_token_stats(run_id, model_name, stats, processed_ids)

def get_all_existing_ids(output_file):
    """
    Get all existing instance IDs from output files.
    """
    if not output_file.exists():
        logger.info("Output file does not exist, no existing IDs to filter")
        return set()
    
    existing_ids = set()
    with open(output_file) as f:
        for line in f:
            datum = json.loads(line)
            existing_ids.add(datum["instance_id"])
    
    logger.info(f"Found {len(existing_ids)} existing ids")
    return existing_ids

def main(
    model_name,
    dataset_path,
    split,
    temperature,
    top_p,
    output_dir,
    min_len,
    max_len,
    shard_id,
    num_shards,
    run_id,
    num_samples,
):
    """
    Main function to run inference using Ollama model.
    
    Args:
        num_samples (int): 要推理的样本数量，-1表示推理所有样本
    """
    if shard_id is not None and num_shards is None:
        raise ValueError("num_shards must be specified with shard_id")
    if shard_id is None and num_shards is not None:
        raise ValueError("shard_id must be specified with num_shards")
    
    output_file = get_output_file(
        output_dir=output_dir,
        model_name=model_name,
        dataset_path=dataset_path,
        split=split,
        temperature=temperature,
        top_p=top_p,
        min_len=min_len,
        max_len=max_len,
        shard_id=shard_id,
        num_shards=num_shards,
        run_id=run_id,
    )
    
    logger.warning(f"output_file: {output_file}")
    
    # 如果输出文件不存在，清空统计信息文件
    if not output_file.exists():
        stats_dir = RUN_EVALUATION_LOG_DIR / run_id / model_name.replace("/", "__")
        stats_file = stats_dir / "token_stats.json"
        if stats_file.exists():
            logger.info(f"Output file does not exist, clearing statistics file: {stats_file}")
            stats_file.unlink()
    
    # Load dataset
    with open(dataset_path) as f:
        dataset = [json.loads(line) for line in f]
    
    logger.info(f"Loaded {len(dataset)} samples from dataset")
    
    # 先根据num_samples参数限制样本数量
    if num_samples > 0:
        dataset = dataset[:num_samples]
        logger.info(f"Limiting to {num_samples} samples before filtering")
    else:
        logger.info(f"Processing all {len(dataset)} samples")
    
    # 再过滤existing IDs
    existing_ids = get_all_existing_ids(output_file)
    dataset = [x for x in dataset if x["instance_id"] not in existing_ids]
    logger.info(f"After filtering existing IDs: {len(dataset)} samples remaining")
    
    if not dataset:
        logger.warning("No samples to process after filtering!")
        return
    
    # Generate patches
    with open(output_file, "a") as f:
        generate(
            dataset=dataset,
            model_name=model_name,
            temperature=temperature,
            top_p=top_p,
            fileobj=f,
            run_id=run_id,
        )
    
    logger.info("Done")

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="deepseek-r1:32b",
        help="Name of the Ollama model to use"
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
        help="Path to the input dataset file"
    )
    parser.add_argument(
        "--split",
        type=str,
        default="test",
        help="Dataset split to use"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./outputs",
        help="Directory to save outputs"
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Sampling temperature"
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=1.0,
        help="Top-p sampling parameter"
    )
    parser.add_argument(
        "--min_len",
        type=int,
        default=None,
        help="Minimum length of input sequences to include"
    )
    parser.add_argument(
        "--max_len",
        type=int,
        default=None,
        help="Maximum length of input sequences to include"
    )
    parser.add_argument(
        "--shard_id",
        type=int,
        default=None,
        help="ID of the shard to load"
    )
    parser.add_argument(
        "--num_shards",
        type=int,
        default=None,
        help="Total number of shards"
    )
    parser.add_argument(
        "--run_id",
        type=str,
        required=True,
        help="Run ID - identifies the run"
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=-1,
        help="Number of samples to process (-1 for all samples)"
    )
    
    args = parser.parse_args()
    main(**vars(args))