import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


import json
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

from agentenv.controller import Agent, Evaluator
from agentenv.envs import WebshopTask

MODEL_PATH = "path/to/agentlm-7b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)


from peft import PeftModel

edited_model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="auto"
)
model = PeftModel.from_pretrained(
    edited_model,
    "path/to/cotri-model",
    device_map="auto"
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained('path/to/cotri-model')
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

evaluator = Evaluator(
    Agent(model, tokenizer),
    [
        WebshopTask(
            client_args={
                "env_server_base": "http://127.0.0.1:36001",
                "data_len": 2,
                "timeout": 300,
            },
            n_clients=1,
        )
    ],
)

output_data = {
    "score": None,
    "success": None,
    "experiences": []
}


exps = evaluator.eval(
    generation_config=GenerationConfig(
        do_sample=False,
        max_new_tokens=128,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
        if tokenizer.pad_token_id is not None
        else tokenizer.eos_token_id,
    ),
    max_rounds=7,
    idxs=list(range(200)),
)

output_data["score"] = exps.score
output_data["success"] = exps.success


for idx, exp in enumerate(exps.experiences):
    conversation = []
    for message in exp.conversation:
        conversation.append({
            "from": message["from"],
            "value": message["value"]
        })
    

    step_infos = getattr(exp, "step_infos", []) or []
    feedback_events = []
    null_feedback_events = []
    
    for info in step_infos:
        if not info:
            continue
        if info.get("randomization_applied"):
            feedback_events.append({
                "original_search": info.get("original_search"),
                "effective_search": info.get("effective_search"),
                "randomization_mode": info.get("randomization_mode"),
                "feedback_text": info.get("feedback_text"),
                "random_count_in_episode": info.get("random_count_in_episode")
            })
        if info.get("null_feedback_applied"):
            null_feedback_events.append({
                "current_conversation_round": info.get("current_conversation_round"),
                "feedback_text": info.get("feedback_text"),
                "original_action": info.get("original_action"),
                "effective_action": info.get("effective_action")
            })
    
    random_feedback_triggered = True if len(feedback_events) > 0 else False
    null_feedback_triggered = True if len(null_feedback_events) > 0 else False
    
    reward = getattr(exp, "reward", None)
    success_item = 1 if reward == 1 else 0 if reward is not None else None
    score_item = reward
    
    output_data["experiences"].append({
        "index": idx,
        "conversation": conversation,
        "reward": reward,
        "success": success_item,
        "score": score_item,
        "random_feedback_triggered": random_feedback_triggered,
        "random_feedback_events": feedback_events,
        "null_feedback_triggered": null_feedback_triggered,
        "null_feedback_events": null_feedback_events
    })


with open("evaluation_cotri.json", "w", encoding="utf-8") as f:
    json.dump(output_data, f, ensure_ascii=False, indent=2)

