#!/usr/bin/env python3

import torch
from .utils import to_torch

@torch.no_grad()
def evaluate_invm(agent, latent_replay, distr_replay):
    from ml_logger import logger

    n_samples = 512
    batch = next(distr_replay.iterator)
    batch_size = len(batch[0])
    logger.print('batch size =================> ', batch_size)
    for _ in range(n_samples // batch_size):
        obs, action, reward, discount, next_obs = to_torch(batch, agent.device)
        latent = agent.encode(obs, to_numpy=False)
        next_latent = agent.encode(next_obs, to_numpy=False)
        with logger.Prefix(metrics="distr_env"):
            invm_loss = agent.get_invm_loss(latent, next_latent, action)
            logger.log(invm_loss=invm_loss.item())
        batch = next(distr_replay.iterator)

    for _ in range(n_samples // batch_size):
        batch = next(latent_replay.iterator)
        latent, action, reward, discount, next_latent = to_torch(batch, agent.device)

        # Sample one out of four in each instance
        # NOTE: latent: (batch_size, 4, 1024)
        batch_size, num_augments, _ = latent.shape
        latent = latent.mean(dim=1)
        next_latent = next_latent.mean(dim=1)

        with logger.Prefix(metrics="clean_env"):
            invm_loss = agent.get_invm_loss(latent, next_latent, action)
            logger.log(invm_loss=invm_loss.item())



def eval(env, agent, global_step,
         num_eval_episodes, action_repeat,
         to_video=None, stochastic_video=None):
    from copy import deepcopy
    from ml_logger import logger
    from . import utils

    prev_step, step, total_reward = 0, 0, 0
    for episode in range(num_eval_episodes):
        eval_agent = deepcopy(agent)  # make a new copy
        obs = env.reset()
        frames = []
        done = False
        size = 64
        while not done:
            if episode == 0 and to_video:
                # todo: use gym.env.render('rgb_array') instead
                frames.append(env.physics.render(height=size, width=size, camera_id=0))

            with torch.no_grad(), utils.eval_mode(eval_agent):
                # todo: remove global_step, replace with random-on, passed-in.
                action = eval_agent.act(obs, global_step, eval_mode=True)
            next_obs, reward, done, info = env.step(action)

            obs = next_obs
            total_reward += reward
            step += 1

        if episode == 0 and to_video:
            # Append the last frame
            frames.append(env.physics.render(height=size, width=size, camera_id=0))
            logger.save_video(frames, to_video)

    # Save video using stochastic agent (eval_mode is set to False)
    if stochastic_video:
        obs = env.reset()
        frames = []
        done = False
        while not done:
            frames.append(env.physics.render(height=size, width=size, camera_id=0))
            with torch.no_grad():
                action = eval_agent.act(obs, global_step, eval_mode=False)
            next_obs, reward, done, info = env.step(action)
            obs = next_obs
        frames.append(env.physics.render(height=size, width=size, camera_id=0))
        logger.save_video(frames, stochastic_video)

    logger.log(episode_reward=total_reward / episode, episode_length=step * action_repeat / episode)


def evaluate(agent, clean_eval_env, distr_eval_env,
             latent_replay, distr_replay, action_repeat, num_eval_episodes, progress, expl_agent=None, eval_invm=True):
    from ml_logger import logger

    # if Progress.step * Args.action_repeat > Adapt.num_adapt_seed_frames:
    if eval_invm:
        # with logger.Prefix(metrics="eval"):
        #     # Compare latent mean and stddev
        #     logger.print('evaluating clean latent space')
        #     compare_latent_spaces(agent if expl_agent is None else expl_agent, latent_replay, distr_replay)

        # Evaluate inv_dynamics head on clean_env latent
        # NOTE: only use encoder
        logger.print('evaluating invm')
        with logger.Prefix(metrics="eval"):
            evaluate_invm(agent if expl_agent is None else expl_agent, latent_replay, distr_replay)

    if expl_agent:
        # Replace agent encoder weights
        org_enc_state = agent.encoder.state_dict()
        agent.encoder.load_state_dict(expl_agent.encoder.state_dict())

    # Evaluation on clean env
    with logger.Prefix(metrics="clean_eval"):
        logger.print('evaluating clean env')
        path = f'videos/clean/{progress.step:09d}_eval.mp4'
        stoch_path = None
        eval(clean_eval_env, agent, progress.step, num_eval_episodes,
             action_repeat, to_video=path,
             stochastic_video=stoch_path)

    # Evaluation on distr env
    with logger.Prefix(metrics="eval"):
        logger.print('evaluating distr env')
        path = f'videos/distr/{progress.step:09d}_eval.mp4'
        stoch_path = None
        eval(distr_eval_env, agent, progress.step, num_eval_episodes,
             action_repeat, to_video=path,
             stochastic_video=stoch_path)

    if expl_agent:
        # with logger.Prefix(metrics="clean_eval_original"):
        #     logger.print('evaluating clean env with original policy')
        #     path = f'videos/clean-with-original-policy/{progress.step:09d}_eval.mp4'
        #     stoch_path = f'videos/clean-with-original-policy-stoch/{progress.step:09d}_eval.mp4'
        #     eval(clean_eval_env, agent, progress.step, to_video=path if Args.save_video else None,
        #          stochastic_video=stoch_path)
        with logger.Prefix(metrics="eval_expl"):
            logger.print('evaluating distr env with original policy')
            path = f'videos/distr-with-expl-policy/{progress.step:09d}_eval.mp4'
            stoch_path = None
            eval(distr_eval_env, expl_agent, progress.step, to_video=path,
                 stochastic_video=stoch_path)

        # Put back the original encoder weights
        agent.encoder.load_state_dict(org_enc_state)

    logger.log(**vars(progress))
    logger.flush()
