import json
import logging
import random as rd
import time
from collections import deque
from typing import Dict, Iterator, List, Tuple, Union

import ecole
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.stats import gmean
from torch.utils.tensorboard import SummaryWriter

from rl.agents.replay_buffer import PrioritizedReplayBuffer
from rl.agents.tree_dqn import TreeDQNAgent
from rl.rewards.retro_branching import RetroBranching
from rl.rewards.reward_agent import RewardAgent


class TreeDQNLearner:
    """Learner class to train all dqn agents."""

    def __init__(
        self,
        agent: TreeDQNAgent,
        reward_agent: RewardAgent,
        env: ecole.environment.Environment,
        instances: Iterator[str],
        evaluation_instances: List[str],
        save_path: str,
        n_epochs: int = 200000,
        memory_size: int = 100000,
        memory_min_size: int = 20000,
        batch_size: int = 128,
        steps_per_update: int = 10,
        n_grad_accumulation: int = 1,
        seed: int = 42,
        lr: float = 5e-5,
        gamma: float = 0.99,
        epsilon_decay: float = 1e-4,
        epsilon_greedy: bool = False,
        max_epsilon: float = 1.0,
        min_epsilon: float = 2.5e-2,
        temperature_decay: float = 0,  # 1e-5
        # PER parameters
        alpha: float = 0.6,
        beta: float = 0.4,
        prior_eps: float = 1e-6,
        # N-step Learning
        n_step: int = 3,
        # Experiment
        writer: SummaryWriter = None,
        save_frequency: int = 1000,
        overfit_one_instance: bool = False,
        time_limit: int = 3600,
    ):
        """Initialization.

        Args:
            agent (DQNAgent)
            env (gym.Env): openAI Gym environment
            memory_size (int): length of memory
            batch_size (int): batch size for sampling
            target_hard_update (int): period for target model's hard update
            lr (float): learning rate
            gamma (float): discount factor
            alpha (float): determines how much prioritization is used
            beta (float): determines how much importance sampling is used
            prior_eps (float): guarantees every transition can be sampled
            n_step (int): step number to calculate n-step td error
        """
        self.agent = agent
        self.reward_agent = reward_agent
        self.env = env
        self.instances = instances
        self.evaluation_instances = evaluation_instances
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.steps_per_update = steps_per_update
        self.n_grad_accumulation = n_grad_accumulation
        self.accumulation_steps = 0 if n_grad_accumulation > 1 else 1
        self.seed = seed
        self.gamma = gamma
        self.epsilon_decay = epsilon_decay
        self.epsilon_greedy = epsilon_greedy
        self.max_epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.epsilon = max_epsilon
        self.temperature_decay = temperature_decay
        self.time_limit = time_limit
        self.overfit_one_instance = overfit_one_instance
        if self.overfit_one_instance:
            self.only_instance = self.evaluation_instances[0]
            self.evaluation_instances = [self.only_instance]

        self.off_policy = True

        # network
        self.save_path = save_path if not None else False
        self.save_frequency = save_frequency
        self.writer = writer

        # device: cpu / gpu
        self.device = self.agent.device

        # PER
        # memory for 1-step Learning and for the current episode
        self.alpha = alpha
        self.beta_0 = beta
        self.beta = beta
        self.prior_eps = prior_eps
        self.memory_min_size = memory_min_size
        self.memory_size = memory_size
        if self.off_policy:
            self.memory = PrioritizedReplayBuffer(memory_size, batch_size, alpha=alpha)

        # memory for N-step Learning
        self.use_n_step = n_step > 1
        if self.use_n_step:
            self.n_step = n_step
            self.n_step_buffer = deque(maxlen=self.n_step)
        else:
            self.n_step = 1

        # memory for tree episode !
        self.episode_memory = []

        # optimizer
        self.lr = lr
        self.optimizer = optim.Adam(self.agent.dqn.parameters(), lr=self.lr)

        # transition to store in memory
        self.use_retro_trajectories = isinstance(self.reward_agent, RetroBranching)
        self.use_bb_mdp = not self.use_retro_trajectories
        self.agent.retro = self.use_retro_trajectories
        self.agent.bb_mdp = self.use_bb_mdp
        self.transition = {}

        if self.use_bb_mdp:
            self.reward_agent.synchronize(n_step, gamma)

        # mode: train / test
        self.is_test = False
        self.best_performance = 1e9

    def select_action(self, state: Tuple[np.ndarray]) -> int:
        """Select an action from the input state."""
        reduced_state, action_set = state[:-1], torch.LongTensor(state[-1])
        if np.random.random() < self.epsilon and not self.is_test:
            action = rd.choice(action_set).item()  # random action to be tested
        else:
            with torch.no_grad():
                available_actions_scores = (
                    self.agent._compute_Q_value(reduced_state).cpu().gather(0, action_set)
                )
                if self.epsilon_greedy or self.is_test:  # epsilon greedy policy
                    selected_action = available_actions_scores.detach().argmax().numpy()
                else:  # epsilon stochastic policy -> sample from softmax distribution on Q-values
                    available_actions_softmax_scores = F.softmax(available_actions_scores / self.T, dim=0)
                    selected_action = torch.multinomial(available_actions_softmax_scores, 1).detach().numpy()
            action = action_set.numpy()[selected_action].item()

        if not self.is_test:
            self.agent_action_count += 1
            action_idx = np.where(action_set == action)[0].item()
            self.transition = {
                "state": state,
                "action": action,
                "action_idx": action_idx,
                "depth": self.current_node_depth,
            }

        return action

    def step(
        self, state: Tuple[np.ndarray], action: int
    ) -> Tuple[Tuple[np.ndarray], int, Union[np.float64, List[Dict]], np.ndarray, bool]:
        """Take an action and return the response of the env."""
        next_obs, next_action_set, reward, done, info = self.env.step(action)
        next_state = extract_state_from_obs(next_obs, next_action_set) if not done else state

        if not self.is_test:
            self.transition.update({"reward": reward, "next_state": next_state, "done": done})
            # episode memory contains at index step_idx the transition at step step_idx
            self.episode_memory.append(self.transition)

        _custom_reward = self.reward_agent.extract(self.env.model, done, next_action_set)

        if done:
            reward = _custom_reward

        return next_state, reward, info, done

    def get_tree_episode_score(self, final_reward: List[Dict[int, int]]) -> float:
        """Compute the score of the current tre episode."""
        if self.use_retro_trajectories:
            score = 0
            for trajectory in final_reward:
                score += sum(trajectory.values())
        elif self.use_bb_mdp:
            score = -self.env.model.as_pyscipopt().getNNodes()

        return score

    def _get_n_step_info(self, true_value: bool = False) -> Tuple[np.int64, np.ndarray, bool]:
        """Return n step rew, next_state, and done."""
        final_transition = self.n_step_buffer[-1]
        n_step_next_state = final_transition["next_state"]
        reward = "reward" if not true_value else "reward_true"
        done = "done" if not true_value else "done_true"
        n_step_reward = final_transition[reward]
        n_step_done = final_transition[done]
        for tree_transition in reversed(list(self.n_step_buffer)[:-1]):
            n_s = tree_transition["next_state"]
            r = tree_transition[reward]
            d = tree_transition[done]
            n_step_reward = r + self.gamma * n_step_reward * (1 - d)
            n_step_next_state, n_step_done = (n_s, d) if d else (n_step_next_state, n_step_done)

        return n_step_reward, n_step_next_state, n_step_done

    def process_episode_into_retro_trajectories(self, final_reward: List[Dict]):
        """Process the current episode and build retrospective trajectories,
        beforing storing them in memory."""

        # remove transition belonging to no retrospective trajectory
        if len(self.episode_memory) > 1:
            # Check if there is a real episode, evacuate double root node
            first_transition, second_transition = self.episode_memory[:2]
            assert first_transition["depth"] == 0
            first_step_is_redundant = bool(second_transition["depth"] == 0)
            if first_step_is_redundant:
                self.episode_memory = self.episode_memory[1:]
                logging.info("First step is redundant !")
            if self.reward_agent.only_use_leaves_closed_by_brancher_as_terminal_nodes:
                self.episode_memory = [
                    transition
                    for idx, transition in enumerate(self.episode_memory)
                    if idx not in self.reward_agent.experience_to_remove
                ]
        else:
            self.unsuccessful_episodes += [f"Len of episode: {len(self.episode_memory)}"]
            return False

        if len(self.episode_memory) != len(self.reward_agent.search_tree.tree.graph["visited_node_ids"]):
            self.unsuccessful_episodes += [
                f"Size of memory and tree don't match :/ \n\
                Episode: {len(self.episode_memory)} vs tree size: {len(self.reward_agent.search_tree.tree.graph['visited_node_ids'])}"
            ]
            return False

        # retrieve true reward and next_state for each transition of the episode
        for retrospective_trajectory in final_reward:
            previous_tree_transition = None
            for step_idx, reward in retrospective_trajectory.items():
                # dive along the retrospective trajectory
                tree_transition = self.episode_memory[step_idx]
                # update the reward, next_state and done, considering the episode ends at the current step
                tree_transition.update(
                    {"reward": reward, "next_state": tree_transition["state"], "done": True}
                )
                if previous_tree_transition is not None:
                    # previous transition was not terminal, update accordingly
                    previous_tree_transition.update({"next_state": tree_transition["state"], "done": False})
                previous_tree_transition = tree_transition

            # Transitions from current retrospective trajectory have been updated.
            # Let's build n_step transitions.
            retro_trajectory_transitions = [self.episode_memory[key] for key in retrospective_trajectory]
            for tree_transition in retro_trajectory_transitions:
                if self.use_n_step:
                    self.n_step_buffer.append(tree_transition)
                    if len(self.n_step_buffer) < self.n_step:
                        continue
                    n_step_reward, n_step_next_state, n_step_done = self._get_n_step_info()
                    transition_to_store = self.n_step_buffer[0]
                    transition_to_store.update(
                        {
                            "n_step_reward": n_step_reward,
                            "n_step_next_state": n_step_next_state,
                            "n_step_done": n_step_done,
                        }
                    )
                else:
                    transition_to_store = tree_transition
                self.memory.store(**transition_to_store)

        return True

    def process_episode_into_bb_mdp(self, final_reward: Dict[int, Dict[str, object]]):
        """
        Transform the episode trajectory into a tree MDP and store it in the memory.
        """
        # Perform some sanity checks
        if len(self.episode_memory) > 1:
            # Check if there is a real episode, evacuate double root node
            first_transition, second_transition = self.episode_memory[:2]
            assert first_transition["depth"] == 0
            first_step_is_redundant = bool(second_transition["depth"] == 0)
            if first_step_is_redundant:
                self.episode_memory = self.episode_memory[1:]
                logging.info("First step is redundant !")
        else:
            self.unsuccessful_episodes += [f"Len of episode: {len(self.episode_memory)}"]
            return False

        # Is everything alright ?
        if len(self.episode_memory) != len(self.reward_agent.visited_nodes_to_step_idx):
            # Handle the case when time_limit leads to imbalance between n episodes and tree size
            transition_gap = len(self.reward_agent.visited_nodes_to_step_idx) - len(self.episode_memory)
            if transition_gap == 1:
                logging.info("Runtime limit --> remove last visited node")
                idx_to_remove = len(self.episode_memory)
                self.reward_agent.visited_nodes_to_step_idx = {
                    node: step_idx
                    for node, step_idx in self.reward_agent.visited_nodes_to_step_idx.items()
                    if step_idx != idx_to_remove
                }
                final_reward = {
                    step_idx: tree_transition
                    for step_idx, tree_transition in final_reward.items()
                    if step_idx != idx_to_remove
                }
            else:
                # Something really weird happened ...
                self.unsuccessful_episodes += [
                    f"Size of memory and tree don't match :/ \n\
                    Episode: {len(self.episode_memory)} vs tree size: {len(self.reward_agent.visited_nodes_to_step_idx)}"
                ]
                return False

        # All good, let's process the episode
        for idx, tree_transition in final_reward.items():
            # Check for memory inflation!
            original_transition = self.episode_memory[idx]
            next_states = [
                self.episode_memory[next_state_idx]["state"]
                for next_state_idx in tree_transition["next_states"]  # invert the two lines
                if next_state_idx < len(self.episode_memory)  # else original_transition["state"]
            ]
            reward = tree_transition["reward"]
            done = len(next_states) == 0  # Problem ?
            original_transition.update({"next_state": next_states, "reward": reward, "done": done})
            if self.use_n_step:
                n_step_next_states = [
                    self.episode_memory[n_step_next_state_idx]["state"]
                    for n_step_next_state_idx in tree_transition[
                        "n_step_next_states"
                    ]  # invert the two lines
                    if n_step_next_state_idx < len(self.episode_memory)  # else original_transition["state"]
                ]
                n_step_reward = tree_transition["n_step_reward"]
                n_step_done = len(n_step_next_states) == 0
                original_transition.update(
                    {
                        "n_step_next_state": n_step_next_states,
                        "n_step_reward": n_step_reward,
                        "n_step_done": n_step_done,
                    }
                )

            self.memory.store(**original_transition)

        return True

    def update_model(self) -> torch.Tensor:
        """Update the model by gradient descent."""
        # PER needs beta to calculate weights
        samples = self.memory.sample_batch(beta=self.beta)
        importance_weights = torch.FloatTensor(samples["weights"].reshape(-1, 1)).to(self.device)
        indices = samples["indices"]
        weights = importance_weights

        # 1-step Learning loss
        # PER: importance sampling before average
        elementwise_loss = self.agent.compute_dqn_loss(samples, self.gamma, n_step=False)
        loss = torch.mean(elementwise_loss * weights)

        # N-step Learning loss
        # We are gonna combine 1-step loss and n-step loss so as to prevent high-variance.
        # The original rainbow employs n-step loss only.
        if self.use_n_step:
            gamma = self.gamma**self.n_step
            elementwise_n_step_loss = self.agent.compute_dqn_loss(samples, gamma, n_step=True)
            elementwise_loss += elementwise_n_step_loss
            # PER: importance sampling before average
            loss = torch.mean(elementwise_loss * weights) / self.n_grad_accumulation

        loss.backward()

        if self.accumulation_steps % self.n_grad_accumulation == 0:
            nn.utils.clip_grad_norm_(self.agent.dqn.parameters(), 10.0)
            self.optimizer.step()
            self.optimizer.zero_grad()

        # PER: update priorities
        loss_for_prior = elementwise_loss.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.prior_eps
        self.memory.update_priorities(indices, new_priorities)

        if self.writer is not None:
            average_weights = torch.mean(weights).item()
            raw_loss = torch.mean(elementwise_loss).item()
            if self.use_n_step:
                raw_n_step_loss = torch.mean(elementwise_n_step_loss).item()
                raw_one_step_loss = raw_loss - raw_n_step_loss
                self.writer.add_scalar("Loss_n_step/epoch", raw_n_step_loss, self.epoch_cnt)
                self.writer.add_scalar("Loss_one_step/epoch", raw_one_step_loss, self.epoch_cnt)
            max_priority = self.memory.max_priority
            self.writer.add_scalar("Beta/epoch", self.beta, self.epoch_cnt)
            self.writer.add_scalar("Loss_raw/epoch", raw_loss, self.epoch_cnt)
            self.writer.add_scalar("Loss_importance_weights/epoch", average_weights, self.epoch_cnt)
            self.writer.add_scalar("Max_priority/epoch", max_priority, self.epoch_cnt)
            if self.agent.classification:
                entropy = sum(self.agent.loss_fn.entropy) / 2
                cross_entropy = sum(self.agent.loss_fn.cross_entropy) / 2
                KL_divergence = sum(self.agent.loss_fn.KL_divergence) / 2
                self.writer.add_scalar("Entropy/epoch", entropy, self.epoch_cnt)
                self.writer.add_scalar("Entropy crossed/epoch", cross_entropy, self.epoch_cnt)
                self.writer.add_scalar("KL divergence/epoch", KL_divergence, self.epoch_cnt)

        return loss.item()

    def run_epochs(self, n_gradient_steps: int = 1):
        """
        Wrapper to run epochs and update all metrics.
        """

        keep_running = n_gradient_steps > 0
        Q_function_update_time = 0

        while keep_running:
            # PER: increase beta
            fraction = min(2 * self.epoch_cnt / self.n_epochs, 1.0)
            self.beta = self.beta_0 + fraction * (1.0 - self.beta_0)

            t1 = time.perf_counter()
            loss = self.update_model()
            t2 = time.perf_counter()
            Q_function_update_time += t2 - t1
            self.losses.append(loss)

            if self.n_grad_accumulation > 1:
                self.accumulation_steps = (self.accumulation_steps + 1) % self.n_grad_accumulation

            if self.accumulation_steps % self.n_grad_accumulation == 0:
                n_gradient_steps -= 1
                self.epoch_cnt += 1
                if self.writer is not None:
                    self.writer.add_scalar("Loss/epoch", loss, self.epoch_cnt)
                    self.writer.add_scalar("Epoch/step", self.epoch_cnt, self.step_cnt)
                    self.writer.add_scalar("Epsilon/step ", self.epsilon, self.epoch_cnt)
                    self.writer.add_scalar("Temperature/step", self.T, self.epoch_cnt)
                    self.writer.add_scalar("Q_function_update_time", Q_function_update_time, self.epoch_cnt)

                # linearly decrease epsilon and temperature
                self.epsilon = max(
                    self.min_epsilon,
                    self.epsilon - (self.max_epsilon - self.min_epsilon) * self.epsilon_decay,
                )
                self.epsilons.append(self.epsilon)
                self.T = max(1e-3, self.T - (1.0 - 1e-3) * self.temperature_decay)
                self.agent._target_update(self.epoch_cnt)
                self.Q_function_update_time_episode += Q_function_update_time
                Q_function_update_time = 0

                if self.save_path and self.epoch_cnt % self.save_frequency == 0:
                    self.agent._save(self.save_path, self.epoch_cnt)

            if n_gradient_steps == 0 or len(self.memory) == 0:
                keep_running = False

    def train(self):
        """
        Train the agent.
        """
        self.is_test = False
        self.epoch_cnt = 0
        self.step_cnt = 0
        self.episode_cnt = 0
        self.failure = 0
        self.T = 1.0
        self.losses = []
        self.scores = []
        self.epsilons = []
        self.n_nodes = []
        self.presolve_times = []
        self.solving_times = []
        self.n_integer_variables_after_presolve = []
        self.dual_gaps = []
        self.n_nodes_validation = []
        self.solving_times_validation = []
        self.dual_gaps_validation = []
        self.n_steps_validation = []
        self.n_epochs_validation = []
        self.unsuccessful_episodes = []

        # Overfit 1 instance
        if self.overfit_one_instance:
            current_instance = self.only_instance
            logging.info(
                f"Evaluate on instance: {current_instance}. Overfit ? {self.overfit_one_instance}. \n"
            )

        while self.epoch_cnt < self.n_epochs:
            done = True
            self.episode_memory = []
            self.current_node_depth = 0
            self.agent_action_count, self.expert_action_count = 0, 0
            self.ecole_transition_time_episode, self.Q_function_update_time_episode = 0, 0
            while done:
                if not self.overfit_one_instance:
                    current_instance = next(self.instances)
                    logging.info(f"Solving instance: {current_instance}.")
                state, reward_offset, done, info = self._reset_env(current_instance)
                score = reward_offset  # Primal Dual Integral
                n_nodes = info["n_nodes"]
                solving_time = info["time"]
                presolve_time = info["time"]
                score = 0
                initial_action_set = state[-1] if not done else []
                n_integer_variable_after_presolve = len(initial_action_set)

            # Start episode
            logging.info(
                f"Start episode {self.episode_cnt}, after {self.epoch_cnt} epochs and {self.step_cnt} transition steps."
            )
            while not done:
                action = self.select_action(state)
                t1 = time.perf_counter()
                next_state, reward, info, done = self.step(state, action)
                t2 = time.perf_counter()
                self.ecole_transition_time_episode += t2 - t1
                self.step_cnt += 1
                n_nodes += info["n_nodes"]
                solving_time += info["time"]
                score += 1
                state = next_state
                self.current_node_depth = info["depth"]

            ## End of episode

            final_gap = self.env.model.as_pyscipopt().getGap()

            # Add transitions to replay buffer
            t1 = time.perf_counter()
            if self.use_retro_trajectories:
                successful_episode = self.process_episode_into_retro_trajectories(final_reward=reward)
            elif self.use_bb_mdp and n_nodes:
                successful_episode = self.process_episode_into_bb_mdp(final_reward=reward)
            t2 = time.perf_counter()
            episode_post_processing_time = t2 - t1

            if successful_episode:
                self.scores.append(score)
                self.n_nodes.append(n_nodes)
                self.dual_gaps.append(final_gap)
                self.solving_times.append(solving_time)
                self.presolve_times.append(presolve_time)
                self.n_integer_variables_after_presolve.append(n_integer_variable_after_presolve)
                self.episode_cnt += 1
            else:
                motive = self.unsuccessful_episodes[-1]
                logging.info(f"{len(self.unsuccessful_episodes)} unsuccessful episodes, latest: \n{motive}")

            logging.info(
                f"End of episode, obtained a return of {score}, average return {np.mean(self.scores):.1f}.\n \
                       Number of B&B nodes: {n_nodes}, average tree size: {np.mean(self.n_nodes):.1f}\n \
                       Solving time: {(solving_time):.1f}, average solving time: {np.mean(self.solving_times):.1f}.\n "
            )

            if self.off_policy and successful_episode:
                training_ready = len(self.memory) > self.memory_min_size
                n_gradient_steps = min(len(self.episode_memory) // self.steps_per_update, 500)
                # If training is ready
                if training_ready and n_gradient_steps > 0:
                    logging.info(f"Performing {n_gradient_steps} gradients steps.")
                    self.run_epochs(n_gradient_steps=n_gradient_steps)
                    # breakpoint()

            if self.writer is not None:
                self.writer.add_scalar("Return", -score, self.episode_cnt)
                self.writer.add_scalar("N_nodes", n_nodes, self.episode_cnt)
                self.writer.add_scalar("Solving times", solving_time, self.episode_cnt)
                self.writer.add_scalar("Presolve times", presolve_time, self.episode_cnt)
                self.writer.add_scalar(
                    "Episode post-processing time", episode_post_processing_time, self.episode_cnt
                )
                self.writer.add_scalar(
                    "Episode Q function update time", self.Q_function_update_time_episode, self.episode_cnt
                )
                self.writer.add_scalar(
                    "Episode ecole transtion time", self.ecole_transition_time_episode, self.episode_cnt
                )
                self.writer.add_scalar(
                    "N_integer_variables_after_presolve", n_integer_variable_after_presolve, self.episode_cnt
                )
                self.writer.add_scalar("Size of replay buffer", len(self.memory), self.episode_cnt)
                agent_action_ratio = self.agent_action_count / (
                    self.agent_action_count + self.expert_action_count
                )
                self.writer.add_scalar("Agent action ratio", agent_action_ratio, self.episode_cnt)
                self.writer.flush()

            # Run validation every 100 episodes
            if self.episode_cnt % 100 == 0:
                validation_n_nodes, validation_dual_gap, validation_solving_time = self.evaluate_agent()
                self.n_nodes_validation.append(validation_n_nodes)
                self.dual_gaps_validation.append(validation_dual_gap)
                self.solving_times_validation.append(validation_solving_time)
                self.n_steps_validation.append(self.step_cnt)
                self.n_epochs_validation.append(self.epoch_cnt)

                if self.writer is not None:
                    self.writer.add_scalar("Validation nodes", validation_n_nodes, self.epoch_cnt)
                    self.writer.add_scalar("Validation gap", validation_dual_gap, self.epoch_cnt)
                    self.writer.add_scalar(
                        "Validation solving times", validation_solving_time, self.epoch_cnt
                    )

        ## End of training

        logging.info(
            f"End of training, total episodes: {self.episode_cnt}, total epochs: {self.epoch_cnt}, total transition: {self.step_cnt}.\n"
        )

        validation_n_nodes, validation_dual_gap, validation_solving_time = self.evaluate_agent()
        self.n_nodes_validation.append(validation_n_nodes)
        self.dual_gaps_validation.append(validation_dual_gap)
        self.solving_times_validation.append(validation_solving_time)
        self.n_steps_validation.append(self.step_cnt)
        self.n_epochs_validation.append(self.epoch_cnt)

        if self.writer is not None:
            self.writer.add_scalar("Validation nodes", validation_n_nodes, self.epoch_cnt)
            self.writer.add_scalar("Validation gap", validation_dual_gap, self.epoch_cnt)
            self.writer.add_scalar("Validation solving times", validation_solving_time, self.epoch_cnt)

        if self.save_path:
            training_dict = {
                "scores": self.scores,
                "losses": self.losses,
                "n_nodes": self.n_nodes,
                "solving_times": self.solving_times,
                "epsilons": self.epsilons,
                "dual_gaps": self.dual_gaps,
                "unsuccessful_episodes": len(self.unsuccessful_episodes),
            }
            validation_dict = {
                "n_nodes": self.n_nodes_validation,
                "dual_gaps": self.dual_gaps_validation,
                "solving_times": self.solving_times_validation,
                "n_steps": self.n_steps_validation,
                "n_epochs": self.n_epochs_validation,
            }

            with open(f"{self.save_path}training.json", "w") as f:
                json.dump(training_dict, f, indent=4)

            with open(f"{self.save_path}validation.json", "w") as f:
                json.dump(validation_dict, f, indent=4)

        if self.writer is not None:
            self.writer.close()

    def _reset_env(
        self, instance: str
    ) -> Tuple[Tuple[np.ndarray], np.ndarray, np.float64, np.ndarray, bool]:
        obs, action_set, reward_offset, done, info = self.env.reset(instance)
        done = True if obs is None else done
        state = extract_state_from_obs(obs, action_set) if not done else obs
        if not done:
            self.reward_agent.before_reset(instance)
            _ = self.reward_agent.extract(self.env.model, done, action_set)
        return state, reward_offset, done, info

    def _evaluate_agent_on_benchmark(self):
        """When training on a set of instance, evaluate on a benchmark!"""
        self.is_test = True
        logging.info(f"Evaluate current agent ... Overfit ? {self.overfit_one_instance}.")
        eval_scores = []
        eval_n_nodes = []
        eval_final_gaps = []
        eval_solving_times = []
        for eval_instance in self.evaluation_instances[:20]:
            logging.info(f"Solving instance: {eval_instance}.")
            state, _, done, info = self._reset_env(eval_instance)
            n_nodes = info["n_nodes"]
            solving_time = info["time"]
            score = 0
            # Start episode
            while not done:
                action = self.select_action(state)
                next_state, reward, info, done = self.step(state, action)
                n_nodes += info["n_nodes"]
                solving_time += info["time"]
                score += 1
                state = next_state

            # Parsonson test score = self.get_tree_episode_score(final_reward=reward)
            final_gap = self.env.model.as_pyscipopt().getGap()

            eval_scores.append(score)
            eval_n_nodes.append(n_nodes)
            eval_final_gaps.append(final_gap)
            eval_solving_times.append(solving_time)

        mean_n_nodes = gmean(eval_n_nodes)
        mean_final_gap = gmean(eval_final_gaps)
        mean_solving_time = gmean(eval_solving_times)

        if mean_n_nodes < self.best_performance:
            self.best_performance = mean_n_nodes
            self.agent._save(self.save_path, self.epoch_cnt, best=True, final=False)

        logging.info(
            f"Current agent evaluated, obtained an average return of {np.mean(eval_scores):.1f} +- {np.std(eval_scores):.1f}.\n \
                       Size of the B&B tree: {np.mean(eval_n_nodes):.1f} +- {np.std(eval_n_nodes):.1f}. \n"
        )
        self.is_test = False
        return mean_n_nodes, mean_final_gap, mean_solving_time

    def evaluate_agent(self):
        return self._evaluate_agent_on_benchmark()


def extract_state_from_obs(
    obs: ecole.core.observation,
    action_set: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    return (
        obs.row_features.astype(np.float32),
        obs.edge_features.indices.astype(np.int16),
        obs.edge_features.values.astype(np.float32),
        obs.variable_features.astype(np.float32),
        action_set.astype(np.int16),
    )
