import torch
import numpy as np
import matplotlib.pyplot as plt

from envs import generate_darkroom_env
from trajectory_generation import rollin_mdp, generate_preference_histories_from_envs
from utils import convert_to_tensor

import metaworld
from gym.wrappers import TimeLimit

def overload_env_step(self, action):
    """Overload environment step function to handle truncation."""
    truncated = False
    observation, reward, terminated, truncated, info = self.env.step(action)
    self._elapsed_steps += 1
    if self._elapsed_steps >= self._max_episode_steps:
        truncated = True
    done = terminated or truncated
    return observation, reward, done, info

def online_evaluate_policy_with_preference(test_trajs, optimal_trajs, policy, horizon, n_episodes=10, preference_model=None):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """
    print("Online Evaluation...")
    rewards = {}
    
    for task_id in test_trajs.keys():
        ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
        env = ml1.train_classes['pick-place-v2']()
        task = ml1.train_tasks[task_id]
        env.set_task(task)
        env.max_path_length = 200
        env = TimeLimit(env, max_episode_steps=200)
        env.step = overload_env_step.__get__(env, env.__class__)
        
        policy.reset()
        
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
        
            horizon = 200
            traj = test_trajs[task_id][np.random.randint(0, len(test_trajs[task_id]))]
            preference = {"traj_1": {}, "traj_2": {}}
            if traj["preference"] == 0:
                preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
            elif traj["preference"] == 1:
                preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
            
            
            state, _ = env.reset()
            total_reward = 0
            done = False
            while not done:
                action = policy.act(state)
                next_state, reward, done, _ = env.step(action)
                # breakpoint() # we need to change the state and action to the tensor
                estimated_reward = preference_model(
                    preference,
                    torch.from_numpy(state).float().to(policy._device)[None, :], 
                    torch.from_numpy(action).float().to(policy._device)[None, :],
                    test=True).detach()
                total_reward += reward
                policy.append("context_states", state)
                policy.append("context_actions", action)
                policy.append("context_next_states", next_state)
                policy.append("context_rewards", estimated_reward)
                state = next_state
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
        print(f"Collected Total Reward: {env_reward/n_episodes}")
        rewards[task_id] = (max_reward, env_reward/n_episodes)
    return rewards

def offline_evaluate_policy_with_preference(test_trajs, optimal_trajs, policy, horizon, n_episodes=10, preference_model=None):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """    
    print("Offline Evaluation...")
    rewards = {}
    reward_li = []
    
    for task_id in test_trajs.keys():
        ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
        env = ml1.train_classes['pick-place-v2']()
        task = ml1.train_tasks[task_id]
        env.set_task(task)
        env.max_path_length = 200
        env = TimeLimit(env, max_episode_steps=200)
        env.step = overload_env_step.__get__(env, env.__class__)
        
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
            policy.reset()
            
            horizon = 200
            traj = test_trajs[task_id][np.random.randint(0, len(test_trajs[task_id]))]
            preference = {"traj_1": {}, "traj_2": {}}
            
            if traj["preference"] == 0:
                preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
                
            elif traj["preference"] == 1:
                preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
                
            optimal_context = optimal_trajs[task_id][np.random.randint(0, len(optimal_trajs[task_id]))]["traj_1"]
            policy.append("context_states", optimal_context["context_states"][None, :, :])
            policy.append("context_actions", optimal_context["context_actions"][None, :, :])
            policy.append("context_next_states", optimal_context["context_next_states"][None, :, :])
            estimated_rewards = torch.zeros((1, horizon, 1)).to(preference_model._device)
            
            for i in range(horizon):
                query_states = torch.from_numpy(optimal_context["context_states"]).float().to(policy._device)[None, :][:, i, :]
                optimal_actions = torch.from_numpy(optimal_context["context_actions"]).float().to(policy._device)[None, :][:, i, :]
                estimated_reward = preference_model(preference, query_states, optimal_actions, test=True).detach()
                estimated_rewards[:, i, :] = estimated_reward
            
            policy.append("context_rewards", estimated_rewards)
                
            state, _ = env.reset()
            total_reward = 0
            done = False
            num_steps = 0
            while not done:
                action = policy.act(state)
                next_state, reward, done, _ = env.step(action)
                estimated_reward = preference_model(
                    preference,
                    torch.from_numpy(state).float().to(policy._device)[None, :], 
                    torch.from_numpy(action).float().to(policy._device)[None, :],
                    test=True).detach()
                total_reward += reward
                state = next_state
                num_steps += 1
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
            reward_li.append(total_reward)
        print(f"Collected Total Reward: {env_reward/n_episodes}")
        rewards[task_id] = (max_reward, env_reward/n_episodes)
    return rewards, reward_li

def online_evaluate_policy_DPT(test_trajs, policy, horizon, n_episodes=10, preference_model=None):
    """
    Evaluates the policy on the environment.
    Args:
        env: The environment to evaluate the policy on.
        policy: The policy to evaluate.
        n_episodes: The number of episodes to evaluate the policy on.
    Returns:
        The average reward of the policy on the environment.
    """
    print("Online Evaluation...")
    rewards = {}
    reward_li = []
    for task_id in test_trajs.keys():
        ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
        env = ml1.train_classes['pick-place-v2']()
        task = ml1.train_tasks[task_id]
        env.set_task(task)
        env.max_path_length = 200
        env = TimeLimit(env, max_episode_steps=200)
        env.step = overload_env_step.__get__(env, env.__class__)
        
        env_reward = 0
        max_reward = -float("inf")
        for _ in range(n_episodes):
            
            horizon = 200
            traj = test_trajs[task_id][np.random.randint(0, len(test_trajs[task_id]))]
        
            preference = {"traj_1": {}, "traj_2": {}}
            if traj["preference"] == 0:
                preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
            elif traj["preference"] == 1:
                preference["traj_1"]["context_states"] = torch.from_numpy(traj["traj_2"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_actions"] = torch.from_numpy(traj["traj_2"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_next_states"] = torch.from_numpy(traj["traj_2"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_1"]["context_rewards"] = torch.from_numpy(traj["traj_2"]["context_rewards"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_states"] = torch.from_numpy(traj["traj_1"]["context_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_actions"] = torch.from_numpy(traj["traj_1"]["context_actions"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_next_states"] = torch.from_numpy(traj["traj_1"]["context_next_states"]).float().to(policy._device)[None, :]
                preference["traj_2"]["context_rewards"] = torch.from_numpy(traj["traj_1"]["context_rewards"]).float().to(policy._device)[None, :]
        
            state, _ = env.reset()
            total_reward = 0
            done = False
            num_steps = 0
            while not done:
                tensor_state = torch.tensor(state).float().to(policy._device)[None, :]
                action = policy(preference, tensor_state, test=True)[0]
                next_state, reward, done, _ = env.step(action)
                total_reward += reward
                state = next_state
                num_steps += 1
            env_reward += total_reward
            max_reward = max(max_reward, total_reward)
            reward_li.append(total_reward)
        print(f"Collected Total Reward: {env_reward/n_episodes}")
        rewards[task_id] = (max_reward, env_reward/n_episodes)
    return rewards, reward_li