from url_benchmark.agent.q_learning import QLearningAgent
import torch
import numpy as np
import time

def evaluate(env, agent, replay_loader, inf_logger, eval_type, work_dir, step, index):
    """
    Evaluation procedure:
    - Sets up the environment for a new evaluation
    - Performs tabular Q-learning to obtain the optimal policy
    - Performs agent inference
    - Compares and plots the policies
    """

    # 1. Setup environment for evaluation (e.g., sample new goals)
    # (Assume env.reset() or custom setup_eval already called externally)
    # pos_goal, neg_goal, reward_array, reward_fn = env.setup_eval(eval_type)
    pos_goal, neg_goal, reward_array, reward_fn = env.sample_eval_task(eval_type, index)

    # 2. Tabular Q-learning (BellmanFordAgent now supports arbitrary reward_fn)
    q_agent = QLearningAgent(env, reward_fn)
    q_agent.solve()

    q_agent.plot_q_function(work_dir=work_dir, step=step, task_str=eval_type + f"_{index}")

    # 3. Agent inference (get agent's policy)
    # Assume agent has a get_policy() method compatible with env
    start_time = time.time()
    z = agent.inference(replay_loader, inf_logger, pos_goal, neg_goal, reward_fn) # implement this function
    end_time = time.time()
    inf_time = end_time - start_time

    # 4. Compare and plot policies
    state_list = env.get_state_list()
    obs_list = [torch.tensor(env.get_obs_from_state(state)).unsqueeze(0) for state in state_list] # implement this function
    # print(obs_list)
    # print(len(state_list))
    obs_list = torch.cat(obs_list, dim=0).to(agent.cfg.device)
    # print(obs_list.shape, goal.shape)
    q_list = agent.q_function_inference(obs_list, pos_goal, neg_goal, z).detach()
    v_list = torch.max(q_list, dim=1)[0]
    # # v_list = v_list
    a_list = torch.argmax(q_list, dim=1).cpu()
    # v_list = {}
    # a_list = {}
    # for i in range(len(state_list)):
    #     v_list[state_list[i]] = torch.max(q_list[i]).item()
    #     a_list[state_list[i]] = torch.argmax(q_list[i]).cpu()
    # print(v_list, a_list)
    env.plot_v_function(work_dir, obs_list.cpu(), v_list, a_list, f"training_step_{step}_{eval_type}_{index}_v_function") # write this function

    num_pos = 0
    num_neg = 0
    for i in range(len(state_list)):
        # print('State: ', state_list[i], ' | Optimal Action: ', q_agent.actions[(state_list[i][1], state_list[i][0])], ' | Policy action: ', a_list[i].item())

        if a_list[i].item() in q_agent.actions[(state_list[i][1], state_list[i][0])]:
            num_pos += 1
        else:
            num_neg += 1
    print('Positive: ', num_pos, ' | Negative: ', num_neg, ' | Inference Time: ', inf_time)
    return num_pos, num_neg, inf_time


        

