import torch
import numpy as np
import matplotlib.pyplot as plt
from gym.envs.mujoco.ant_v3 import AntEnv
import ipdb
import os



def state_to_action(old_state, agent, agent_type, task_params, device):

    if agent_type == 'agent':
        st = np.expand_dims(np.concatenate((old_state[:27], task_params)), 0)
        st = torch.from_numpy(st).float().to(device)
        action, _, _ = agent.get_action(st, test=True)
    elif agent_type == 'imitation':
        st = torch.from_numpy(old_state).float().to(device)[:27].unsqueeze(0)
        action = agent(st)
    return action

def to_numpy(arr):
    return arr.cpu().data.numpy()

def evaluate_reward_ant(dataset, agent, agent_type, physics_params, device, seed):
    #ipdb.set_trace()

    if dataset == 'ant':
        from helper.ant_ppo.rl_code.ant_utils import change_xml
        ant_xml_dir = os.getcwd() + '/helper/ant_ppo/rl_code/ant_xml'
        xml_file = ant_xml_dir + '/ant_tmp_replace.xml'
        params = physics_params.cpu().data.numpy()
        change_xml(xml_file, ant_xml_dir, params[0], params[1])
    else:
        from helper.ant_ppo.rl_code_8_legs.ant_utils import change_xml
        ant_xml_dir = os.getcwd() + '/helper/ant_ppo/rl_code_8_legs/ant_xml'
        xml_file = ant_xml_dir + '/ant_tmp_replace.xml'
        params = physics_params.cpu().data.numpy()
        change_xml(xml_file, ant_xml_dir, params)


    env = AntEnv(xml_file)
    env.seed(seed)

    T_policy = 100
    reward_norm = 1
    rewards = []
    
    old_state = env.reset()
    for t in range(T_policy):
        action = state_to_action(old_state, agent, agent_type, params, device)
        new_state, reward, done, info = env.step(to_numpy(action)[0])

        old_state = new_state
        reward = reward / reward_norm
        rewards.append(reward)

    avg_reward = np.mean(rewards)
    final_reward = rewards[-1]
    max_reward = np.max(reward)

    reward_dict = {'avg_reward': avg_reward, 'final_reward': final_reward, 'max_reward': max_reward, 'params': physics_params}
    return reward_dict
