import os
from argparse import ArgumentParser
from collections import deque

import cv2
import gymnasium as gym
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder

from customPolicy import CustomActorCriticPolicy
from protosac import ProtoSAC


class TerminateOnOffTrack(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.frame_stack = deque(maxlen=4)
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(64, 64, 4),  # 4 grayscale frames stacked
            dtype=np.uint8
        )

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        processed_frame = self.preprocess(obs)
        stacked_obs = self.get_stacked_obs(processed_frame)

        # CarRacing-v3 exposes this flag
        off_track = getattr(self.env, "on_grass", False)

        if off_track:
            terminated = True  # end episode immediately
            info["off_track"] = True

        return stacked_obs, reward, terminated, truncated, info

    def preprocess(self, obs):
        # Crop the top 35 pixels (scoreboard), grayscale, resize to 64x64
        # cropped = obs[35:, :, :]
        gray = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (64, 64), interpolation=cv2.INTER_AREA)
        return resized.astype(np.uint8)

    def get_stacked_obs(self, new_frame):
        self.frame_stack.append(new_frame)
        while len(self.frame_stack) < 4:
            self.frame_stack.append(new_frame)
        # Stack along the last axis to get shape (64, 64, 4)
        return np.stack(self.frame_stack, axis=-1)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        processed_frame = self.preprocess(obs)
        self.frame_stack.clear()
        stacked_obs = self.get_stacked_obs(processed_frame)
        return stacked_obs, info


TOTAL_TIMESTEPS = 30000

model_kwargs = {
    "learning_rate": 1e-3,
    'verbose': 1
}

ENVIRONMENT = {
    0: "Pendulum-v1",
    1: "LunarLanderContinuous-v3",
    2: "MountainCarContinuous-v0",
    3: "HalfCheetah-v5",
    4: "Humanoid-v5",
    5: "Hopper-v5",
    6: "CarRacing-v3",
}
env_options = ", ".join([f"{k}: {v}" for k, v in ENVIRONMENT.items()])

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--environment', type=int, choices=ENVIRONMENT.keys(),
                        help=f"Choose the OpenAI Gym environment to use. {env_options}. Default: {ENVIRONMENT[0]}",
                        default=0)
    parser.add_argument('--episodes', type=int,
                        help=f"The number of episodes to run in the given environment. Default: {TOTAL_TIMESTEPS}",
                        default=TOTAL_TIMESTEPS)
    parser.add_argument('--baseline', type=bool,
                        help=f"Specify if you want to use the baseline. Default: False",
                        default=False)

    args = parser.parse_args()

    name_env = ENVIRONMENT[args.environment]

    if args.environment == 1:  # LunarLanderContinuous-v3 # hyperparameters at https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/sac.yml

        model_kwargs['learning_rate'] = 7.3e-4
        model_kwargs['tau'] = 0.01
        model_kwargs['learning_starts'] = 10_000
        model_kwargs['policy_kwargs'] = dict(net_arch=[400, 300])

    if args.environment == 2:  # if MountainCarContinuous-v0
        model_kwargs['learning_rate'] = 3e-4
        model_kwargs['buffer_size'] = 50000
        model_kwargs['batch_size'] = 512
        model_kwargs['ent_coef'] = 0.1
        model_kwargs['train_freq'] = 32
        model_kwargs['gradient_steps'] = 32
        model_kwargs['gamma'] = 0.9999
        model_kwargs['tau'] = 0.01
        model_kwargs['learning_starts'] = 0
        model_kwargs['use_sde'] = True
        model_kwargs['policy_kwargs'] = dict(log_std_init=-3.67, net_arch=[64, 64])

    if args.environment in (3, 4, 5):
        model_kwargs['learning_rate'] = 3e-4
        model_kwargs['learning_starts'] = 10000

    if args.environment == 6:
        model_kwargs['learning_rate'] = 3e-4
        model_kwargs['buffer_size'] = 300000
        model_kwargs['tau'] = 0.02
        model_kwargs['train_freq'] = 8
        model_kwargs['gradient_steps'] = 8
        model_kwargs['learning_starts'] = 1000
        model_kwargs['use_sde'] = True
        model_kwargs['use_sde_at_warmup'] = True


    def make_env():
        env = gym.make(name_env, render_mode="rgb_array")

        if args.environment == 6:
            env = TerminateOnOffTrack(env)

        env = Monitor(env)  # record stats such as returns
        return env


    env = DummyVecEnv([make_env])

    env = VecVideoRecorder(
        env,
        f"videos/{name_env}",
        record_video_trigger=lambda x: (x + 200) % (args.episodes // 4) == 0,
        video_length=200,
    )

    if args.baseline:
        model = SAC('MlpPolicy', env, **model_kwargs)
    else:
        model = ProtoSAC(CustomActorCriticPolicy, env, **model_kwargs)

    # start training
    model.learn(total_timesteps=args.episodes, log_interval=4)

os.makedirs(f"models/{name_env}", exist_ok=True)
# del model # remove to demonstrate saving and loading
model.save(f"models/{name_env}/{name_env}")

env.close()
