import collections
import os

import cv2
import gym
import numpy as np
import torch
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer

import wandb
from omegaconf import OmegaConf
from torch import nn
from wandb.integration.sb3 import WandbCallback

import envs
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecTransposeImage

from ocr.slate.slate import SLATE
from rl.sac.policies import *
from rl.sac.sac import *
from rl.wrappers.episode_recorder import EpisodeRecorder


class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env: gym.Env, width: int = 64, height: int = 64):
        gym.ObservationWrapper.__init__(self, env)
        self.width = width
        self.height = height
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(self.height, self.width, env.observation_space.shape[2]),
            dtype=env.observation_space.dtype
        )

    def observation(self, frame: np.ndarray) -> np.ndarray:
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_CUBIC)
        return frame


class FailOnTimelimitWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        observation, reward, done, info = super().step(action)
        if done and 'is_success' not in info:
            info['is_success'] = False

        return observation, reward, done, info


def load_features_extractor_state_dict(path):
    state_dict = torch.load(path)["ocr_module_state_dict"]
    return state_dict


def make_env(config, rank):
    def init():
        env = gym.make(config['env_id'])
        env = FailOnTimelimitWrapper(env)
        env = WarpFrame(env, width=config['width'], height=config['height'])
        seed = config['seed'] + rank
        env.seed(seed)
        env.action_space.seed(seed)
        env = Monitor(env)
        return env

    return init


def make_eval_env(config, rank, video_folder, record_video_trigger):
    def init():
        env = make_env(config, rank)()
        return EpisodeRecorder(env, video_folder, record_video_trigger)

    return init


if __name__ == "__main__":
    config = {
        'total_timesteps': 30_000_000,
        'hidden_dim': 512,
        'batch_size': 128,
        'learning_rate': 2e-4,
        'buffer_size': 1000000,
        'ent_coef': 'auto',
        'gamma': 0.99,
        'tau': 0.005,
        'train_freq': 1,
        'gradient_steps': -1,
        'learning_starts': 5000,
        'seed': 0,
        'vec_env_cls': SubprocVecEnv,
        'n_envs': 16,
        'env_id': 'Navigation10x10-v0',
        'width': 96,
        'height': 96,
        'env_conf_path': 'envs/config/navigation10x10.yaml',
        'slate_conf_path': 'ocr/slate/config/navigation10x10.yaml',
        'slate_checkpoint_path': 'navigation10x10/model_best.pth',
        'eval_freq': 625,
        'eval_n_episodes': 20,
        'depth': 1,
        'use_wm_optimizer': True,
        'target_entropy': 0.98 * np.log(28) * 0.6,
        'project': 'navigation10x10_sacwmgnn-slate_auto-alpha-0-6'
    }

    features_extractor_state_dict = load_features_extractor_state_dict(
        config['slate_checkpoint_path'])

    env = config['vec_env_cls']([make_env(config, i) for i in range(config['n_envs'])])

    config_ocr = OmegaConf.load(config['slate_conf_path'])
    config['num_objects'] = config_ocr.slotattr.num_slots
    config['embedding_dim'] = config_ocr.slotattr.slot_size

    config_env = OmegaConf.load(config['env_conf_path'])
    features_extractor = SLATE(config_ocr, config_env, env.observation_space, preserve_slot_order=False)
    features_extractor._module.load_state_dict(features_extractor_state_dict)
    features_extractor = features_extractor.cuda()

    for model in (features_extractor,): # transition_model, reward_model):
        model.eval()
        for param in model.parameters():
            param.requires_grad = False

    transition_model = TransitionModelGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        ignore_action=False,
        copy_action=True,
        use_interactions=True,
        edge_actions=True)

    reward_model = RewardModelGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        ignore_action=False,
        copy_action=True,
        use_interactions=True,
        edge_actions=True)

    value_model = ValueModelGNN(
        observation_space=env.observation_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        use_interactions=True)

    critic = DiscreteCriticWMGNN(
        transition_model,
        reward_model,
        value_model,
        action_space=env.action_space,
        n_critics=2,
        gamma=config['gamma'],
        depth=config['depth'],
    )

    actor = DiscreteActorGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        use_interactions=True)

    policy = SACCustomPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        lr=config['learning_rate'],
        features_extractor=features_extractor,
        actor=actor,
        critic=critic,
        use_wm_optimizer=config['use_wm_optimizer'],
    )

    replay_buffer = ReplayBuffer(
        buffer_size=config['buffer_size'],
        observation_space=gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(config['num_objects'], config['embedding_dim'])
        ),
        action_space=env.action_space,
        device="cuda",
        n_envs=config['n_envs'],
        optimize_memory_usage=False,
        handle_timeout_termination=True,
    )

    wandb_dir = f'./wandb/{config["project"]}'
    os.makedirs(wandb_dir, exist_ok=True)

    run = wandb.init(
        project=config['project'],
        config=config,
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
        save_code=True,  # optional
        dir=f'./wandb/{config["project"]}',
        name=f'run-{config["seed"]}',
    )

    model = DiscreteSACWMGNN(policy=policy, env=env, learning_rate=config['learning_rate'],
                  batch_size=config['batch_size'], replay_buffer=replay_buffer,
                  buffer_size=config['buffer_size'], learning_starts=config['learning_starts'], gamma=config['gamma'],
                  train_freq=config['train_freq'], gradient_steps=config['gradient_steps'], ent_coef=config['ent_coef'],
                  verbose=1, tau=config['tau'], seed=config['seed'], target_entropy=config['target_entropy'],
                  tensorboard_log=wandb.run.dir)

    eval_env = VecTransposeImage(DummyVecEnv([make_eval_env(
        config, config['n_envs'], video_folder=f"{wandb.run.dir}/videos/", record_video_trigger=lambda x: x % config['eval_n_episodes'] in (0, 1)
    )]))

    model.learn(
        total_timesteps=config['total_timesteps'],
        log_interval=10,
        callback=[
            WandbCallback(
                gradient_save_freq=0,
                verbose=2,
                model_save_freq=config['eval_freq'],
                model_save_path=f"{wandb.run.dir}/models/",
            ),
            EvalCallback(
                eval_env,
                eval_freq=config['eval_freq'],
                n_eval_episodes=config['eval_n_episodes'],
                best_model_save_path=f"{wandb.run.dir}/models/",
                log_path=f"{wandb.run.dir}/eval_logs/",
                deterministic=False,
            ),
        ]
    )
    run.finish()

