import os
import time

import numpy as np
import torch
try:
    from cluster import announce_fraction_finished
    from cluster import exit_for_resume
    from cluster import save_metrics_params
    CLUSTER = 1
except ImportError:
    CLUSTER = 0
    print('No cluster utils detected, proceeding without it')

from tonic import logger


class Trainer:
    """Trainer used to train and evaluate an agent on an environment."""

    def __init__(
        self,
        steps=int(1e7),
        epoch_steps=int(2e4),
        save_steps=int(5e5),
        test_episodes=10,
        show_progress=True,
        replace_checkpoint=False,
    ):
        self.max_steps = steps
        self.epoch_steps = epoch_steps
        self.save_steps = save_steps
        self.test_episodes = test_episodes
        self.show_progress = show_progress
        self.replace_checkpoint = replace_checkpoint

    def initialize(self, agent, environment, test_environment=None):
        self.agent = agent
        self.environment = environment
        self.test_environment = test_environment

    def run(self, params, steps=0, epochs=0, episodes=0):
        """Runs the main training loop."""

        start_time = last_epoch_time = time.time()

        # Start the environments.
        observations, tendon_states = self.environment.start()

        num_workers = len(observations)
        scores = np.zeros(num_workers)
        lengths = np.zeros(num_workers, int)
        self.steps, epoch_steps = steps, 0
        steps_since_save = 0

        while True:
            # Select actions.
            if hasattr(self.agent, "expl"):
                greedy_episode = not episodes % self.agent.expl.test_episode_every
            else:
                greedy_episode = None
            assert not np.isnan(observations.sum())
            actions = self.agent.step(
                observations, self.steps, tendon_states, greedy_episode
            )
            assert not np.isnan(actions.sum())
            logger.store("train/action", actions, stats=True)

            # Take a step in the environments.
            observations, tendon_states, info = self.environment.step(actions)
            self.agent.update(**info, steps=self.steps)

            scores += info["rewards"]
            lengths += 1
            self.steps += num_workers
            epoch_steps += num_workers
            steps_since_save += num_workers

            # Show the progress bar.
            if self.show_progress:
                logger.show_progress(self.steps, self.epoch_steps, self.max_steps)

            # Check the finished episodes.
            for i in range(num_workers):
                if info["resets"][i]:
                    logger.store("train/episode_score", scores[i], stats=True)
                    logger.store("train/episode_length", lengths[i], stats=True)
                    scores[i] = 0
                    lengths[i] = 0
                    episodes += 1

            # End of the epoch.
            if epoch_steps >= self.epoch_steps:
                # Evaluate the agent on the test environment.
                if self.test_environment:
                    self._test(params)

                # Log the data.
                epochs += 1
                current_time = time.time()
                epoch_time = current_time - last_epoch_time
                sps = epoch_steps / epoch_time
                logger.store("train/episodes", episodes)
                logger.store("train/epochs", epochs)
                logger.store("train/seconds", current_time - start_time)
                logger.store("train/epoch_seconds", epoch_time)
                logger.store("train/epoch_steps", epoch_steps)
                logger.store("train/steps", self.steps)
                logger.store("train/worker_steps", self.steps // num_workers)
                logger.store("train/steps_per_second", sps)
                logger.dump()
                last_epoch_time = time.time()
                epoch_steps = 0
                if CLUSTER:
                    announce_fraction_finished(self.steps / self.max_steps)

            # End of training.
            stop_training = self.steps >= self.max_steps

            # Save a checkpoint.
            if stop_training or steps_since_save >= self.save_steps:
                path = os.path.join(logger.get_path(), "checkpoints")
                if os.path.isdir(path) and self.replace_checkpoint:
                    for file in os.listdir(path):
                        if file.startswith("step_"):
                            os.remove(os.path.join(path, file))
                checkpoint_name = f"step_{self.steps}"
                save_path = os.path.join(path, checkpoint_name)
                self.agent.save(save_path)
                logger.save(save_path)
                self.save_time(save_path, epochs, episodes)
                steps_since_save = self.steps % self.save_steps
                current_time = time.time()
                if (current_time - start_time) > 14000:
                    if CLUSTER:
                        exit_for_resume()

            if stop_training:
                return scores

    def _test(self, params=None):
        """Tests the agent on the test environment."""

        # Start the environment.
        if not hasattr(self, "test_observations"):
            self.test_observations, _ = self.test_environment.start()
            assert len(self.test_observations) == 1

        # Test loop.
        best_return = -10000
        for _ in range(self.test_episodes):
            score, length = 0, 0
            success_rate = 0
            best_reward = -10000
            actions_list = []
            activations = []
            loads = []

            while True:
                # Select an action.
                actions = self.agent.test_step(self.test_observations, self.steps)
                actions_list.append(actions)
                #activations.append(environment.environments[0].model.muscle_activation_array())
                #loads.append(environment.environments[0].model.contact_load())
                activations.append(0)
                loads.append(0)
                assert not np.isnan(actions.sum())
                logger.store("test/action", actions, stats=True)

                # Take a step in the environment.
                self.test_observations, _, info = self.test_environment.step(actions)
                self.agent.test_update(**info, steps=self.steps)

                score += info["rewards"][0]
                length += 1
                best_reward = np.maximum(best_reward, info["rewards"][0])

                if info["resets"][0]:
                    best_return = np.maximum(best_return, score)
                    if info["terminations"]:
                        success_rate += 1
                    break

            # Log the data.
            metrics = {
                "test/episode_score": score,
                "test/episode_length": length,
                "test/success_rate": success_rate,
                "test/effort": np.mean(np.square(actions_list)),
                "test/activations": np.mean(np.square(activations)),
                "test/loads": np.mean(loads)
            }
            for k, v in metrics.items():
                logger.store(k, v, stats=True)
            if params is not None:
                if CLUSTER:
                    save_metrics_params({"test/episode_score": score}, params)

    def save_time(self, path, epochs, episodes):
        time_path = self.get_path(path, "time")
        time_dict = {"epochs": epochs, "episodes": episodes, "steps": self.steps}
        torch.save(time_dict, time_path)

    def get_path(self, path, post_fix):
        return path.split("step")[0] + post_fix + ".pt"
