import os
import time
import numpy as np
import torch

from tonic import logger
from cluster import announce_fraction_finished, exit_for_resume, save_metrics_params


class Trainer:
    """Trainer used to train and evaluate an agent on an environment."""

    def __init__(
            self, steps=int(1e7), epoch_steps=int(2e4), save_steps=int(5e5),
            test_episodes=10, show_progress=True, replace_checkpoint=False,
    ):
        self.max_steps = steps
        self.epoch_steps = epoch_steps
        self.save_steps = save_steps
        self.test_episodes = test_episodes
        self.show_progress = show_progress
        self.replace_checkpoint = replace_checkpoint

    def initialize(self, agent, environment, test_environment=None):
        self.agent = agent
        self.environment = environment
        self.test_environment = test_environment

    def run(self, params, steps=0, epochs=0, episodes=0):
        """Runs the main training loop."""

        start_time = last_epoch_time = time.time()

        # Start the environments.
        observations, tendon_states = self.environment.start()

        num_workers = len(observations)
        scores = np.zeros(num_workers)
        lengths = np.zeros(num_workers, int)
        self.steps, epoch_steps = steps, 0
        steps_since_save = 0

        while True:
            # Select actions.
            if hasattr(self.agent, 'expl'):
                greedy_episode = not episodes % self.agent.expl.test_episode_every
            else:
                greedy_episode=None
            assert not np.isnan(observations.sum())
            actions = self.agent.step(observations, self.steps, tendon_states, greedy_episode)
            assert not np.isnan(actions.sum())
            logger.store('train/action', actions, stats=True)

            # Take a step in the environments.
            observations, tendon_states, info = self.environment.step(actions)
            self.agent.update(**info, steps=self.steps)

            scores += info['rewards']
            lengths += 1
            self.steps += num_workers
            epoch_steps += num_workers
            steps_since_save += num_workers

            # Show the progress bar.
            if self.show_progress:
                logger.show_progress(
                    self.steps, self.epoch_steps, self.max_steps)

            # Check the finished episodes.
            for i in range(num_workers):
                if info['resets'][i]:
                    logger.store('train/episode_score', scores[i], stats=True)
                    logger.store(
                        'train/episode_length', lengths[i], stats=True)
                    scores[i] = 0
                    lengths[i] = 0
                    episodes += 1

            # End of the epoch.
            if epoch_steps >= self.epoch_steps:
                # Evaluate the agent on the test environment.
                if self.test_environment:
                    self._test(params)

                # Log the data.
                epochs += 1
                current_time = time.time()
                epoch_time = current_time - last_epoch_time
                sps = epoch_steps / epoch_time
                logger.store('train/episodes', episodes)
                logger.store('train/epochs', epochs)
                logger.store('train/seconds', current_time - start_time)
                logger.store('train/epoch_seconds', epoch_time)
                logger.store('train/epoch_steps', epoch_steps)
                logger.store('train/steps', self.steps)
                logger.store('train/worker_steps', self.steps // num_workers)
                logger.store('train/steps_per_second', sps)
                logger.dump()
                last_epoch_time = time.time()
                epoch_steps = 0
                announce_fraction_finished(self.steps / self.max_steps)

            # End of training.
            stop_training = self.steps >= self.max_steps

            # Save a checkpoint.
            if stop_training or steps_since_save >= self.save_steps:
                path = os.path.join(logger.get_path(), 'checkpoints')
                if os.path.isdir(path) and self.replace_checkpoint:
                    for file in os.listdir(path):
                        if file.startswith('step_'):
                            os.remove(os.path.join(path, file))
                checkpoint_name = f'step_{self.steps}'
                save_path = os.path.join(path, checkpoint_name)
                self.agent.save(save_path)
                logger.save(save_path)
                self.save_time(save_path, epochs, episodes)
                steps_since_save = self.steps % self.save_steps
                current_time = time.time()
                if (current_time - start_time) > 14000:
                    exit_for_resume()

            if stop_training:
                return scores

    def _test(self, params=None):
        """Tests the agent on the test environment."""

        # Start the environment.
        if not hasattr(self, 'test_observations'):
            self.test_observations, _ = self.test_environment.start()
            assert len(self.test_observations) == 1

        # Test loop.
        best_return= -10000
        for _ in range(self.test_episodes):
            score, length = 0, 0
            success_rate = 0
            best_reward = -10000

            while True:
                # Select an action.
                actions = self.agent.test_step(
                    self.test_observations, self.steps)
                assert not np.isnan(actions.sum())
                logger.store('test/action', actions, stats=True)

                # Take a step in the environment.
                self.test_observations, _, info = self.test_environment.step(
                    actions)
                self.agent.test_update(**info, steps=self.steps)

                score += info['rewards'][0]
                length += 1
                best_reward = np.maximum(best_reward, info['rewards'][0])

                if info['resets'][0]:
                    best_return= np.maximum(best_return, score)
                    if info['terminations']:
                        success_rate += 1
                    break

            # Log the data.
            metrics = {'test/episode_score': score,
                       'test/episode_length': length,
                       'test/success_rate': success_rate}
            for k, v in metrics.items():
                logger.store(k, v, stats=True)
            if params is not None:
                save_metrics_params({'test/episode_score': score}, params)

    def save_time(self, path, epochs, episodes):
        time_path = self.get_path(path, 'time')
        time_dict = {'epochs': epochs,
                     'episodes': episodes,
                     'steps': self.steps}
        torch.save(time_dict, time_path)

    def get_path(self, path, post_fix):
        return path.split('step')[0] + post_fix + '.pt'
