import numpy as np
import torch
import wandb

from algorithms.dreamer import (
    Dreamer,
    InvariantDreamer,
)
from common.logger import Video
from common.utils import to_torch, to_np, preprocess, postprocess
from environments import make_env
from setup import set_seed, set_device, setup_logger
from adapt_dreamer import get_config


if __name__ == "__main__":
    config = get_config()
    set_seed(config.seed)
    set_device(config.use_gpu)

    # Logger
    logger = setup_logger(config)

    # Environment
    env = make_env(config.env_id, config.seed, config.pixel_obs)
    eval_env = make_env(config.env_id, config.seed, config.pixel_obs)

    # Load agent
    if config.algo == "dreamer":
        algo = Dreamer(config, env, eval_env, logger)
    elif config.algo == "dreamer_invariant":
        algo = InvariantDreamer(config, env, eval_env, logger)
    else:
        raise NotImplementedError("Unsupported algorithm")
    algo.load_checkpoint(config.source_dir)

    # Evaluate agent
    algo.toggle_train(False)
    returns, successes = [], []
    for i in range(config.eval_episodes):
        belief, posterior_state, action_tensor = algo.init_latent_and_action()
        obs = env.reset()
        done = False
        episode_reward = 0
        episode_success = 0
        frames = []
        with torch.no_grad():
            while not done:
                obs_tensor = to_torch(preprocess(obs[None]))
                (
                    belief,
                    posterior_state,
                    action_tensor,
                ) = algo.update_latent_and_select_action(
                    belief, posterior_state, action_tensor, obs_tensor, False
                )
                action = to_np(action_tensor)[0]
                next_obs, reward, done, info = env.step(action)
                if config.pixel_obs:
                    obs_hat = to_np(algo.obs_model(belief, posterior_state))
                    obs_hat = postprocess(obs_hat)[0]
                    frames.append([obs, obs_hat])
                obs = next_obs
                episode_reward += reward
                episode_success += info.get("success", 0)

        # Log statistics
        logger.record("test/return", episode_reward)
        logger.record("test/success", float(episode_success > 0))
        if config.pixel_obs:
            # video shape: (T, N, C, H, W) -> (N, T, C, H, W)
            video = Video(np.stack(frames).transpose(1, 0, 2, 3, 4), fps=30)
            logger.record("test/video", video, exclude="stdout")
        logger.dump(i)

        # Record episode statistics
        returns.append(episode_reward)
        successes.append(episode_success)

    # Record wandb summary
    returns = np.array(returns)
    successes = np.array(successes)
    wandb.run.summary["return_mean"] = returns.mean()
    wandb.run.summary["return_max"] = returns.max()
    wandb.run.summary["return_std"] = returns.std()
    wandb.run.summary["success_rate"] = successes.mean()
