import time
import os

import numpy as np
import torch
from tqdm import tqdm
import wandb

from utils.results_utils import get_eval_statistics


class Trainer:
    def __init__(
            self,
            algo,
            eval_env,
            epoch,
            step_per_epoch,
            rollout_freq,
            logger,
            log_freq,
            env_type,
            save_path,
            eval_episodes=10,
            use_wandb=False,
            hidden_dims=None
    ):
        self.algo = algo
        self.eval_env = eval_env

        self._epoch = epoch
        self._step_per_epoch = step_per_epoch
        self._rollout_freq = rollout_freq

        self.logger = logger
        self._log_freq = log_freq
        self._eval_episodes = eval_episodes
        self.use_wandb = use_wandb
        self.env_type = env_type
        self.save_path = save_path

        self.hidden_dims = hidden_dims

    def train_dynamics(self):
        start_time = time.time()

        self.algo.learn_dynamics()

        self.algo.save_dynamics_model("dynamics_model")
        self.logger.print("total time: {:.3f}s".format(time.time() - start_time))

    def train_policy(self):
        start_time = time.time()
        num_timesteps = 0
        # train loop
        for e in range(1, self._epoch + 1):
            self.algo.policy.train()
            with tqdm(total=self._step_per_epoch, desc=f"Epoch #{e}/{self._epoch}") as t:
                while t.n < t.total:
                    if num_timesteps % self._rollout_freq == 0:
                        self.algo.rollout_transitions()
                    # update policy by sac
                    loss = self.algo.learn_policy()
                    t.set_postfix(**loss)
                    # log
                    if num_timesteps % self._log_freq == 0:
                        for k, v in loss.items():
                            self.logger.record(k, v, num_timesteps, printed=False)
                    num_timesteps += 1
                    t.update(1)
            # evaluate current policy

            eval_info = self._evaluate()

            all_rewards = eval_info["eval/episode_reward"]
            avg_reward, std_reward, avg_norm_reward, std_norm_reward = get_eval_statistics(all_rewards, self.env_type)

            current_time_step = e * self._step_per_epoch
            print("---------------------------------------")
            print(f"Epoch {current_time_step}: Evaluation over {self._eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
            print("---------------------------------------")

            if self.use_wandb:
                metrics = {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward,
                           'eval std reward': std_reward, 'epochs': current_time_step}
                wandb.log(metrics)

            # save policy
            torch.save(self.algo.policy.state_dict(), os.path.join(self.save_path, 'policy.pkl'))
        self.logger.print("total time: {:.3f}s".format(time.time() - start_time))

    def _evaluate(self):
        self.algo.policy.eval()
        obs = self.eval_env.reset()
        eval_ep_info_buffer = []
        num_episodes = 0
        episode_reward, episode_length = 0, 0

        while num_episodes < self._eval_episodes:
            if self.hidden_dims is not None:
                obs[self.hidden_dims] = 0.0

            action = self.algo.policy.sample_action(obs, deterministic=True)
            next_obs, reward, terminal, _ = self.eval_env.step(action)
            episode_reward += reward
            episode_length += 1

            obs = next_obs

            if terminal:
                eval_ep_info_buffer.append(
                    {"episode_reward": episode_reward, "episode_length": episode_length}
                )
                num_episodes += 1
                episode_reward, episode_length = 0, 0
                obs = self.eval_env.reset()

        return {
            "eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer],
            "eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer]
        }
