import time
import os

import numpy as np
import cv2
import torch
import gym
import torch.nn as nn
import torch.nn.functional as F
import imageio
import random

from typing import Optional, Dict, List
from tqdm import tqdm
from collections import deque
from buffer import ReplayBuffer
from logger import Logger
from agent import BasePolicy
from hrl.env import MinigridWrapper
import crafter

# model-free policy trainer
class MFPolicyTrainer:
    def __init__(
        self,
        policy: BasePolicy,
        eval_env: gym.Env,
        buffer: ReplayBuffer,
        logger: Logger,
        render: bool,
        epoch: int = 1e6,
        step_per_epoch: int = 1000,
        batch_size: int = 256,
        eval_episodes: int = 10,
        low_policy: nn.Module = None,
        embed_codebook: torch.Tensor = None,
        project_out: nn.Linear = None,
        embed_index_set: np.array = None, 
        deterministic: bool = True,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None
    ) -> None:
        self.policy = policy
        self.eval_env = eval_env
        self.buffer = buffer
        self.logger = logger
        self.render = render 

        self._epoch = epoch
        self._step_per_epoch = step_per_epoch
        self._batch_size = batch_size
        self._eval_episodes = eval_episodes
        self.lr_scheduler = lr_scheduler
        self.embed_index_set = embed_index_set
        self.low_policy = low_policy
        self.embed_codebook = embed_codebook
        self.project_out = project_out

        self.deterministic = deterministic

    def train(self) -> Dict[str, float]:
        start_time = time.time()

        num_timesteps = 0
        last_10_performance = deque(maxlen=10)
        # train loop
        for e in range(1, self._epoch + 1):

            self.policy.train() 
            self.low_policy.train() ######################################################################################################
            pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}")
            for it in pbar:
                batch = self.buffer.sample(self._batch_size)
                if e <= 100:
                    loss = self.low_policy.learn(batch, self.embed_index_set, self.embed_codebook, self.project_out)
                    if num_timesteps % self._step_per_epoch == 0:
                        self.logger.log(loss, num_timesteps)                 

                loss = self.policy.learn(batch)
                if num_timesteps % 1000 == 0:
                    self.logger.log(loss, num_timesteps)                
                num_timesteps += 1

            if self.lr_scheduler is not None:
                # for i in range(len(self.lr_scheduler)):
                self.lr_scheduler.step()
            # evaluate current policy
            
            if isinstance(self.eval_env, crafter.Env):
                task_success_counter = self.crafter_eval()
                for task, success_count in task_success_counter.items():
                    success_rate = success_count / self._eval_episodes * 100
                    print(f"Task: {task}, Success Rate: {success_rate:.2f}%")
                    self.logger.log_str(f"Task: {task}, Success Rate: {success_rate:.2f}%")
                self.logger.log_str('='*50 + str(e))
                continue
            eval_info = self._evaluate()
            ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
            ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])

            if isinstance(self.eval_env, MinigridWrapper):
                norm_ep_rew_mean, norm_ep_rew_std = 0, 0 
                last_10_performance.append(0)
            else:
                norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100
                norm_ep_rew_std = self.eval_env.get_normalized_score(ep_reward_std) * 100
                last_10_performance.append(norm_ep_rew_mean)

            self.logger.log_var("eval/normalized_episode_reward", norm_ep_rew_mean, num_timesteps)
            self.logger.log_var("eval/normalized_episode_reward_std", norm_ep_rew_std, num_timesteps)
            self.logger.log_var("eval/episode_length", ep_length_mean, num_timesteps)
            self.logger.log_var("eval/episode_length_std", ep_length_std, num_timesteps)

            self.logger.log_str("{}:eval/normalized_episode_reward; ep_reward_mean {}, step {}".format(norm_ep_rew_mean, ep_reward_mean, num_timesteps))
            # save checkpoint
            torch.save(self.policy.network.state_dict(), os.path.join(self.logger.log_path, "policy.pth"))
            torch.save(self.low_policy.actor.state_dict(), os.path.join(self.logger.log_path, "low_policy.pth"))

        self.logger.log_str("total time: {:.2f}s".format(time.time() - start_time))
        torch.save(self.policy.network.state_dict(), os.path.join(self.logger.log_path, "policy.pth"))
        torch.save(self.low_policy.actor.state_dict(), os.path.join(self.logger.log_path, "low_policy.pth"))

        # self.logger.close()

        return {"last_10_performance": np.mean(last_10_performance)}

    def _evaluate(self) -> Dict[str, List[float]]:
        self.policy.eval()
        
        obs = self.eval_env.reset()
        eval_ep_info_buffer = []
        num_episodes = 0
        episode_reward, episode_length = 0, 0
        render_obs = []

        while num_episodes < self._eval_episodes:
            action_index = self.policy.get_action(obs)
            
            self.low_policy.eval()
            self.project_out.eval()

            action_quantize = F.embedding(torch.from_numpy(self.embed_index_set[action_index]).to(self.embed_codebook.device), self.embed_codebook)
            with torch.no_grad():
                action_context = self.project_out(action_quantize)
                # action_context = torch.zeros_like(action_context).to(action_context)#####################for test
                action = self.low_policy.get_action(torch.from_numpy(obs).unsqueeze(0).to(action_context), action_context, deterministic = self.deterministic)
            
            next_obs, reward, terminal, _ = self.eval_env.step(action.detach().cpu().numpy().flatten())
            
            if self.render:
                image = self.eval_env.render(mode="rgb_array")
                save_traj = os.path.join(self.logger.log_path, "traj")
                if not os.path.exists(save_traj):
                    os.makedirs(save_traj)
                cv2.imwrite(os.path.join(save_traj, '{}.png'.format(episode_length)), image)
                # imageio.mimsave
                render_obs.append(image)

            episode_reward += reward
            episode_length += 1

            obs = next_obs

            if terminal:
                eval_ep_info_buffer.append(
                    {"episode_reward": episode_reward, "episode_length": episode_length}
                )
                num_episodes +=1
                episode_reward, episode_length = 0, 0
                # if self.render:
                #     imageio.mimsave(os.path.join(self.logger.log_path, "cql.gif"), render_obs)
                render_obs = []
                obs = self.eval_env.reset()
        
        return {
            "eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer],
            "eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer]
        }
    def crafter_eval(self):
        TASKS = [
        "collect_coal",
        "collect_diamond",
        "collect_drink",
        "collect_iron",
        "collect_sapling",
        "collect_stone",
        "collect_wood",
        "defeat_skeleton",
        "defeat_zombie",
        "eat_cow",
        "eat_plant",
        "make_iron_pickaxe",
        "make_iron_sword",
        "make_stone_pickaxe",
        "make_stone_sword",
        "make_wood_pickaxe",
        "make_wood_sword",
        "place_furnace",
        "place_plant",
        "place_stone",
        "place_table",
        "wake_up"]

        task_success_counter = {task: 0 for task in TASKS}
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


        num_episodes = self._eval_episodes

        for episode in range(num_episodes):
            obs = self.eval_env.reset()  
            obs = np.transpose(obs, (2, 0, 1))

            done = False

            while not done:
                obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
                goal = random.randint(0, 15)  
                goal = F.one_hot(torch.tensor(goal), num_classes=16).float().to(device)  
                goal = goal.unsqueeze(0)
                # print(goal.shape)
                action_index = self.policy.get_action(obs)
                
                self.low_policy.eval()
                self.project_out.eval()

                action_quantize = F.embedding(torch.from_numpy(self.embed_index_set[action_index]).to(self.embed_codebook.device), self.embed_codebook)
                with torch.no_grad():
                    action_context = self.project_out(action_quantize)
                    # action_context = torch.zeros_like(action_context).to(action_context)#####################for test
                    action = self.low_policy.get_action(torch.from_numpy(obs).unsqueeze(0).to(action_context), action_context, deterministic = True)

                    next_obs, reward, done, info = self.eval_env.step(action)
                    obs = np.transpose(next_obs, (2, 0, 1))

            for task in TASKS:
                if info['achievements'].get(task, 0) > 0:  
                    task_success_counter[task] += 1

        return task_success_counter