## This file uses Q learning to learn a synthetic advantage function for each preference class

import numpy as np
import torch
import torch.nn as nn
from typing import Tuple, Callable
from stable_baselines3.common.policies import ActorCriticPolicy
import hydra

import gymnasium as gym
from gymnasium import spaces

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor 
from stable_baselines3 import PPO

from ..envs import TwoDoorsEnv, ManyDoorsEnv
from minigrid.wrappers import FullyObsWrapper, ImgObsWrapper
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation
from ..envs.wrappers import OneHotPartialImage, OneHotFullImage

class MinigridFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None:
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, (2, 2)),
            nn.ReLU(),
            nn.Conv2d(16, 64, (2, 2)),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

def make_custom_manydoors(identity):
    env = ManyDoorsEnv(identity=identity, position_reward=False)
    env = ImgObsWrapper(env)

    print(f"env observation space: {env.observation_space}")
    print(f"env action space: {env.action_space.n}")
    return env

def make_custom_twodoors(identity):
    env = TwoDoorsEnv(identity=identity, position_reward=False)
    env = ImgObsWrapper(env)

    print(f"env observation space: {env.observation_space}")
    print(f"env action space: {env.action_space.n}")
    return env


@hydra.main(config_path="../config", config_name="train_policies")
def main(cfg):
    print(torch.cuda.is_available())

    if torch.cuda.is_available():
        device = torch.device("cuda")

    pretrain = False

    env_fn = make_custom_manydoors if cfg.env == "manydoors" else make_custom_twodoors

    for identity in [1, 2]:
        env = make_vec_env(lambda : env_fn(identity), cfg.num_parallel_envs)

        policy_kwargs = dict(
            features_extractor_class=MinigridFeaturesExtractor,
            features_extractor_kwargs=dict(features_dim=cfg.features_dim),
        )

        #get saved model if pretraining
        if pretrain:
            print(f"loading model for identity {identity}")
            model = PPO.load(f"models/{cfg.env}_{identity}_29", env)
            # insert into vector en
        else:    
            model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

        for i in range(cfg.num_policies):
            model.learn(total_timesteps=cfg.timesteps)
            model.save(f"{cfg.env}_{identity}_{i}")

if __name__ == "__main__":
    main()