import logging
import time

import torch
from torch.distributions import Normal, kl_divergence, Categorical
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

import numpy as np

from .config import BaseConfig
from .replay_memory import ReplayMemory, BatchOutput
from .test import test
from .utils import FreezeParameters
from torch.nn import MSELoss, CrossEntropyLoss, L1Loss

from core.env import EnvBatcher
from .utils import imagine_ahead, lambda_return, bottle, root_search
from .mcts import MCTS, Node
from .model import ActorOutput

train_logger = logging.getLogger('train')
test_logger = logging.getLogger('train_eval')


def get_dyn_update_itr(config, env_steps):
    min_itr = config.min_dynamics_update_itr
    max_itr = config.max_dynamics_update_itr
    return int(max(max_itr - ((max_itr - min_itr) * ((env_steps / (0.1 * config.max_env_steps + 1e-5)))), min_itr))


def get_beh_update_itr(config, env_steps):
    min_itr = config.min_behaviour_update_itr
    max_itr = config.max_behaviour_update_itr
    return int(min(min_itr + ((max_itr - min_itr) * ((env_steps / (0.1 * config.max_env_steps + 1e-5)))), max_itr))


def _update_dynamics(batch, model, dynamics_optimizer, free_nats, global_prior, config):
    # ##################
    # Dynamics learning
    # ##################
    dyn_update_start_time = time.time()

    # predict
    init_belief = model.init_belief(config.batch_size).to(config.device)
    init_state = model.init_state(config.batch_size).to(config.device)
    transition_output = model.transition(init_belief, init_state, batch.action[:-1], batch.action_repeat_prob[:-1],
                                         bottle(model.encoder, (batch.obs[1:],)),
                                         batch.non_terminal[:-1])
    predicted_obs = bottle(model.observation, (transition_output.belief, transition_output.posterior_state))
    predicted_reward = bottle(model.reward, (transition_output.belief, transition_output.posterior_state))
    posterior_dist = Normal(transition_output.posterior_mean, transition_output.posterior_std_dev)
    prior_dist = Normal(transition_output.prior_mean, transition_output.prior_std_dev)

    # estimate losses
    obs_loss = MSELoss(reduction='none')(predicted_obs, batch.obs[1:]).sum(dim=2).mean(dim=(0, 1))
    reward_loss = MSELoss(reduction='none')(predicted_reward.squeeze(-1), batch.reward[:-1]).mean(dim=(0, 1))
    div = kl_divergence(posterior_dist, prior_dist).sum(dim=2)
    kl_loss = torch.max(div, free_nats).mean(dim=(0, 1))
    if config.global_kl_beta != 0:
        posterior_dist = Normal(transition_output.posterior_mean, transition_output.posterior_std_dev)
        kl_loss += config.global_kl_beta * kl_divergence(posterior_dist, global_prior).sum(dim=2).mean(dim=(0, 1))
    dynamics_loss = obs_loss + reward_loss + kl_loss

    # optimize
    dynamics_optimizer.zero_grad()
    dynamics_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.transition.parameters(), config.grad_clip_norm, norm_type=2)
    torch.nn.utils.clip_grad_norm_(model.reward.parameters(), config.grad_clip_norm, norm_type=2)
    torch.nn.utils.clip_grad_norm_(model.observation.parameters(), config.grad_clip_norm, norm_type=2)
    torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), config.grad_clip_norm, norm_type=2)
    dynamics_optimizer.step()

    # remove from cuda
    # del batch, predicted_obs, predicted_reward, posterior_dist, prior_dist, init_state, init_belief
    # del obs_loss, reward_loss, kl_loss
    # torch.cuda.empty_cache()

    return transition_output, batch.reward.cpu().flatten().data.numpy().tolist(), \
           predicted_reward.cpu().flatten().data.numpy().tolist(), \
           obs_loss.item(), reward_loss.item(), kl_loss.item(), time.time() - dyn_update_start_time


def _update_behaviour(transition_output, model, actor_optimizer, value_optimizer, config):
    agent_update_start_time = time.time()

    with torch.no_grad():
        actor_states = transition_output.posterior_state.detach()
        actor_beliefs = transition_output.belief.detach()

    with FreezeParameters([model.transition, model.encoder, model.reward, model.observation]):
        imagination_output = imagine_ahead(actor_states, actor_beliefs, model, config.planning_horizon)
    with FreezeParameters([model.transition, model.encoder, model.reward, model.observation]):
        with FreezeParameters([model.value]):
            imged_reward = bottle(model.reward, (imagination_output.belief, imagination_output.prior_state))
            value_pred = bottle(model.value, (imagination_output.belief, imagination_output.prior_state))

    if config.explore_mode in ['mcts', 'rollout']:
        if config.automatic_entropy_tuning:
            _img_action_entropy = config.alpha * imagination_output.actions_entropy.unsqueeze(-1)
        else:
            _img_action_entropy = config.actor_entropy_coeff * imagination_output.actions_entropy.unsqueeze(-1)
    else:
        _img_action_entropy = torch.zeros_like(imged_reward)

    returns = lambda_return(imged_reward, value_pred, _img_action_entropy,
                            action_repeats=imagination_output.action_repeats.unsqueeze(2),
                            bootstrap=value_pred[-1],
                            discount=config.gamma,
                            lambda_=config.disclam)
    actor_repeat_entropy = torch.mean(imagination_output.action_repeats_entropy)
    actor_repeat_log_prob = torch.mean(imagination_output.action_repeats_log_prob)

    actor_loss = -torch.mean(returns)
    actor_repeat_entropy_loss = - config.actor_repeat_entropy_coeff * actor_repeat_log_prob

    if config.explore_mode in ['mcts', 'rollout']:
        if not config.automatic_entropy_tuning:
            actor_loss += - config.actor_entropy_coeff * imagination_output.actions_entropy.mean()

    if config.optimize_with_search:
        _states = imagination_output.prior_state.reshape(1, imagination_output.prior_state.shape[0] *
                                                         imagination_output.prior_state.shape[1],
                                                         imagination_output.prior_state.shape[-1]).detach()
        _beliefs = imagination_output.belief.reshape(1, imagination_output.belief.shape[0] *
                                                     imagination_output.belief.shape[1],
                                                     imagination_output.belief.shape[-1]).detach()

        batched_actor_cross_entropy = None
        _batch_size = int(_states.shape[1] / 1)
        for batch_start_i in range(0, _states.shape[1], _batch_size):
            with torch.no_grad():
                _batch_range = torch.arange(batch_start_i, min(batch_start_i + _batch_size, _states.shape[1]))
                root_search_output = root_search(len(_batch_range), model,
                                                 _beliefs[:, _batch_range, :],
                                                 _states[:, _batch_range, :],
                                                 20, 20, config, planning_horizon=5)

                target_probs = Categorical(logits=root_search_output.q_values).probs

            actor_cross_entropy = 0
            actor_dist = model.actor.action_dist(model.actor(_states[:, _batch_range, :].squeeze(0),
                                                             _beliefs[:, _batch_range, :].squeeze(0)))
            for action_idx in range(root_search_output.actions.shape[-2]):
                actions_log_prob = actor_dist.log_prob(root_search_output.actions[:, action_idx, :])
                actor_cross_entropy += target_probs[:, action_idx] * actions_log_prob

            if batched_actor_cross_entropy is None:
                batched_actor_cross_entropy = actor_cross_entropy
            else:
                torch.cat((batched_actor_cross_entropy, actor_cross_entropy), dim=0)

        actor_loss += - batched_actor_cross_entropy.mean()

    # Update model parameters
    actor_optimizer.zero_grad()
    (actor_loss + actor_repeat_entropy_loss).backward()
    torch.nn.utils.clip_grad_norm_(model.actor.parameters(), config.grad_clip_norm, norm_type=2)
    torch.nn.utils.clip_grad_norm_(model.actor_repeat.parameters(), config.grad_clip_norm, norm_type=2)
    actor_optimizer.step()

    if config.explore_mode in ['mcts', 'rollout']:
        if config.automatic_entropy_tuning:
            alpha_loss = -(config.log_alpha * (-_img_action_entropy + config.target_entropy).detach()).mean()
            config.alpha_optim.zero_grad()
            alpha_loss.backward()
            config.alpha_optim.step()

            config.alpha = config.log_alpha.exp()

    # ##################
    # Value learning
    # ##################
    with torch.no_grad():
        value_beliefs = imagination_output.belief.detach()
        value_prior_states = imagination_output.prior_state.detach()
        target_return = returns.detach()

    # detach the input tensor from the transition network.
    value_dist = Normal(bottle(model.value, (value_beliefs, value_prior_states)), 1)
    value_loss = -value_dist.log_prob(target_return).mean(dim=(0, 1))

    # Update model parameters
    value_optimizer.zero_grad()
    value_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.value.parameters(), config.grad_clip_norm, norm_type=2)
    value_optimizer.step()

    return target_return.cpu().flatten().data.numpy().tolist(), \
           actor_loss.item(), actor_repeat_entropy.item(), \
           actor_repeat_entropy_loss.item(), value_loss.item(), \
           time.time() - agent_update_start_time


def update_params(model, optimizer, memory, updates_counter, config, writer, total_env_steps):
    dynamics_optimizer, value_optimizer, actor_optimizer = optimizer

    # loss-trackers
    itr_actor_loss, itr_actor_repeat_entropy_loss, itr_value_loss = 0, 0, 0
    itr_obs_loss, itr_kl_loss, itr_reward_loss = 0, 0, 0
    itr_actor_entropy = 0

    # time-trackers
    dynamics_update_time = []
    agent_update_time = []

    # used for logging
    log_target_reward_data = []
    log_pred_reward_data = []
    log_target_value_data = []

    free_nats = torch.full((1,), config.free_nats, device=config.device)
    global_prior = Normal(torch.zeros_like(model.init_state()).to(config.device),
                          torch.ones_like(model.init_state()).to(config.device))

    if config.update_mode == 'together_with_grad_flow':
        _dyn_update_itr = config.update_itrs
        _bhv_update_itr = config.update_itrs
        for update_i in range(_dyn_update_itr):
            # ##################
            # Dynamics learning
            # ##################
            dyn_update_start_time = time.time()
            batch: BatchOutput = memory.sample(config.batch_size, config.chunk_size, config.device)

            # predict
            init_belief = model.init_belief(config.batch_size).to(config.device)
            init_state = model.init_state(config.batch_size).to(config.device)
            transition_output = model.transition(init_belief, init_state, batch.action[:-1],
                                                 batch.action_repeat_prob[:-1],
                                                 bottle(model.encoder, (batch.obs[1:],)),
                                                 batch.non_terminal[:-1])
            predicted_obs = bottle(model.observation, (transition_output.belief, transition_output.posterior_state))
            predicted_reward = bottle(model.reward, (transition_output.belief, transition_output.posterior_state))
            posterior_dist = Normal(transition_output.posterior_mean, transition_output.posterior_std_dev)
            prior_dist = Normal(transition_output.prior_mean, transition_output.prior_std_dev)

            # estimate losses
            obs_loss = MSELoss(reduction='none')(predicted_obs, batch.obs[1:]).sum(dim=2).mean(dim=(0, 1))
            reward_loss = MSELoss(reduction='none')(predicted_reward.squeeze(-1), batch.reward[:-1]).mean(dim=(0, 1))
            div = kl_divergence(posterior_dist, prior_dist).sum(dim=2)
            kl_loss = torch.max(div, free_nats).mean(dim=(0, 1))
            if config.global_kl_beta != 0:
                posterior_dist = Normal(transition_output.posterior_mean, transition_output.posterior_std_dev)
                kl_loss += config.global_kl_beta * kl_divergence(posterior_dist, global_prior).sum(dim=2).mean(
                    dim=(0, 1))

            # dynamics_loss = obs_loss + reward_loss + kl_loss
            dynamics_loss = reward_loss
            _dyn_time = time.time() - dyn_update_start_time

            agent_update_start_time = time.time()
            with torch.no_grad():
                _actor_states = transition_output.posterior_state
                _actor_beliefs = transition_output.belief

            # with FreezeParameters([model.transition, model.encoder, model.reward, model.observation]):
            imagination_output = imagine_ahead(_actor_states, _actor_beliefs, model, config.planning_horizon)
            with FreezeParameters([model.transition, model.encoder, model.reward, model.observation]):
                with FreezeParameters([model.value]):
                    imged_reward = bottle(model.reward, (imagination_output.belief, imagination_output.prior_state))
                    value_pred = bottle(model.value, (imagination_output.belief, imagination_output.prior_state))

            _img_action_entropy = torch.zeros_like(imged_reward)
            returns = lambda_return(imged_reward, value_pred, _img_action_entropy,
                                    action_repeats=imagination_output.action_repeats.unsqueeze(2),
                                    bootstrap=value_pred[-1],
                                    discount=config.gamma,
                                    lambda_=config.disclam)

            # ##################
            # Value learning
            # ##################
            value_beliefs = imagination_output.belief
            value_prior_states = imagination_output.prior_state
            with torch.no_grad():
                target_return = returns.detach()

            # detach the input tensor from the transition network.
            value_dist = Normal(bottle(model.value, (value_beliefs, value_prior_states)), 1)
            value_loss = -value_dist.log_prob(target_return).mean(dim=(0, 1))

            # Update model parameters
            # optimize
            dynamics_optimizer.zero_grad()
            value_optimizer.zero_grad()
            (dynamics_loss + 0.5 * value_loss).backward()
            torch.nn.utils.clip_grad_norm_(model.transition.parameters(), config.grad_clip_norm, norm_type=2)
            torch.nn.utils.clip_grad_norm_(model.reward.parameters(), config.grad_clip_norm, norm_type=2)
            torch.nn.utils.clip_grad_norm_(model.observation.parameters(), config.grad_clip_norm, norm_type=2)
            torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), config.grad_clip_norm, norm_type=2)
            torch.nn.utils.clip_grad_norm_(model.value.parameters(), config.grad_clip_norm, norm_type=2)
            dynamics_optimizer.step()
            value_optimizer.step()

            with torch.no_grad():
                actor_states = transition_output.posterior_state.detach()
                actor_beliefs = transition_output.belief.detach()

            with FreezeParameters([model.transition, model.encoder, model.reward, model.observation]):
                imagination_output = imagine_ahead(actor_states, actor_beliefs, model, config.planning_horizon)
            with FreezeParameters([model.transition, model.encoder, model.reward, model.observation]):
                with FreezeParameters([model.value]):
                    imged_reward = bottle(model.reward, (imagination_output.belief, imagination_output.prior_state))
                    value_pred = bottle(model.value, (imagination_output.belief, imagination_output.prior_state))

            _img_action_entropy = torch.zeros_like(imged_reward)
            actor_returns = lambda_return(imged_reward, value_pred, _img_action_entropy,
                                          action_repeats=imagination_output.action_repeats.unsqueeze(2),
                                          bootstrap=value_pred[-1],
                                          discount=config.gamma,
                                          lambda_=config.disclam)
            actor_repeat_entropy = torch.mean(imagination_output.action_repeats_entropy)
            actor_repeat_log_prob = torch.mean(imagination_output.action_repeats_log_prob)

            actor_loss = -torch.mean(actor_returns)
            actor_repeat_entropy_loss = - config.actor_repeat_entropy_coeff * actor_repeat_log_prob

            if config.explore_mode in ['mcts', 'rollout']:
                actor_loss += - config.actor_entropy_coeff * imagination_output.actions_entropy.mean()

            # Update model parameters
            actor_optimizer.zero_grad()
            value_optimizer.zero_grad()
            dynamics_optimizer.zero_grad()
            (actor_loss + actor_repeat_entropy_loss).backward()
            torch.nn.utils.clip_grad_norm_(model.actor.parameters(), config.grad_clip_norm, norm_type=2)
            torch.nn.utils.clip_grad_norm_(model.actor_repeat.parameters(), config.grad_clip_norm, norm_type=2)
            actor_optimizer.step()

            _beh_time = time.time() - agent_update_start_time

            updates_counter['dynamics'] += 1
            updates_counter['agent'] += 1

            # log
            log_target_reward_data.append(batch.reward.cpu().flatten().data.numpy().tolist())
            log_pred_reward_data.append(predicted_reward.cpu().flatten().data.numpy().tolist())
            itr_obs_loss += obs_loss.item()
            itr_reward_loss += reward_loss.item()
            itr_kl_loss += kl_loss.item()
            dynamics_update_time.append(_dyn_time)

            itr_actor_loss += actor_loss.item()
            itr_actor_entropy += actor_repeat_entropy.item()
            itr_actor_repeat_entropy_loss += actor_repeat_entropy_loss.item()
            itr_value_loss += value_loss.item()
            agent_update_time.append(_beh_time)

            # store predictions/target for logging
            log_target_value_data.append(target_return.cpu().flatten().data.numpy().tolist())

    elif config.update_mode == 'together':
        _dyn_update_itr = config.update_itrs
        _bhv_update_itr = config.update_itrs
        for update_i in range(_dyn_update_itr):
            # ##################
            # Dynamics Learning`
            # ##################

            # Sample a batch from memory
            batch: BatchOutput = memory.sample(config.batch_size, config.chunk_size, config.device)
            transition_output, _batch_reward, _pred_reward, \
            _obs_loss, _reward_loss, _kl_loss, _dyn_time = _update_dynamics(batch, model, dynamics_optimizer,
                                                                            free_nats, global_prior, config)

            log_target_reward_data.append(_batch_reward)
            log_pred_reward_data.append(_pred_reward)
            itr_obs_loss += _obs_loss
            itr_reward_loss += _reward_loss
            itr_kl_loss += _kl_loss
            dynamics_update_time.append(_dyn_time)

            # ##################
            # Behaviour Learning
            # ##################
            _target_value, _actor_loss, \
            _actor_repeat_entropy, _actor_repeat_entropy_loss, \
            _value_loss, _beh_time = _update_behaviour(transition_output, model, actor_optimizer,
                                                       value_optimizer, config)
            # log
            itr_actor_loss += _actor_loss
            itr_actor_entropy += _actor_repeat_entropy
            itr_actor_repeat_entropy_loss += _actor_repeat_entropy_loss
            itr_value_loss += _value_loss
            agent_update_time.append(_beh_time)

            # store predictions/target for logging
            log_target_value_data.append(_target_value)

            updates_counter['dynamics'] += 1
            updates_counter['agent'] += 1

    elif config.update_mode == 'separate':
        # ##################
        # Dynamics Learning
        # ##################

        if config.anneal_update_itr:
            _dyn_update_itr = get_dyn_update_itr(config, total_env_steps)
        else:
            _dyn_update_itr = config.update_itrs

        for update_i in range(_dyn_update_itr):
            # Sample a batch from memory
            batch: BatchOutput = memory.sample(config.batch_size, config.chunk_size, config.device)
            _, _batch_reward, _pred_reward, \
            _obs_loss, _reward_loss, _kl_loss, _dyn_time = _update_dynamics(batch, model, dynamics_optimizer,
                                                                            free_nats, global_prior, config)

            log_target_reward_data.append(_batch_reward)
            log_pred_reward_data.append(_pred_reward)
            itr_obs_loss += _obs_loss
            itr_reward_loss += _reward_loss
            itr_kl_loss += _kl_loss
            dynamics_update_time.append(_dyn_time)
            updates_counter['dynamics'] += 1

        # ##################
        # Behaviour Learning
        # ##################
        if config.anneal_update_itr:
            _bhv_update_itr = get_beh_update_itr(config, total_env_steps)
        else:
            _bhv_update_itr = config.update_itrs

        for update_i in range(_bhv_update_itr):
            # Sample a batch from memory
            batch: BatchOutput = memory.sample(config.batch_size, config.chunk_size, config.device)

            # predict
            init_belief = model.init_belief(config.batch_size).to(config.device)
            init_state = model.init_state(config.batch_size).to(config.device)
            transition_output = model.transition(init_belief, init_state, batch.action[:-1],
                                                 batch.action_repeat_prob[:-1],
                                                 bottle(model.encoder, (batch.obs[1:],)),
                                                 batch.non_terminal[:-1])

            _target_value, _actor_loss, \
            _actor_repeat_entropy, _actor_repeat_entropy_loss, \
            _value_loss, _beh_time = _update_behaviour(transition_output, model, actor_optimizer,
                                                       value_optimizer, config)
            # log
            itr_actor_loss += _actor_loss
            itr_actor_entropy += _actor_repeat_entropy
            itr_actor_repeat_entropy_loss += _actor_repeat_entropy_loss
            itr_value_loss += _value_loss
            agent_update_time.append(_beh_time)
            # store predictions/target for logging
            log_target_value_data.append(_target_value)

            updates_counter['agent'] += 1
    else:
        raise NotImplementedError

    # normalize losses
    itr_actor_loss /= _bhv_update_itr
    itr_actor_entropy /= _bhv_update_itr
    itr_actor_repeat_entropy_loss /= _bhv_update_itr
    itr_value_loss /= _bhv_update_itr

    itr_obs_loss /= _dyn_update_itr
    itr_reward_loss /= _dyn_update_itr

    # log losses
    writer.add_scalar('train/actor_loss', itr_actor_loss, total_env_steps)
    writer.add_scalar('train/actor_repeat_entropy_loss', itr_actor_repeat_entropy_loss, total_env_steps)
    writer.add_scalar('train/actor_repeat_entropy', itr_actor_entropy, total_env_steps)
    writer.add_scalar('train/value_loss', itr_value_loss, total_env_steps)
    writer.add_scalar('train/obs_loss', itr_obs_loss, total_env_steps)
    writer.add_scalar('train/reward_loss', itr_reward_loss, total_env_steps)
    writer.add_scalar('train/kl_loss', itr_kl_loss, total_env_steps)
    if config.automatic_entropy_tuning:
        writer.add_scalar('train/alpha', config.alpha, total_env_steps)

    writer.add_scalar('train/behaviour_itr', _bhv_update_itr, total_env_steps)
    writer.add_scalar('train/dynamics_itr', _dyn_update_itr, total_env_steps)
    writer.add_scalar('train/agent_update_counter', updates_counter['dynamics'], total_env_steps)
    writer.add_scalar('train/dynamics_update_counter', updates_counter['agent'], total_env_steps)

    # log times
    writer.add_scalar('time_per_batch/rollout_agent_update', sum(agent_update_time), total_env_steps)
    writer.add_scalar('time_per_batch/agent_update', sum(agent_update_time) / len(agent_update_time), total_env_steps)
    writer.add_scalar('time_per_batch/rollout_dynamics_update', sum(dynamics_update_time), total_env_steps)
    writer.add_scalar('time_per_batch/dynamics_update', sum(dynamics_update_time) / len(dynamics_update_time),
                      total_env_steps)

    _msg = 'env steps #{:<10}  dynamics updates #{:<10} agent updates #{:<10} || '
    _msg += 'actor loss:{:<8.3f} value loss: {:<8.3f} '
    _msg += 'obs loss : {:<8.3f} reward loss : {:<8.3f} kl loss: {:<8.3f}'
    _msg = _msg.format(total_env_steps,
                       updates_counter['dynamics'], updates_counter['agent'],
                       itr_actor_loss, itr_value_loss,
                       itr_obs_loss, itr_reward_loss, itr_kl_loss)
    train_logger.info(_msg)


def _test(model, env, config, writer, best_test_score, total_env_steps):
    no_search_test_output = test(env, model, config=config, mode='no-search')
    if config.explore_mode != 'no-search':
        with_search_test_output = test(env, model, config=config, mode=config.explore_mode,
                                       mcts_num_simulations=config.num_simulations)
    else:
        with_search_test_output = no_search_test_output

    # save best model
    if no_search_test_output.score >= best_test_score['score']['no_search']:
        torch.save(model.state_dict(), config.best_model_path['no_search'])
        best_test_score['score']['no_search'] = no_search_test_output.score

    if with_search_test_output.score >= best_test_score['score']['with_search']:
        torch.save(model.state_dict(), config.best_model_path['with_search'])
        best_test_score['score']['with_search'] = with_search_test_output.score

    # Test Log
    writer.add_scalar('test/with_search/episode_reward', with_search_test_output.score, total_env_steps)
    writer.add_scalar('test/with_search/avg_repeat', with_search_test_output.avg_repeat, total_env_steps)

    writer.add_scalar('test/no_search/episode_reward', no_search_test_output.score, total_env_steps)
    writer.add_scalar('test/no_search/avg_repeat', no_search_test_output.avg_repeat, total_env_steps)

    msg = '#{:<10} |'
    msg += 'WITH SEARCH [ test score: {} avg_action_repeats:{} best score:{} ] ||'
    msg += 'NO SEARCH [ test score: {} avg_action_repeats:{} best score:{} ]'
    msg = msg.format(total_env_steps,
                     with_search_test_output.score, with_search_test_output.avg_repeat,
                     best_test_score['score']['with_search'],
                     no_search_test_output.score, no_search_test_output.avg_repeat,
                     best_test_score['score']['no_search'], )

    test_logger.info(msg)


def train(config: BaseConfig, writer: SummaryWriter):
    memory = ReplayMemory(config.replay_memory_capacity, config.action_space.shape[0],
                          len(config.action_repeat_set), config.observation_space.shape[0])
    assert config.explore_mode in ['no-search', 'rollout', 'mcts', 'mcts+fixed']

    # create networks
    model = config.get_uniform_network().to(config.device)
    test_model = config.get_uniform_network().to(config.device)

    # create envs
    env = config.new_game(seed=config.seed)
    test_envs = EnvBatcher(config.new_game, config.test_episodes)

    # create optimizers
    dynamics_optimizer = Adam([{'params': model.transition.parameters()},
                               {'params': model.observation.parameters()},
                               {'params': model.reward.parameters()},
                               {'params': model.encoder.parameters()}], lr=config.dynamics_lr)
    value_optimizer = Adam([{'params': model.value.parameters()}], lr=config.value_lr)
    policy_optimizer = Adam([{'params': model.actor.parameters()},
                             {'params': model.actor_repeat.parameters()}], lr=config.policy_lr)
    optimizer = (dynamics_optimizer, value_optimizer, policy_optimizer)

    if config.automatic_entropy_tuning:
        config.alpha = 0.2
        config.target_entropy = -torch.prod(torch.Tensor(env.action_space.shape).to(config.device)).item()
        config.log_alpha = torch.zeros(1, requires_grad=True, device=config.device)
        config.alpha_optim = Adam([config.log_alpha], lr=config.entropy_lr)

    # training trackers
    total_env_steps = 0
    episode_reward, episode_step = 0, 0
    updates_counter = {'dynamics': 0, 'agent': 0}
    best_test_score = {'score': {'with_search': float('-inf'),
                                 'no_search': float('-inf')}}

    # Fire!!!
    done = True
    i_episode = 0
    while True:

        # ##################
        # Learning.
        # ##################
        if len(memory) >= (config.batch_size * config.chunk_size) and total_env_steps >= config.start_step:
            update_params(model, optimizer, memory, updates_counter, config, writer, total_env_steps)

        # ##################
        # Env. interaction.
        # ##################
        if done:
            episode_reward, episode_step = 0, 0
            episode_action_repeats = []
            episode_action_dist_entropy = []
            root_childs = []
            search_value_errors = []
            i_episode += 1

            belief = model.init_belief().to(config.device)
            posterior_state = model.init_state().to(config.device)
            action = model.init_action().to(config.device)
            action_repeat_one_hot = model.init_action_repeat_one_hot().to(config.device)

            obs = env.reset()
            obs = torch.FloatTensor(obs).unsqueeze(0).to(config.device)

        for _ in range(config.env_itr_steps):
            temperature = config.visit_softmax_temperature_fn(env_steps=total_env_steps)

            # determine action
            if total_env_steps <= config.start_step:
                action = env.action_space.sample()
                action = torch.FloatTensor(action).unsqueeze(0).to(config.device)

                action_repeat_one_hot, action_repeat_n = model.actor_repeat.uniform_sample(device=config.device)

            else:
                with torch.no_grad():
                    transition_output = model.transition(belief, posterior_state,
                                                         action.unsqueeze(0),
                                                         action_repeat_one_hot.unsqueeze(0),
                                                         model.encoder(obs).unsqueeze(0))

                    belief = transition_output.belief.squeeze(0)
                    posterior_state = transition_output.posterior_state.squeeze(0)

                    actor_output = model.actor(belief, posterior_state)
                    _entropy = model.actor.action_dist(actor_output).entropy()
                    episode_action_dist_entropy.append(_entropy.item())

                    if config.explore_mode == 'no-search':
                        action = model.actor.action_sample(actor_output, deterministic=False)
                        actor_repeat_output = model.actor_repeat(belief, posterior_state, action)
                        action_repeat_one_hot, action_repeat_n = model.actor_repeat.sample(actor_repeat_output,
                                                                                           deterministic=False)
                    elif config.explore_mode == 'rollout':
                        root_search_output = root_search(1, model,
                                                         transition_output.belief,
                                                         transition_output.posterior_state,
                                                         config.proposal_action_sample,
                                                         config.uniform_action_sample,
                                                         config)

                        greedy_actions_idx = torch.argmax(root_search_output.q_values.squeeze(0),dim=0)
                        greedy_action = root_search_output.actions[[0], greedy_actions_idx]
                        noise_dist = Normal(torch.zeros(greedy_action.shape).to(config.device),
                                            torch.ones(greedy_action.shape).to(config.device) * 0.3)
                        action = greedy_action + noise_dist.sample()
                        action_repeat_n = root_search_output.actions_repeat[[0], greedy_actions_idx]
                        action_repeat_one_hot = root_search_output.actions_repeat_one_hot[[0], greedy_actions_idx]

                        _trans = model.transition(belief.repeat(1, 1, root_search_output.actions.shape[1]).
                                                  reshape(1, root_search_output.actions.shape[1],
                                                          model.belief_size).squeeze(0),
                                                  posterior_state.repeat(1, 1, root_search_output.actions.shape[1]).
                                                  reshape(1, root_search_output.actions.shape[1],
                                                          model.state_size).squeeze(0),
                                                  root_search_output.actions, root_search_output.actions_repeat_one_hot)

                        _base_values = model.value(_trans.belief.squeeze(0), _trans.prior_state.squeeze(0))

                        search_value_error = L1Loss()(_base_values.squeeze(1), root_search_output.q_values.squeeze(0))
                        search_value_errors.append(search_value_error.item())

                    elif 'mcts' in config.explore_mode:
                        actor_output = model.actor(belief, posterior_state)
                        child_action = model.actor.action_sample(actor_output, deterministic=False)
                        child_actor_repeat_output = model.actor_repeat(belief, posterior_state, child_action)
                        reward = model.reward(belief, posterior_state)
                        child_action_repeat_one_hot, child_action_repeat = model.actor_repeat.sample(
                            child_actor_repeat_output, deterministic=False)
                        actor_dist = model.actor.action_dist(actor_output)
                        child_action_log_prob = - actor_dist.entropy(child_action)

                        action, action_repeat, action_repeat_one_hot = None, None, None
                        for env_i in range(1):
                            root = Node(0, root=True)
                            progressive = 'fixed' not in config.explore_mode
                            root.expand(model, ActorOutput(actor_output.mu[env_i, :].unsqueeze(0),
                                                           actor_output.std_dev[env_i, :].unsqueeze(0)),
                                        child_action_log_prob[env_i].item(),
                                        child_action[env_i, :].unsqueeze(0).unsqueeze(0),
                                        child_action_repeat[env_i, :].int().item(),
                                        child_action_repeat_one_hot[env_i, :].unsqueeze(0).unsqueeze(0),
                                        (belief[env_i, :].unsqueeze(0),
                                         posterior_state[env_i, :].unsqueeze(0)),
                                        reward[env_i, :].item(), progressive=progressive)
                            if not progressive:
                                root.add_exploration_noise(config.root_dirichlet_alpha,
                                                           config.root_exploration_fraction)
                            MCTS(config, progressive=progressive).run(root, model,
                                                                      num_simulations=config.num_simulations)

                            child_values = [(root.reward + (config.gamma ** action_cap.repeat) * child.value(),
                                             action_cap, child)
                                            for action_cap, child in root.children.items()]
                            _, greedy_action_cap, greedy_child = max(child_values)
                            action_cap = greedy_action_cap
                            noise_dist = Normal(torch.zeros(greedy_action_cap.action.shape).to(config.device),
                                                torch.ones(greedy_action_cap.action.shape).to(config.device) * 0.3)

                            visit_counts = [(child.visit_count, action_cap, child)
                                            for action_cap, child in root.children.items()]
                            action_probs = [visit_count_i ** (1 / temperature) for visit_count_i, _, _ in visit_counts]
                            total_count = sum(action_probs)
                            action_probs = [x / total_count for x in action_probs]
                            # action_pos = np.random.choice(len(visit_counts), p=action_probs)
                            #
                            # _, action_cap, child = visit_counts[action_pos]

                            if action is None:
                                action = action_cap.action + noise_dist.sample()
                                action_repeat = [action_cap.repeat]
                                action_repeat_one_hot = action_cap.repeat_one_hot
                            else:
                                action = torch.cat((action,  action_cap.action + noise_dist.sample()), dim=1)
                                action_repeat.append(action_cap.repeat)
                                action_repeat_one_hot = torch.cat((action_repeat_one_hot, action_cap.repeat_one_hot),
                                                                  dim=1)
                            root_childs.append(len(visit_counts))

                            # estimate q-delta
                            _trans = model.transition(belief.repeat(1, 1, len(visit_counts)).
                                                      reshape(1, len(visit_counts), model.belief_size).squeeze(0),
                                                      posterior_state.repeat(1, 1, len(visit_counts)).
                                                      reshape(1, len(visit_counts), model.state_size).squeeze(0),
                                                      torch.cat([action_cap.action
                                                                 for action_cap in root.children.keys()], dim=1),
                                                      torch.cat([action_cap.repeat_one_hot
                                                                 for action_cap in root.children.keys()], dim=1))

                            _base_values = model.value(_trans.belief.squeeze(0), _trans.prior_state.squeeze(0))

                            search_value_error = L1Loss()(_base_values.squeeze(1),
                                                          torch.Tensor([child.value() for _, child in root.children.items()]).to(config.device))
                            search_value_errors.append(search_value_error.item())

                        action.squeeze_(0)
                        action_repeat_one_hot.squeeze_(0)
                    else:
                        raise NotImplementedError()

            # step action in the env.
            step_action = action.cpu().data.numpy()[0]
            action_repeat = action_repeat_n.int().item()
            episode_action_repeats.append(action_repeat)

            reward_sum = 0
            discounted_reward_sum = 0
            for rep_i in range(action_repeat):
                next_obs, reward, done, info = env.step(step_action)
                episode_step += 1
                total_env_steps += 1
                episode_reward += reward
                reward_sum += reward
                discounted_reward_sum += (config.gamma ** rep_i) * reward

                # ########
                # Test
                # ########
                # Note : This is kept inside env step for-loop to keep test intervals sync. across multiple seeds.
                if total_env_steps % config.test_interval_steps == 0 and total_env_steps > config.start_step:
                    test_model.load_state_dict(model.state_dict())
                    _test(test_model, test_envs, config, writer, best_test_score, total_env_steps)

                if done:
                    break

            # add to memory
            memory.push(obs.data.cpu().numpy(), action.data.cpu().numpy(), rep_i + 1,
                        action_repeat_one_hot.cpu().data.numpy(), discounted_reward_sum, done)

            obs = torch.FloatTensor(next_obs).unsqueeze(0).to(config.device)
            if done:
                break

        # ################
        # Log & save model
        # ################
        if done:
            writer.add_scalar('train/episode_reward', episode_reward, total_env_steps)
            writer.add_scalar('train/action_repeats', np.mean(episode_action_repeats), total_env_steps)
            writer.add_scalar('train/episode_steps', episode_step, total_env_steps)
            writer.add_scalar('train/episodes', i_episode, total_env_steps)
            writer.add_scalar('train/dynamics_updates', updates_counter['dynamics'], total_env_steps)
            writer.add_scalar('train/agent_updates', updates_counter['agent'], total_env_steps)
            writer.add_scalar('train/replay_memory_size', len(memory), total_env_steps)
            writer.add_scalar('train/temperature', config.visit_softmax_temperature_fn(env_steps=total_env_steps),
                              total_env_steps)
            if len(search_value_errors) > 0:
                writer.add_scalar('train/search_value_error', np.mean(search_value_errors), total_env_steps)
            if config.explore_mode == 'mcts':
                writer.add_scalar('train/avg_root_childs', np.mean(root_childs), total_env_steps)
            if len(episode_action_dist_entropy) > 0:
                writer.add_scalar('train/action_dist_entropy', np.mean(episode_action_dist_entropy), total_env_steps)

            _msg = 'total-steps #{:<10}|| train score:{:<8.3f} eps steps: {:<10} episodes: {:<10}'
            _msg += ' avg action repeat: ({:<8.3f},{:<8.3f})'
            _msg = _msg.format(total_env_steps, episode_reward, episode_step, i_episode,
                               np.mean(episode_action_repeats), np.std(episode_action_repeats))
            train_logger.info(_msg)

        # save model
        if (i_episode % config.save_model_freq == 0) or total_env_steps >= config.max_env_steps:
            torch.save(model.state_dict(), config.model_path)

        # check if max. env steps reached.
        if total_env_steps >= config.max_env_steps:
            train_logger.info('max env. steps reached!!')
            break
