import collections
import os

import cv2
import gym
import numpy as np
from gym import spaces
import torch

import wandb

import envs.register_shapes2d
from ocr.features_extractor import CSWMSlotExtractor
from rl.callbacks import EvalCallback, WandbCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecTransposeImage

from rl.sac.policies import *
from rl.sac.sac import *
from rl.wrappers.episode_recorder import EpisodeRecorder


class FailOnTimelimitWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        observation, reward, done, info = super().step(action)
        if done and 'is_success' not in info:
            info['is_success'] = False

        return observation, reward, done, info


def make_env(config, rank):
    def init():
        env = gym.make(config['env_id'])
        env = FailOnTimelimitWrapper(env)
        seed = config['seed'] + rank
        env.seed(seed)
        env.action_space.seed(seed)
        env = Monitor(env)
        return env

    return init


def make_eval_env(config, rank, video_folder, record_video_trigger):
    def init():
        env = make_env(config, rank)()
        return EpisodeRecorder(env, video_folder, record_video_trigger, remove_older_seconds=12 * 60 * 60, remove_job_interval_seconds=1 * 60 * 60)

    return init


if __name__ == "__main__":
    config = {
        'total_timesteps': 30_000_000,
        'hidden_dim': 512,
        'batch_size': 128,
        'edge_actions': True,
        'learning_rate': 2e-4,
        'buffer_size': 1000000,
        'ent_coef': 'auto',
        'gamma': 0.99,
        'tau': 0.005,
        'train_freq': 1,
        'gradient_steps': -1,
        'learning_starts': 5000,
        'seed': 0,
        'vec_env_cls': SubprocVecEnv,
        'n_envs': 16,
        'env_id': 'PushingNoAgent5x5-v0',
        'eval_freq': 625,
        'eval_n_episodes': 30,
        'depth': 1,
        'use_wm_optimizer': True,
        'target_entropy': 0.98 * np.log(16) * 0.6,
        'num_objects': 5,
        'embedding_dim': 64,
        'cnn_hidden_dim': 32,
        'mlp_hidden_dim': 128,
        'hinge': 0.025,
        'coef': 1,
        'project': 'pushing-no-agent5x5_roca-cswm_ent-0-6',
    }

    torch.set_float32_matmul_precision('medium')

    wandb_dir = f'./wandb/{config["project"]}'

    env = config['vec_env_cls']([make_env(config, i) for i in range(config['n_envs'])])
    eval_env = VecTransposeImage(DummyVecEnv([make_eval_env(
        config, config['n_envs'], video_folder=f"{wandb_dir}/files/media/videos", record_video_trigger=lambda x: x % config['eval_n_episodes'] in (0, 1)
    )]))

    features_extractor = CSWMSlotExtractor(observation_space=eval_env.observation_space, num_slots=config['num_objects'],
                                       slot_size=config['embedding_dim'], cnn_hidden_dim=config['cnn_hidden_dim'],
                                       mlp_hidden_dim=config['mlp_hidden_dim'])

    transition_model = TransitionModelGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        ignore_action=False,
        copy_action=True,
        use_interactions=True,
        edge_actions=config['edge_actions'])

    reward_model = RewardModelGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        ignore_action=False,
        copy_action=True,
        use_interactions=True,
        edge_actions=config['edge_actions']
    )

    value_model = ValueModelGNN(
        observation_space=env.observation_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        use_interactions=True)

    critic = DiscreteCriticWMGNN(
        transition_model,
        reward_model,
        value_model,
        action_space=env.action_space,
        n_critics=2,
        gamma=config['gamma'],
        depth=config['depth'],
    )

    actor = DiscreteActorGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        use_interactions=True)

    policy = SACCustomPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        lr=config['learning_rate'],
        features_extractor=features_extractor,
        actor=actor,
        critic=critic,
        use_wm_optimizer=config['use_wm_optimizer'],
        is_frozen_features_extractor=False,
    )

    os.makedirs(wandb_dir, exist_ok=True)

    run = wandb.init(
        project=config['project'],
        config=config,
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
        save_code=True,  # optional
        dir=wandb_dir,
        name=f'run_h-{config["hinge"]}_{config["seed"]}',
    )

    model = DiscreteSACWMGNN(policy=policy, env=env, learning_rate=config['learning_rate'],
                  batch_size=config['batch_size'], replay_buffer=None,
                  buffer_size=config['buffer_size'], learning_starts=config['learning_starts'], gamma=config['gamma'],
                  train_freq=config['train_freq'], gradient_steps=config['gradient_steps'], ent_coef=config['ent_coef'],
                  verbose=1, tau=config['tau'], seed=config['seed'], target_entropy=config['target_entropy'],
                  tensorboard_log=wandb.run.dir,
                  contrastive_learning_kwargs={key: value for key, value in config.items() if key in ('hinge', 'coef')})

    model.learn(
        total_timesteps=config['total_timesteps'],
        log_interval=10,
        callback=[
            WandbCallback(
                gradient_save_freq=0,
                verbose=2,
                model_save_freq=config['eval_freq'],
                model_save_path=f"{wandb.run.dir}/models/",
            ),
            EvalCallback(
                eval_env,
                eval_freq=config['eval_freq'],
                n_eval_episodes=config['eval_n_episodes'],
                best_model_save_path=f"{wandb.run.dir}/models/",
                log_path=f"{wandb.run.dir}/eval_logs/",
                deterministic=False,
            ),
        ]
    )
    run.finish()

