'''
This file allows you to record human gameplay as trajectories directly into the dataset.
'''
import numpy as np
import gymnasium as gym
import torch
#timelimit wrapper
from gymnasium.wrappers import TimeLimit
import pygame

import os.path
import sys

import hydra

from experiments.envs.metaworld import MetaWorldSawyerEnv, MetaWorldSafetySpeedWrapper

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO


sys.path.append(os.path.join(os.path.dirname(__file__), '..')) # add parent directory to path to import env
from envs.twodoors import TwoDoorsEnv
from envs.manydoors import ManyDoorsEnv

@hydra.main(version_base=None, config_path="../config", config_name="metaworld_record_demos")
def main(cfg):
    if cfg.policy_dir is None:
        print("policy_dir not specified. Please specify a policy_dir in the config file, or in the command line with --policy_dir=<path>")
        print("As a default, the policy_dir will be the latest policy set trained")
        
        all_items = os.listdir("./policies")
        all_dirs = [item for item in all_items if os.path.isdir(f"./policies/{item}")]

        all_dirs.sort()

        if len(all_dirs) == 0:
            print("Error: No policies found in ./policies")
            return

        policy_dir = f"./policies/metaworld/{all_dirs[-1]}"
    else:
        policy_dir = cfg.policy_dir

    print(f"using policy_dir: {policy_dir}")

    path = f"demos/{cfg.env}.npy"
    print("path: ", path)
    if os.path.isfile(path):
        print("loading old demos file")
        print("size so far: ", end=" ")
        demos = np.load(path, allow_pickle=True)
        demos = list(demos)  # convert to list for appending
    else:
        print("making new file")
        demos = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    opposite_identity  = {1: 2, 2: 1}

    num_demos = cfg.num_demos # number of demos to record
    for identity in [2, 1]:
        if cfg.get(f'optimal_policy_{identity}') is not None:
            optimal_model = PPO.load(policy_dir + "/" + cfg.get(f'optimal_policy_{identity}'))
        else:
            #get last model
            all_models = os.listdir(policy_dir)
            all_models = [model for model in all_models if f"manydoors_{identity}" in model]
            all_models.sort()
            optimal_model = PPO.load(f"{policy_dir}/{all_models[-1]}") # this is the model we will use to generate synthetic advantages
    
        #visualize optimal model to make sure it works
        env = MetaWorldSawyerEnv(cfg.env)
        env = MetaWorldSafetySpeedWrapper(env, identity)
        
        obs = env.reset()
        done = False

        while not done:
            action, _states = optimal_model.predict(obs)
            obs, reward, dones, info = env.step(action)

            th_obs = torch.Tensor(obs).to(device)
            th_act = torch.Tensor(action).to(device)
            value, log_prob, entropy = optimal_model.policy.evaluate_actions(th_obs, th_act)
            print(f"value: {value}")
            print(f"log_prob: {log_prob}")
            print(f"entropy: {entropy}")

            env.render()
            done = 1 in dones
        
        print(f"DONE")


        best_logprob = -np.inf
        best_total_reward = -np.inf
        highest_likelihood_demo = None
        best_demo = None

        for i in range(num_demos):
            model_num = i % 10
        
            model = PPO.load(f"{policy_dir}/manydoors_{identity}_{model_num}")

            env = MetaWorldSawyerEnv(cfg.env)
            env = MetaWorldSafetySpeedWrapper(env, identity)

            r = 0
            obs = env.reset()
            done = False

            obses = [obs[0]]
            actions = []
            rewards = []
            logprobs = [] #logprob is unscaled advantage
            values = []

            total_logprob = 1
            total_reward  = 0

            while not done:
                action, _states = model.predict(obs)
                actions.append(action[0])
                print(f"action: {action}")

                th_obs = torch.Tensor(obs).to(device)
                th_act = torch.Tensor(action).to(device)
                value, log_prob, entropy = optimal_model.policy.evaluate_actions(th_obs, th_act)
                total_logprob *= log_prob

                dist = optimal_model.policy.get_distribution(th_obs)

                obs, reward, dones, info = env.step(action)
                obses.append(obs[0])
                rewards.append(reward[0])
                total_reward += reward[0]
                logprobs.append(log_prob[0])
                values.append(value[0])
                env.render()

                done = 1 in dones

            if total_logprob > best_logprob:
                best_logprob = total_logprob
                highest_likelihood_demo = (obses, actions, rewards, logprobs, values, identity)
            
            if total_reward > best_total_reward:
                best_total_reward = total_reward
                best_demo = (obses, actions, rewards, logprobs, values, identity)

            print(f"total reward: {total_reward}")

            logprobs = torch.tensor(logprobs).detach().cpu().numpy()    
            values = torch.tensor(values).detach().cpu().numpy()

            demo = {"obs": obses, "actions": actions, "rewards": rewards, "logprobs": logprobs, "values": values,  "identity": identity}

            demos.append(demo)

        print(f"best demo actions: {best_demo[1]}")
        print(f"best_demo_reward: {np.sum(best_demo[2])}")

        print(f"highest likelihood demo actions: {highest_likelihood_demo[1]}")
        print(f"highest_likelihood_demo_reward: {np.sum(highest_likelihood_demo[2])}")

    demos = np.array(demos)
    print(f"size of demos: {demos.shape[0]}")

    #get hydra run dir for logging
    hydra_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    with open(f"{hydra_dir}/demos_{cfg.env}.npy", "wb") as f:
        np.save(f, demos, allow_pickle=True)

if __name__ == "__main__":
    main()