import os
import re
import json
import argparse
import random
from base_prompt import *
import requests
import stat
import ollama

import openai

openai.api_key = os.getenv("OPENAI_API_KEY")

# 测试
def load_data(args):
    problems = json.load(open(os.path.join(args.data_root, 'problems.json')))
    pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json')))
    captions = json.load(open(args.caption_file))["captions"]

    for qid in problems:
        problems[qid]['caption'] = captions[qid] if qid in captions else ""

    qids = pid_splits['%s' % (args.test_split)]
    qids = qids[:args.test_number] if args.test_number > 0 else qids
    print(f"number of test problems: {len(qids)}\n")

    # pick up shot examples from the training set
    shot_qids = args.shot_qids
    train_qids = pid_splits['train']
    if shot_qids == None:
        assert args.shot_number >= 0 and args.shot_number <= 32
        shot_qids = random.sample(train_qids, args.shot_number)  # random sample
    else:
        shot_qids = [str(qid) for qid in shot_qids]
        for qid in shot_qids:
            assert qid in train_qids  # check shot_qids
    print("training question ids for prompting: ", shot_qids, "\n")

    return problems, qids, shot_qids


def get_qwen_result(image_path, prompt, args):
    messages = [{
        "role": "user",
        "content": prompt
    }]


    if os.path.exists(image_path):  # 这里可以添加条件检查，以确保路径有效
        messages[0]["images"] = [image_path]


    response = ollama.chat(
        model="llava:7b",
        stream=False,
        messages=messages,
        options={
            "temperature": args.temperature,
            "max_tokens": args.max_tokens,
            "top_p": args.top_p,
            "frequency_penalty": args.frequency_penalty,
            "presence_penalty": args.presence_penalty,
            "stop": ["\n"]
        }
    )

    output = response['message']['content']

    # extract the answer
    pattern = re.compile(r'The answer is ([A-Z]).')
    res = pattern.findall(output)
    if len(res) == 1:
        answer = res[0]  # 'A', 'B', ...
    else:
        answer = "FAILED"

    return answer, output


def get_pred_idx(prediction, choices, options):
    """
    Get the index (e.g. 2) from the prediction (e.g. 'C')
    """
    if prediction in options[:len(choices)]:
        return options.index(prediction)
    else:
        return random.choice(range(len(choices)))


def get_result_file(args):
    name = "{}_{}_{}_seed_{}_{}.json".format(args.label, 'test', args.prompt_format, args.seed, '8random')

    result_file = os.path.join(args.output_root, args.model, name)
    os.makedirs(os.path.dirname(result_file), exist_ok=True)
    return result_file

# test
def save_results(result_file, current_run_data, all_runs_data, average_acc):
    # 合并历史数据和新数据
    merged_data = {
        "average_accuracy": average_acc,
        "total_runs": len(all_runs_data) + 1,
        "runs": all_runs_data + [current_run_data]
    }

    # 使用临时文件保证写入原子性
    temp_file = result_file + ".tmp"
    with open(temp_file, 'w') as f:
        json.dump(merged_data, f, indent=2, separators=(',', ': '))

    # 原子替换文件
    os.replace(temp_file, result_file)

def load_existing_results(result_file):
    if os.path.exists(result_file):
        with open(result_file) as f:
            data = json.load(f)
            return data["runs"], data.get("average_accuracy", 0)
    return [], 0

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_root', type=str, default='./data/scienceqa')
    parser.add_argument('--output_root', type=str, default='./sqa/results')
    parser.add_argument('--caption_file', type=str, default='./data/captions.json')
    parser.add_argument('--model', type=str, default='llava_test')
    parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
    # user options
    parser.add_argument('--label', type=str, default='exp0')
    parser.add_argument('--test_split', type=str, default='val', choices=['test', 'val', 'minival'])
    parser.add_argument('--test_number', type=int, default=10, help='GPT-3 is expensive. -1 for whole val/test set')
    parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
    parser.add_argument('--save_every', type=int, default=10, help='Save the result with every n examples.')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--prompt_format',
                        type=str,
                        default='CQM-A',
                        choices=[
                            'CQM-A', 'CQM-LA', 'CQM-EA', 'CQM-LEA', 'CQM-ELA', 'CQM-AL', 'CQM-AE', 'CQM-ALE', 'QCM-A',
                            'QCM-LA', 'QCM-EA', 'QCM-LEA', 'QCM-ELA', 'QCM-AL', 'QCM-AE', 'QCM-ALE', 'QCML-A', 'QCME-A',
                            'QCMLE-A', 'QCLM-A', 'QCEM-A', 'QCLEM-A', 'QCML-AE'
                        ],
                        help='prompt format template')
    parser.add_argument('--shot_number', type=int, default=3, help='Number of n-shot training examples.')
    parser.add_argument('--shot_qids', type=list, default=None, help='Question indexes of shot examples')
    parser.add_argument('--seed', type=int, default=10, help='random seed')
    # GPT-3 settings
    parser.add_argument('--temperature', type=float, default=0.5)
    parser.add_argument('--max_tokens',
                        type=int,
                        default=512,
                        help='The maximum number of tokens allowed for the generated answer.')
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--frequency_penalty', type=float, default=0.0)
    parser.add_argument('--presence_penalty', type=float, default=0.0)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print('====Input Arguments====')
    print(json.dumps(vars(args), indent=2, sort_keys=False))

    result_file = get_result_file(args)
    print("Result file:", result_file)

    # 加载已有结果
    if os.path.exists(result_file):
        with open(result_file) as f:
            existing_data = json.load(f)
            completed_runs = existing_data.get("completed_runs", [])
            print(f"Found {len(completed_runs)} existing runs")
    else:
        existing_data = {
            "average_accuracy": 0.0,
            "total_runs": 5,
            "completed_runs": [],
            "all_results": {}
        }
        completed_runs = []

    original_seed = args.seed
    total_runs = 5

    for run_idx in range(len(completed_runs), total_runs):
        print(f"\n=== Starting Run {run_idx + 1}/{total_runs} ===")
        args.seed = original_seed + 2 * run_idx  # 不同运行使用不同种子
        random.seed(args.seed)

        # 检查是否存在未完成的运行记录
        run_id = f"run_{run_idx + 1}"
        run_data = existing_data["all_results"].get(run_id, {
            "run_number": run_idx + 1,
            "seed": args.seed,
            "correct": 0,
            "total": 0,
            "acc": 0.0,
            "results": {},
            "outputs": {},
            "status": "running",
            "checkpoint": None
        })

        # 恢复检查点
        if run_data["status"] == "completed":
            print(f"Run {run_idx + 1} already completed, skipping")
            continue

        problems, qids, shot_qids = load_data(args)
        start_index = 0

        # 恢复进度
        if run_data["checkpoint"] is not None:
            print(f"Resuming from checkpoint: {run_data['checkpoint']}")
            start_index = run_data["checkpoint"] + 1
            existing_results = run_data["results"]
        else:
            existing_results = {}

        try:
            for i, qid in enumerate(qids[start_index:]):
                current_index = start_index + i

                if qid in existing_results:
                    continue

                # 原有处理逻辑
                choices = problems[qid]["choices"]
                answer = problems[qid]["answer"]
                label = args.options[answer]

                # 生成prompt
                prompt = build_prompt(problems, shot_qids, qid, args)

                # 处理图像
                image_path = os.path.join(args.data_root, 'images/test', qid, 'image.png')

                # 获取预测结果
                prediction, output = get_qwen_result(image_path, prompt, args)
                pred_idx = get_pred_idx(prediction, choices, args.options)

                # 更新结果
                run_data["results"][qid] = pred_idx
                run_data["outputs"][qid] = output
                if pred_idx == answer:
                    run_data["correct"] += 1
                run_data["total"] += 1

                # 定期保存检查点
                if (current_index + 1) % args.save_every == 0:
                    run_data["acc"] = run_data["correct"] / run_data["total"] * 100
                    run_data["checkpoint"] = current_index
                    existing_data["all_results"][run_id] = run_data

                    # 临时保存
                    temp_file = result_file + ".tmp"
                    os.makedirs(os.path.dirname(temp_file), exist_ok=True)
                    with open(temp_file, "w") as f:
                        json.dump(existing_data, f, indent=2)
                    os.replace(temp_file, result_file)
                    print(
                        f"Checkpoint saved at {current_index + 1}/{len(qids)} | "
                        f"当前准确率: {run_data['acc']:.2f}% ({run_data['correct']}/{run_data['total']})"  # 新增输出
                    )

            # 完成当前运行
            run_data["status"] = "completed"
            run_data["acc"] = run_data["correct"] / len(qids) * 100
            run_data["total"] = len(qids)
            del run_data["checkpoint"]

            # 更新总数据
            completed_runs.append(run_data["acc"])
            existing_data["completed_runs"] = completed_runs
            existing_data["all_results"][run_id] = run_data

            # 计算平均准确率
            if completed_runs:
                existing_data["average_accuracy"] = sum(completed_runs) / len(completed_runs)

            # 最终保存
            with open(result_file, "w") as f:
                json.dump(existing_data, f, indent=2)
            print(f"Run {run_idx + 1} completed. Accuracy: {run_data['acc']:.2f}%")

        except Exception as e:
            emergency_path = result_file + ".emergency"
            os.makedirs(os.path.dirname(emergency_path), exist_ok=True)  # 新增目录创建
            print(f"Error occurred: {str(e)}")
            print("Saving emergency checkpoint...")
            with open(emergency_path, "w") as f:
                json.dump(existing_data, f)
            raise

    # 最终输出
    print("\n=== Final Results ===")
    print(f"Completed runs: {len(completed_runs)}/{total_runs}")
    if completed_runs:
        print(f"Average accuracy: {existing_data['average_accuracy']:.2f}%")
        print("Individual runs:")
        for idx, acc in enumerate(completed_runs):
            print(f"Run {idx + 1}: {acc:.2f}%")
