import torch
import pickle
import json
import os
from datetime import datetime
from tqdm import tqdm
import pandas as pd

from GIB_MCS_wo_mcs_feature import GIBErrorDetectionModel, build_graph, load_evaluation_results, is_correct_graph

def load_trained_model(model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(model_path, map_location=device)

    config = checkpoint['model_config']
    model = GIBErrorDetectionModel(
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    print(f"Model loaded from {model_path}")
    return model, device


def inference_one_graph(model, graph_info, problem_id, graph_id, is_correct, threshold=0.3):
        with torch.no_grad():
        outputs = model(graph_info, [])
        edge_masks = outputs['edge_masks']
        edge_ids = outputs['edge_ids']
        G = build_graph(graph_info['structure_str'])

        edge_details = []
        for i, (u, v) in enumerate(edge_ids):
            edge_label = G.get_edge_data(u, v).get("label", f"{u}->{v}")
            edge_text = graph_info.get("embeddings", {}).get(edge_label, None)
            edge_text = graph_info.get("reasoning_string", "") if edge_text is None else str(edge_label)

            score = float(edge_masks[i].item())
            edge_details.append({
                "edge_index": i,
                "edge_nodes": [u, v],
                "edge_label": edge_label,
                "edge_text": edge_text,
                "mask_score": score,
                "prediction": 1 if score >= threshold else 0
            })

        return {
            "problem_id": problem_id,
            "graph_id": graph_id,
            "is_correct": int(is_correct),
            "edge_details": edge_details,
            "num_edges": len(edge_ids),
            "success": True
        }


def batch_inference(model_path, pkl_path, eval_results_path,
                    output_dir="./inference_results_ablation", threshold=0.3):
    os.makedirs(output_dir, exist_ok=True)

    # 1. Load model
    model, device = load_trained_model(model_path)

    # 2. Load graph data
    with open(pkl_path, "rb") as f:
        graph_data = pickle.load(f)

    # 3. Load evaluation results
    eval_results = load_evaluation_results(eval_results_path)

    all_problem_ids = list(graph_data.keys())
    print(f"Total problems: {len(all_problem_ids)}")

    all_results = []
    for pid in tqdm(all_problem_ids, desc="Inference"):
        problem_data = graph_data[pid]
        for gid, ginfo in problem_data.items():
            is_corr = is_correct_graph(ginfo, pid, gid, eval_results)
            result = inference_one_graph(model, ginfo, pid, gid, is_corr, threshold)
            all_results.append(result)


    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    json_path = os.path.join(output_dir, f"results_{timestamp}.json")
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, indent=2, ensure_ascii=False)

    print(f"Results saved to {json_path}")
    return all_results


if __name__ == "__main__":
    model_path = 
    pkl_path = 
    eval_results_path = 
    batch_inference(model_path, pkl_path, eval_results_path)
