import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import transformers
from envs import generate_darkroom_env
from trajectory_generation import rollin_mdp, generate_preference_histories_from_envs
from utils import convert_to_tensor

def online_evaluate_policy(goals, policy, horizon, n_episodes=5):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """
    print("Online Evaluation...")
    rewards = {}
    for goal in goals:
        goal = np.array(goal)
            
        eval_env = generate_darkroom_env(dim=10, horizon=100, goals=[goal])[0]
        policy.reset()
        # policy.append("context_states", traj["context_states"])
        # policy.append("context_actions", traj["context_actions"])
        # policy.append("context_next_states", traj["context_next_states"])
        # policy.append("context_rewards", traj["context_rewards"][:, None])
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
            state = eval_env.reset()
            total_reward = 0
            for i in range(horizon):
                action = policy.act(state, sample=True)
                next_state, reward = eval_env.transit(state, action)
                total_reward += reward
                policy.append("context_states", state)
                policy.append("context_actions", action)
                policy.append("context_next_states", next_state)
                policy.append("context_rewards", reward)
                state = next_state
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
        
        print(f"total reward: {env_reward/n_episodes}")
        rewards[(goal[0], goal[1])] = (max_reward, env_reward/n_episodes)
    return rewards

def offline_evaluate_policy(goals, policy, horizon, n_episodes=5):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """
    
    print("Offline Evaluation...")
    rewards = {}
    for goal in goals:
        goal = np.array(goal)
        eval_env = generate_darkroom_env(dim=10, horizon=100, goals=[goal])[0]
        context_states, context_actions, context_next_states, context_rewards = rollin_mdp(eval_env, "expert", mode="step", random_p=0.0)
        # print(f"The goal is {traj['goal']}")
        policy.reset()
        policy.append("context_states", context_states[None, :, :])
        policy.append("context_actions", context_actions[None, :, :])
        policy.append("context_next_states", context_next_states[None, :, :])
        policy.append("context_rewards", context_rewards[None, :, None])
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
            state = eval_env.reset()
            total_reward = 0
            for i in range(horizon):
                action = policy.act(state, sample=True)
                next_state, reward = eval_env.transit(state, action)
                total_reward += reward
                # policy.append("context_states", state)
                # policy.append("context_actions", action)
                # policy.append("context_next_states", next_state)
                # policy.append("context_rewards", reward)
                state = next_state
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
        print(f"total reward: {env_reward/n_episodes}")
        rewards[(goal[0], goal[1])] = (max_reward, env_reward/n_episodes)
    return rewards

def online_evaluate_policy_with_preference(goals, policy, horizon, n_episodes=10, preference_model=None):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """
    print("Online Evaluation...")
    rewards = {}
    
    for goal in goals:
        goal = np.array(goal)
            
        eval_env = generate_darkroom_env(dim=10, horizon=100, goals=[goal])[0]
        traj = generate_preference_histories_from_envs([eval_env], 1, 1, "uniform", "step", 0.2)[0]
        policy.reset()
        
        preference = {"traj_1": {}, "traj_2": {}}
        if traj["preference"] == 0:
            preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
        elif traj["preference"] == 1:
            preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
        
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
            state = eval_env.reset()
            total_reward = 0
            for i in range(horizon):
                action = policy.act(state, sample=True)
                next_state, reward = eval_env.transit(state, action)
                # breakpoint() # we need to change the state and action to the tensor
                estimated_reward = preference_model(
                    preference,
                    torch.from_numpy(state).float().to(policy._device)[None, :], 
                    torch.from_numpy(action).float().to(policy._device)[None, :],
                    test=True).detach()
                total_reward += reward
                policy.append("context_states", state)
                policy.append("context_actions", action)
                policy.append("context_next_states", next_state)
                policy.append("context_rewards", estimated_reward)
                state = next_state
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
        
        print(f"Collected Total Reward: {env_reward/n_episodes}")
        rewards[(goal[0], goal[1])] = (max_reward, env_reward/n_episodes)
    return rewards

def offline_evaluate_policy_with_preference(goals, policy, horizon, n_episodes=10, preference_model=None):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """    
    print("Offline Evaluation...")
    rewards = {}
    
    reward_li = []
    for goal in goals:
            
        goal = np.array(goal)
            
        eval_env = generate_darkroom_env(dim=10, horizon=100, goals=[goal])[0]
        traj = generate_preference_histories_from_envs([eval_env], 1, 1, "uniform", "step", 0.2)[0]
        policy.reset()
        
        horizon, _ = traj["traj_1"]["context_states"].shape
        
        context_states, context_actions, context_next_states, context_rewards = rollin_mdp(eval_env, "expert", mode="step", random_p=0.8)
        
        
        preference = {"traj_1": {}, "traj_2": {}}
        
        if traj["preference"] == 0:
            preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
            
        elif traj["preference"] == 1:
            preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
            

        policy.append("context_states", context_states[None, :, :])
        policy.append("context_actions", context_actions[None, :, :])
        policy.append("context_next_states", context_next_states[None, :, :])
        estimated_rewards = torch.zeros((1, horizon, 1)).to(preference_model._device)
        
        for i in range(horizon):
            query_states = torch.from_numpy(context_states).float().to(policy._device)[None, :][:, i, :]
            optimal_actions = torch.from_numpy(context_actions).float().to(policy._device)[None, :][:, i, :]
            estimated_reward = preference_model(preference, query_states, optimal_actions, test=True).detach()
            estimated_rewards[:, i, :] = estimated_reward
        
        policy.append("context_rewards", estimated_rewards)
            
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
            state = eval_env.reset()
            total_reward = 0
            for i in range(horizon):
                action = policy.act(state, sample=True)
                next_state, reward = eval_env.transit(state, action)
                # breakpoint() # we need to change the state and action to the tensor
                estimated_reward = preference_model(
                    preference,
                    torch.from_numpy(state).float().to(policy._device)[None, :], 
                    torch.from_numpy(action).float().to(policy._device)[None, :],
                    test=True).detach()
                total_reward += reward
                state = next_state
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
            reward_li.append(total_reward)
        
        print(f"Collected Total Reward: {env_reward/n_episodes}")
        rewards[(goal[0], goal[1])] = (max_reward, env_reward/n_episodes)
    return rewards, reward_li


def online_evaluate_policy_DPT(goals, policy, horizon, n_episodes=10, preference_model=None):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """
    print("Online Evaluation...")
    rewards = {}
    reward_li = []
    for goal in goals:
        goal = np.array(goal)
            
        eval_env = generate_darkroom_env(dim=10, horizon=100, goals=[goal])[0]
        traj = generate_preference_histories_from_envs([eval_env], 1, 1, "uniform", "step", 0.2)[0]
        
        preference = {"traj_1": {}, "traj_2": {}}
        if traj["preference"] == 0:
            preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
        elif traj["preference"] == 1:
            preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
            preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
        
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
            state = eval_env.reset()
            total_reward = 0
            for i in range(horizon):
                tensor_state = torch.tensor(state).float().to(policy._device)[None, :]
                action_probs = policy(preference, tensor_state, test=True)[0]
                action_index = np.random.choice(np.arange(policy._action_dim), p=action_probs)

        
                action = np.zeros(policy._action_dim)
                action[action_index] = 1.0
                
                next_state, reward = eval_env.transit(state, action)
                total_reward += reward
                state = next_state
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
            reward_li.append(total_reward)
        print(f"Collected Total Reward: {env_reward/n_episodes}")
        rewards[(goal[0], goal[1])] = (max_reward, env_reward/n_episodes)
    return rewards, reward_li