import collections

import torch
import gym
import numpy as np
from gym.wrappers import TimeLimit
from stable_baselines3.common.buffers import ReplayBuffer

import wandb
from omegaconf import OmegaConf
from wandb.integration.sb3 import WandbCallback

from envs.cw_envs import CwTargetEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecTransposeImage
from stable_baselines3.common.utils import set_random_seed

from ocr.slate.slate import SLATE
from rl.sac.policies import *
from rl.sac.sac import *


def load_features_extractor_state_dict(path):
    state_dict = torch.load(path)["ocr_module_state_dict"]
    return state_dict



def make_env(config, rank):
    def init():
        env_config = OmegaConf.load(config['env_conf_path'])
        seed = config['seed'] + rank
        set_random_seed(seed)
        env = CwTargetEnv(env_config, seed)
        env.action_space.seed(seed)
        env = TimeLimit(env, env.unwrapped._max_episode_length)
        env = Monitor(env)
        return env

    return init


if __name__ == "__main__":
    config = {
        'total_timesteps': 30_000_000,
        'hidden_dim': 512,
        'batch_size': 256,
        'learning_rate': 2e-4,
        'buffer_size': 1000000,
        'ent_coef': 'auto',
        'gamma': 0.99,
        'depth': 1,
        'tau': 0.005,
        'train_freq': 1,
        'gradient_steps': -1,
        'learning_starts': 5000,
        'seed': 0,
        'vec_env_cls': SubprocVecEnv,
        'n_envs': 16,
        'env_conf_path': 'envs/config/reaching-hard.yaml',
        'slate_conf_path': 'ocr/slate/config/slate.yaml',
        'slate_checkpoint_path': 'reaching-hard_slate/model_best.pth',
        'eval_freq': 625,
        'eval_n_episodes': 10,
        'use_wm_optimizer': True,
        'device': torch.device('cuda:0'),
        'target_entropy': -3,
        'project': 'goca_reaching_entropy-3-0'
    }

    features_extractor_state_dict = load_features_extractor_state_dict(
        config['slate_checkpoint_path'])

    env = config['vec_env_cls']([make_env(config, i) for i in range(config['n_envs'])])

    config_ocr = OmegaConf.load(config['slate_conf_path'])
    config['num_objects'] = config_ocr.slotattr.num_slots
    config['embedding_dim'] = config_ocr.slotattr.slot_size

    config_env = OmegaConf.load(config['env_conf_path'])
    features_extractor = SLATE(config_ocr, config_env, env.observation_space, preserve_slot_order=True)
    features_extractor._module.load_state_dict(features_extractor_state_dict)
    features_extractor = features_extractor.cuda(config['device'])


    transition_model = TransitionModelGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        ignore_action=False,
        copy_action=True,
        use_interactions=True,
        edge_actions=True)

    reward_model = RewardModelGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        ignore_action=False,
        copy_action=True,
        use_interactions=True,
        edge_actions=True)

    for model in (features_extractor,): # transition_model, reward_model):
        model.eval()
        for param in model.parameters():
            param.requires_grad = False

    value_model = ValueModelGNN(
        observation_space=env.observation_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        use_interactions=True)

    critic = ContinuousCriticWMGNN(
        transition_model,
        reward_model,
        value_model,
        action_space=env.action_space,
        n_critics=2,
        gamma=config['gamma'],
        depth=config['depth'],)

    actor = ActorGNN(
        observation_space=env.observation_space,
        action_space=env.action_space,
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        num_objects=config['num_objects'],
        use_interactions=True)

    policy = SACWMGNNPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        lr=config['learning_rate'],
        features_extractor=features_extractor,
        actor=actor,
        critic=critic,
    )

    replay_buffer = ReplayBuffer(
        buffer_size=config['buffer_size'],
        observation_space=gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(config['num_objects'], config['embedding_dim'])
        ),
        action_space=env.action_space,
        device=config['device'],
        n_envs=config['n_envs'],
        optimize_memory_usage=False,
        handle_timeout_termination=True,
    )

    run = wandb.init(
        project=config['project'],
        config=config,
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
        save_code=True,  # optional
        name='run-0'
    )

    model = SACWMGNN(policy=policy, env=env, learning_rate=config['learning_rate'], replay_buffer=replay_buffer,
                  batch_size=config['batch_size'], target_entropy=config['target_entropy'], device=config['device'],
                  buffer_size=config['buffer_size'], learning_starts=config['learning_starts'], gamma=config['gamma'],
                  train_freq=config['train_freq'], gradient_steps=config['gradient_steps'], ent_coef=config['ent_coef'],
                  verbose=1, tau=config['tau'], seed=config['seed'],
                  tensorboard_log=wandb.run.dir)

    eval_env = VecTransposeImage(DummyVecEnv([make_env(config, config['n_envs'])]))

    model.learn(
        total_timesteps=config['total_timesteps'],
        log_interval=10,
        callback=[
            WandbCallback(
                gradient_save_freq=0,
                verbose=2,
                model_save_freq=config['eval_freq'],
                model_save_path=f"{wandb.run.dir}/models/",
            ),
            EvalCallback(
                eval_env,
                eval_freq=config['eval_freq'],
                n_eval_episodes=config['eval_n_episodes'],
                best_model_save_path=f"{wandb.run.dir}/models/",
                log_path=f"{wandb.run.dir}/eval_logs/",
                deterministic=False,
            ),
        ]
    )
    run.finish()

