from data_loader import math_500_loader, math_n_loader, math_train_loader, gsm8k_test_loader, gsm8k_train_loader, collegemath_test_loader, collegemath_train_loader, aime_loader, hardmath2_loader, usamo_loader
from models.answer_verifier import AnswerVerifier, StepVerifier
from models.backtranslator import BackTranslator

from langgraph.prebuilt import create_react_agent
from langchain_deepseek import ChatDeepSeek
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from typing import Annotated
from tqdm import tqdm


from itertools import repeat
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
from utils import timeout_handler
import os
import sys

import json
import pickle

from prover.utils import get_datetime, load_config, AttrDict
from prover.lean.verifier import Lean4ServerScheduler

from hermes import HermesReasoner

import operator
from typing import Annotated, Any, Dict, List, Literal, Sequence, TypedDict

from langchain_core.messages import BaseMessage
from langgraph.managed import RemainingSteps

MAX_RECURSION_LIMIT = 100
LEANSERVER_CONFIG = 'configs/lean4_server.py'
lean4server_cfg = load_config(LEANSERVER_CONFIG)

setting = 'zero_shot'

with open(f'prompts/{setting}_cot.txt', 'r') as f:
    prompt_schema = f.read()

with open('setup.json', 'r') as f:
        setup = json.load(f)
TRANSLATOR_CONFIG = setup['translator']
PROVER_CONFIG = setup['prover']

#%% Init Lean server
scheduler = Lean4ServerScheduler(
            max_concurrent_requests=lean4server_cfg.get("lean_max_concurrent_requests", 4),
            timeout=lean4server_cfg.get("lean_timeout", 120),
            memory_limit=lean4server_cfg.get("lean_memory_limit", 10),
            name=lean4server_cfg.get("name", 'test-server'),
        )
    

def run_agent(prompt, agent):
    output = agent.invoke(
        {"messages": [{"role": "user", "content": prompt}]},
        {"recursion_limit": MAX_RECURSION_LIMIT}
    )
    return output

def attempt_problem_with_lean(problem, agent):
    output = run_agent(prompt_schema.replace('<question>', problem), agent)
    return output


def predict_helper(data_sample, embedding_model):
    
    reasoner = HermesReasoner(scheduler=scheduler,
                              translator_config=TRANSLATOR_CONFIG,
                              prover_config=PROVER_CONFIG,
                              embedding_model=embedding_model,
                              user_id='abc123')

    prompt = "You are a helpful assistant who is proficient in mathematics and uses verification tools to make sure every step is correct."
    agent = create_react_agent(
        model=llm,
        tools=[reasoner],
        prompt=prompt
    )

    return timeout_handler(
        attempt_problem_with_lean,
        args=(data_sample, agent),
        timeout_duration=1500,
    )


if __name__ == '__main__':
    translator_cfg = load_config(TRANSLATOR_CONFIG)
    prover_cfg = load_config(PROVER_CONFIG)

    #%% Init Agent - Deepseek V3
    model = 'deepseek-reasoner' # configure environment to have deepseek api key
    llm = ChatDeepSeek(
        model=model,
        temperature=0.95,
        max_tokens=8192,
        timeout=300,
        max_retries=20,
    )

    embedding_model = HuggingFaceEmbeddings(model_name="Qwen/Qwen3-Embedding-0.6B",
                                            model_kwargs={'device': 'cuda:0'},
                                            encode_kwargs={"normalize_embeddings": True})

    problem = 'Prove that 2+2=5'
    answer = predict_helper(problem, embedding_model=embedding_model)

    print(answer)


scheduler.close()