import time
import os

import numpy as np
import torch
import gym

from typing import Optional, Dict, List
from tqdm import tqdm
from collections import deque
from dataset.buffer import ReplayBuffer
from utils.logger import Logger
from policies import BasePolicy
from reward.reward_model import RewardModel
from dataset.load_preference_dataset import load_preference_dataset
from dataset.generate_preference_data import collect_preference_data

# model-free policy trainer
class PolicyTrainer:
    def __init__(
        self,
        args,
        policy: BasePolicy,
        reward: RewardModel,
        eval_env: gym.Env,
        real_buffer: ReplayBuffer,
        fake_buffer: ReplayBuffer,
        logger: Logger,
        epoch: int = 1000,
        step_per_epoch: int = 1000,
        batch_size: int = 256,
        eval_episodes: int = 10,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        rollout_freq: int=1000,
        num_query: int = 2000,
        len_query: int = 200
    ) -> None:
        self.args = args
        self.policy = policy
        self.reward = reward
        self.eval_env = eval_env
        self.real_buffer = real_buffer
        self.fake_buffer = fake_buffer
        self.logger = logger

        self._epoch = epoch
        self._step_per_epoch = step_per_epoch
        self._batch_size = batch_size
        self._eval_episodes = eval_episodes
        self._rollout_freq = rollout_freq
        self._num_query = num_query
        self._len_query = len_query
        self.lr_scheduler = lr_scheduler

    def train(self,pref_real_dataset) -> Dict[str, float]:
        start_time = time.time()
        pref_fake_dataset = {}
        num_timesteps = 0
        data_size = 0
        last_10_performance = deque(maxlen=10)
        # train loop
        for e in range(1, self._epoch + 1):
            self.policy.train()
            pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}")
            
            for it in pbar:
                # collect model data (s,a,s')
                if self.args.update_reward:
                    if num_timesteps % self._rollout_freq == 0 and num_timesteps<=self.args.max_reward_steps:
                        init_obss = self.real_buffer.sample(self.args.rollout_batch_size)["observations"].cpu().numpy()
                        rollout_transitions  = self.policy.rollout(init_obss, self.args.rollout_length)
                        self.fake_buffer.add_batch(**rollout_transitions)
                    
                    # train reward model
                    if num_timesteps % self.args.reward_update_freq == 0 and num_timesteps!=0 and num_timesteps<=self.args.max_reward_steps:
                        # generate fake reward
                        self.fake_buffer.predict_reward(reward_model=self.reward)
                        fake_dataset = self.fake_buffer.sample_all()
                        # generate fake preference dataset
                        collect_preference_data(
                            args=self.args, 
                            dataset=fake_dataset,
                            num_query=self._num_query,
                            len_query=self._len_query,
                            human_label=False
                        )
                        # load fake preference dataset
                        pref_fake_dataset = load_preference_dataset(
                            args=self.args, 
                            dataset=fake_dataset,
                            num_query=self._num_query,
                            len_query=self._len_query,
                            human_label=False
                        )
                        # select_data = self.reward.select_data(pref_dataset)

                        # if len(pref_fake_dataset)==0:
                        #     for key in select_data.keys():
                        #         pref_fake_dataset[key] = select_data[key]
                        #         data_size = len(pref_fake_dataset[key])
                        # elif len(select_data)!=0:
                        #     for key in select_data.keys():
                        #         pref_fake_dataset[key] = np.concatenate((pref_fake_dataset[key],select_data[key]),axis=0)
                        #         data_size = len(pref_fake_dataset[key])

                        #if data_size > self.args.max_fake_data_size:
                        # star training reward model
                        self.reward.train(
                            init_pref_real_dataset = pref_real_dataset,
                            init_pref_fake_dataset = pref_fake_dataset,
                            fake_ratio = self.args.fake_ratio, 
                            n_epochs= self.args.reward_train_epoch,  
                            logger = self.logger,
                            batch_size=self.args.reward_train_batch_size
                        )
                        # relabel reward of offline dataset
                        self.real_buffer.predict_reward(reward_model=self.reward)

                        # clear fake date
                        #pref_fake_dataset =  {}
                   
                batch = self.real_buffer.sample(self._batch_size)
                loss = self.policy.learn(batch)
                pbar.set_postfix(**loss)

                for k, v in loss.items():
                    self.logger.logkv_mean(k, v)
                
                num_timesteps += 1

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            
            # evaluate current policy
            eval_info = self._evaluate()
            ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
            ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])
            norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100
            norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100
            last_10_performance.append(norm_ep_rew_mean)
            self.logger.logkv("eval/normalized_episode_reward", norm_ep_rew_mean)
            self.logger.logkv("eval/normalized_episode_reward_std", norm_ep_rew_std)
            self.logger.logkv("eval/episode_length", ep_length_mean)
            self.logger.logkv("eval/episode_length_std", ep_length_std)
            #self.logger.logkv("0_data_size", data_size)
            self.logger.set_timestep(num_timesteps)
            self.logger.dumpkvs(exclude=["reward_training_progress","transition_training_progress"])
        
            # save checkpoint
            torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth"))

        self.logger.log("total time: {:.2f}s".format(time.time() - start_time))
        torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth"))
        self.logger.close()

        return {"last_10_performance": np.mean(last_10_performance)}

    def _evaluate(self) -> Dict[str, List[float]]:
        self.policy.eval()
        obs = self.eval_env.reset()
        eval_ep_info_buffer = []
        num_episodes = 0
        episode_reward, episode_length = 0, 0

        while num_episodes < self._eval_episodes:
            action = self.policy.select_action(obs.reshape(1,-1), deterministic=True)
            next_obs, reward, terminal, _ = self.eval_env.step(action.flatten())
            episode_reward += reward
            episode_length += 1

            obs = next_obs

            if terminal:
                eval_ep_info_buffer.append(
                    {"episode_reward": episode_reward, "episode_length": episode_length}
                )
                num_episodes +=1
                episode_reward, episode_length = 0, 0
                obs = self.eval_env.reset()
        
        return {
            "eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer],
            "eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer]
        }