# evaluator.py
import os
import re
import torch
import pickle
import numpy as np
import networkx as nx
from tqdm import tqdm

def encode(s, stoi):
    pattern = r'\d+'
    matches = re.findall(pattern, s)
    encoded_string = [stoi[ch] for ch in matches]
    return encoded_string

def decode(l, itos):
    return ' '.join([itos[i] for i in l])

def find_third_number_position(number_string):
    numbers = number_string.split()
    third_number_index = 2
    position = sum(len(num) for num in numbers[:third_number_index]) + third_number_index
    return position

def check_path(G, gen_str, reachability):
    path = re.findall(r'\d+', gen_str)
    if len(path) < 3:
        return 'too short'
    if path[2] != path[0]:
        return 'incorrect start'
    for i in range(3, len(path)):
        node = int(path[i])
        if node >= len(reachability) or node < 0:
            return f'wrong syntax {node}'
        if not G.has_edge(path[i-1], path[i]):
            return f'non-existence edge {path[i-1], path[i]}'
        if path[i] == path[1]:
            if i == len(path) - 1:
                return ''
            else:
                return f'stop error {path[i]}'
        if reachability[int(path[1])][int(path[i])] == 0:
            return f'cant reach from {path[i]}'
    return 'incorrect end'

def check_path_unreachable(G, gen_str, gt, unreachable_flag=False):
    # 如果你的任务有 unreachable 标记，可扩展；目前简化处理
    path = re.findall(r'\d+|x', gen_str)
    if 'x' in path and len(path) < 4:
        return '' if 'x' in gt else 'reachable pair'
    if 'x' in gt and 'x' not in gen_str:
        return 'unreachable pair'
    return check_path(G, gen_str)

def load_test_data(data_path, meta, args):
    stoi = meta['stoi']
    simple_format = meta['simple_format']
    itos = meta['itos']
    block_size = meta['block_size']

    # 选择测试文件
    typedata = args.type_data
    if typedata == 'train':
        f_path = f'{data_path}/simple_train.txt'
    elif typedata in ['0', '1', '2']:
        f_path = f'{data_path}/test_degree{typedata}.txt'
    else:
        f_path = f'{data_path}/{typedata}.txt'

    texts, encode_texts, ground_truth = [], [], []
    with open(f_path, encoding='gbk') as f:
        for line in f:
            if not simple_format:
                texts.append(line.split(':')[0] + ':')
                encode_texts.append(encode(line.split(':')[0] + ':', stoi))
            else:
                pos = find_third_number_position(line)
                line = line[:pos] + line[:line.find(' ')]
                texts.append(line)
                encode_texts.append(encode(line, stoi))
            ground_truth.append(line)

    encode_texts = torch.tensor(encode_texts, dtype=torch.long)
    test_num = args.test_num if args.test_num is not None else len(encode_texts)
    encode_texts = encode_texts[:test_num]

    return encode_texts, ground_truth, itos, block_size

def evaluate_model(model, args, device, log_file=None, step=None):
    """
    评估模型，返回正确率，并可选写入日志文件
    """
    model.eval()
    with torch.no_grad():

        # 加载元数据
        data_path = f'data/{args.dataset}/{args.data_dir}'
        meta_path = f'{data_path}/simple_meta.pkl'
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
        stoi, itos = meta['stoi'], meta['itos']
        unreachable = meta.get('unreachable', False)

        # 加载图结构
        path_graph = nx.read_graphml(f'{data_path}/path_graph.graphml')
        reachability = np.load(f'{data_path}/true_reach_matrix.npy')

        # 加载测试数据
        encode_texts, ground_truth, itos, max_new_tokens = load_test_data(data_path, meta, args)
        batch_size = args.batch_size
        test_num = len(encode_texts)
        batch_num = (test_num + batch_size - 1) // batch_size

        total_cnt, correct_cnt = 0, 0
        all_results = []

        for i in tqdm(range(batch_num), desc="Evaluating", leave=False):
            x = encode_texts[i*batch_size:(i+1)*batch_size].to(device)
            y = model.generate(x, max_new_tokens, temperature=args.temperature, top_k=len(itos))
            y_pred = [decode(y[t].tolist(), itos).split('\n')[0] for t in range(len(x))]

            for t, item in enumerate(y_pred):
                total_cnt += 1
                # 简化：默认不处理 unreachable（如需可加参数控制）
                feedback = check_path(path_graph, item, reachability)
                if feedback == '':
                    correct_cnt += 1
                symbol = ' ' + feedback if feedback else ''
                content = item.split('[PAD]', 1)[0].strip()
                all_results.append(content + symbol)

        accuracy = correct_cnt / total_cnt if total_cnt > 0 else 0.0

        # 写入预测结果文件（可选）
        if args.write_result:
            pred_file = f'pred_{args.type_data}_{step or args.ckpt_iter}_tem_{args.temperature}.txt'
            os.makedirs(args.out_dir, exist_ok=True)
            with open(os.path.join(args.out_dir, pred_file), 'w') as f:
                for line in all_results:
                    f.write(line + '\n')

        # # 写入汇总结果到 results/（可选）
        # result_line = f"Total {test_num}, Correct {correct_cnt} ({accuracy:.4f})"
        # if log_file:
        #     with open(log_file, 'a') as f:
        #         f.write(f"[Step {step}] {result_line}\n")
        # else:
        #     print(result_line)

        # # 写入全局 results 文件（模仿原逻辑）
        # result_file = f'results/{args.dataset}.txt'
        # if args.fix_att:
        #     result_file = f'results/{args.dataset}_fix_att.txt'
        # elif args.result_name:
        #     result_file = f'results/{args.dataset}_{args.result_name}.txt'

        # os.makedirs('results', exist_ok=True)
        # with open(result_file, 'a') as f:
        #     model_id = f"From {args.out_dir}{args.ckpt_iter}_ckpt.pt" if hasattr(args, 'ckpt_iter') else f"Step {step}"
        #     f.write(f"{model_id}, Test data degree {args.type_data}\n")
        #     f.write(result_line + ".\n")

        return accuracy