'''
This file allows you to record human gameplay as trajectories directly into the dataset.
'''
import os.path
import sys

import gymnasium as gym
import hydra
import numpy as np
import pygame
import torch
# timelimit wrapper
from gymnasium.wrappers import TimeLimit
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation
from minigrid.core.actions import Actions
# from RecordController import ManualControlRecord
from minigrid.minigrid_env import MiniGridEnv
from minigrid.wrappers import ImgObsWrapper
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.policies import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv

from experiments.envs.manydoors import ManyDoorsEnv
from experiments.envs.metaworld import (MetaWorldSafetySpeedWrapper,
                                        MetaWorldSawyerEnv)
from experiments.envs.twodoors import TwoDoorsEnv

from ..envs.wrappers import OneHotFullImage, OneHotPartialImage

# add parent directory to path to import env
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

# Recreate the custom feature extractor


class CustomExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim):
        super(CustomExtractor, self).__init__(observation_space, features_dim)
        self._features_dim = observation_space.shape[0]  # Assuming 1D input

    def forward(self, observations):
        return observations


# Initialize the model with the same arguments
policy_kwargs = dict(
    features_extractor_class=CustomExtractor,
    features_extractor_kwargs=dict(features_dim=35)
)


@hydra.main(version_base=None, config_path="../config", config_name="record_demos")
def main(cfg):
    if cfg.policy_dir is None:
        print("policy_dir not specified. Please specify a policy_dir in the config file, or in the command line with --policy_dir=<path>")
        print("As a default, the policy_dir will be the latest policy set trained")

        name = cfg.env
        if cfg.env[-2:] == "v2":
            name = "metaworld"

        policy_dir = f"./policies/{name}"

        all_items = os.listdir(policy_dir)
        all_dirs = [item for item in all_items if os.path.isdir(
            f"{policy_dir}/{item}")]

        all_dirs.sort()

        if len(all_dirs) == 0:
            print("Error: No policies found in ./policies/env_name")
            return

        policy_dir = f"{policy_dir}/{all_dirs[-1]}"
    else:
        policy_dir = cfg.policy_dir

    print(f"using policy_dir: {policy_dir}")

    path = f"demos/{cfg.env}.npy"
    print("path: ", path)
    if os.path.isfile(path):
        print("loading old demos file")
        print("size so far: ", end=" ")
        demos = np.load(path, allow_pickle=True)
        demos = list(demos)  # convert to list for appending
    else:
        print("making new file")
        demos = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    opposite_identity = {1: 2, 2: 1}

    num_demos = cfg.num_demos  # number of demos to record
    for identity in [2, 1]:
        if cfg.env == "twodoors":
            env = TwoDoorsEnv(render_mode="none",
                              position_reward=False, identity=identity)
            env = ImgObsWrapper(env)
            env = DummyVecEnv([lambda: env])
        elif cfg.env == "manydoors":
            env = ManyDoorsEnv(render_mode="none",
                               position_reward=False, identity=identity)
            env = ImgObsWrapper(env)
            env = DummyVecEnv([lambda: env])
        elif cfg.env[-2:] == "v2":
            # all metaworld environments end in v2
            env = MetaWorldSawyerEnv(cfg.env)
            env = MetaWorldSafetySpeedWrapper(env, identity)
        else:
            print("Error: environment not recognized")
            print("Must be either twodoors, manydoors, or one of the metaworld envs")
            return

        if cfg.get(f'optimal_policy_{identity}') is not None:
            optimal_model = PPO.load(
                policy_dir + "/" + cfg.get(f'optimal_policy_{identity}'))
        else:
            # get last model
            all_models = os.listdir(policy_dir)
            all_models = [
                model for model in all_models if f"{cfg.env}_{identity}" in model]
            all_models.sort()
            # this is the model we will use to generate synthetic advantages
            print(f"loading optimal model: {all_models[-1]}")
            optimal_model = PPO.load(
                f"{policy_dir}/{all_models[-1]}", env=env)

            optimal_model.policy.features_extractor = CustomExtractor(
                env.observation_space, features_dim=35)

            optimal_model.policy.pi_features_extractor = CustomExtractor(
                env.observation_space, features_dim=35)

        best_logprob = -np.inf
        best_total_reward = -np.inf
        highest_likelihood_demo = None
        best_demo = None

        for i in range(num_demos):
            model_num = i % 10

            if cfg.env == "twodoors":
                env = TwoDoorsEnv(render_mode="none",
                                  position_reward=False, identity=identity)
                env = ImgObsWrapper(env)
                env = DummyVecEnv([lambda: env])
            elif cfg.env == "manydoors":
                env = ManyDoorsEnv(render_mode="none",
                                   position_reward=False, identity=identity)
                env = ImgObsWrapper(env)
                env = DummyVecEnv([lambda: env])
            elif cfg.env[-2:] == "v2":
                # all metaworld environments end in v2
                env = MetaWorldSawyerEnv(cfg.env)
                env = MetaWorldSafetySpeedWrapper(env, identity)
            else:
                print("Error: environment not recognized")
                print("Must be either twodoors, manydoors, or one of the metaworld envs")
                return

            model = PPO.load(
                f"{policy_dir}/{cfg.env}_{identity}_{model_num}", env=env)

            r = 0
            obs, info = env.reset()
            done = False

            obses = [obs]
            actions = []
            rewards = []
            logprobs = []  # logprob is unscaled advantage
            values = []

            total_logprob = 0
            total_reward = 0

            while not done:
                action, _states = model.predict(obs)

                actions.append(action)

                th_obs = torch.Tensor(obs).to(device)
                th_act = torch.Tensor(action).to(device)

                value, log_prob, entropy = optimal_model.policy.evaluate_actions(
                    th_obs, th_act)

                #print(
                #    f"value: {value}, log_prob: {log_prob}, entropy: {entropy}")
                total_logprob += log_prob

                dist = optimal_model.policy.get_distribution(th_obs)

                obs, reward, dones, trunc, info = env.step(action)
                obses.append(obs)
                rewards.append(reward)
                total_reward += reward
                logprobs.append(log_prob)
                values.append(value)
                if cfg.render:
                    env.render()

                done = 1 in dones if isinstance(dones, list) else dones

            if total_logprob > best_logprob:
                best_logprob = total_logprob
                highest_likelihood_demo = (
                    obses, actions, rewards, logprobs, values, identity)

            if total_reward > best_total_reward:
                best_total_reward = total_reward
                best_demo = (obses, actions, rewards,
                             logprobs, values, identity)

            logprobs = torch.tensor(logprobs).detach().cpu().numpy()
            values = torch.tensor(values).detach().cpu().numpy()

            demo = {"obs": obses, "actions": actions, "rewards": rewards,
                    "logprobs": logprobs, "values": values,  "identity": identity}

            demos.append(demo)

        print(f"best demo actions: {best_demo[1]}")
        print(f"best_demo_reward: {np.sum(best_demo[2])}")

        print(f"highest likelihood demo actions: {highest_likelihood_demo[1]}")
        print(
            f"highest_likelihood_demo_reward: {np.sum(highest_likelihood_demo[2])}")

    demos = np.array(demos)
    print(f"size of demos: {demos.shape[0]}")

    # get hydra run dir for logging
    hydra_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    with open(f"{hydra_dir}/demos_{cfg.env}.npy", "wb") as f:
        np.save(f, demos, allow_pickle=True)


if __name__ == "__main__":
    main()
