import time
import os

import numpy as np
import torch
import gym

from typing import Optional, Dict, List
from tqdm import tqdm
from collections import deque
from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger
from offlinerlkit.policy import BasePolicy
import pickle


# model-free policy trainer
class MFPolicyTrainer:
    def __init__(
            self,
            policy: BasePolicy,
            eval_env: gym.Env,
            buffer: ReplayBuffer,
            logger: Logger,
            epoch: int = 1000,
            step_per_epoch: int = 1000,
            batch_size: int = 256,
            eval_episodes: int = 10,
            pre_num: int = 0,
            num_evaluate: int = 1,
            lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None
    ) -> None:
        self.policy = policy
        self.eval_env = eval_env
        self.buffer = buffer
        self.logger = logger

        self._epoch = epoch
        self._step_per_epoch = step_per_epoch
        self._batch_size = batch_size
        self._eval_episodes = eval_episodes
        self._num_evaluate = num_evaluate

        self._pre_num = pre_num

        self.lr_scheduler = lr_scheduler

    def train(self) -> Dict[str, float]:
        start_time = time.time()

        num_timesteps = 0
        last_10_performance = deque(maxlen=10)

        # train loop
        for e in range(1, self._epoch + 1):

            if e < self._pre_num:
                self.policy.train()
                pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}")
                for it in pbar:
                    batch = self.buffer.sample(self._batch_size)

                    loss = self.policy.learnPRDC(batch)
                    pbar.set_postfix(**loss)

                    for k, v in loss.items():
                        self.logger.logkv_mean(k, v)

                    num_timesteps += 1
            else:
                self.policy.train()
                pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}")

                for it in pbar:
                    batch = self.buffer.sample(self._batch_size)

                    loss = self.policy.learn(batch)
                    pbar.set_postfix(**loss)

                    for k, v in loss.items():
                        self.logger.logkv_mean(k, v)

                    num_timesteps += 1

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            # evaluate current policy
            if e % 10 == 0:
                eval_results = []

                # Run _evaluate(e) 10 times and collect the results
                for i in range(self._num_evaluate):
                    eval_info = self._evaluate(e)
                    ep_reward_mean = np.mean(eval_info["eval/episode_reward"])
                    ep_reward_std = np.std(eval_info["eval/episode_reward"])
                    ep_length_mean = np.mean(eval_info["eval/episode_length"])
                    ep_length_std = np.std(eval_info["eval/episode_length"])

                    norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100
                    norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100

                    eval_results.append({
                        "norm_ep_rew_mean": norm_ep_rew_mean,
                        "norm_ep_rew_std": norm_ep_rew_std,
                        "ep_length_mean": ep_length_mean,
                        "ep_length_std": ep_length_std
                    })

                # Find the result with the maximum norm_ep_rew_mean
                best_result = max(eval_results, key=lambda x: x['norm_ep_rew_mean'])

                # Log the best result
                norm_ep_rew_mean = best_result["norm_ep_rew_mean"]
                norm_ep_rew_std = best_result["norm_ep_rew_std"]
                ep_length_mean = best_result["ep_length_mean"]
                ep_length_std = best_result["ep_length_std"]

                last_10_performance.append(norm_ep_rew_mean)
                self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean)
                self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std)
                self.logger.logkv("eval/episode_length", ep_length_mean)
                self.logger.logkv("eval/episode_length_std", ep_length_std)
                self.logger.set_timestep(num_timesteps)
                self.logger.dumpkvs()

                # save checkpoint
                torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth"))

        self.logger.log("total time: {:.2f}s".format(time.time() - start_time))
        torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth"))
        self.logger.close()

        return {"last_10_performance": np.mean(last_10_performance)}

    def _evaluate(self, e) -> Dict[str, List[float]]:
        self.policy.eval()
        obs = self.eval_env.reset()
        eval_ep_info_buffer = []
        num_episodes = 0
        episode_reward, episode_length = 0, 0

        while num_episodes < self._eval_episodes:
            action = self.policy.select_action(obs.reshape(1, -1), deterministic=True)
            next_obs, reward, terminal, _ = self.eval_env.step(action.flatten())
            episode_reward += reward
            episode_length += 1

            obs = next_obs

            if terminal:
                eval_ep_info_buffer.append(
                    {"episode_reward": episode_reward, "episode_length": episode_length}
                )
                num_episodes += 1
                episode_reward, episode_length = 0, 0
                obs = self.eval_env.reset()

        return {
            "eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer],
            "eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer]
        }
